卷积层越往上,需要增加高层特征的数量。我们将图像压缩成256个尺寸大小64×64的特征向量,接着将DA域中图像的特征向量转换为DB域中图像的特征向量。 总而言之,我们将DA域中一个尺寸为[256,256,3]的图像,输入到设计的编码器中,获得了尺寸为[64,64,256]的输出OAenc。 转换 这些网络层的作用是组合图像的不同相近特征,然后基于这些特征,确定如何将图像的特征向量OAenc从DA域转换为DB域的特征向量。因此,作者使用了6层Resnet模块: o_r1 = build_resnet_block(o_enc_A, num_features= 64* 4)o_r2 = build_resnet_block(o_r1, num_features= 64* 4)o_r3 = build_resnet_block(o_r2, num_features= 64* 4)o_r4 = build_resnet_block(o_r3, num_features= 64* 4)o_r5 = build_resnet_block(o_r4, num_features= 64* 4)o_enc_B = build_resnet_block(o_r5, num_features= 64* 4) # o_enc_B.shape = (64, 64, 256) 这里OBenc表示该层的最终输出,尺寸为[64,64,256],这可以看作是DB域中图像的特征向量。 你一定很想知道build_resnet_block函数的内容及作用。build_resnet_block是一个由两个卷积层组成的神经网络层,其中部分输入数据直接添加到输出。这样做是为了确保先前网络层的输入数据信息直接作用于后面的网络层,使得相应输出与原始输入的偏差缩小,否则原始图像的特征将不会保留在输出中且输出结果会偏离目标轮廓。在上面也提到,这个任务的一个主要目标是保留原始图像的特征,如目标的大小和形状,因此残差网络非常适合完成这些转换。Resnet模块的结构如下所示:
图7:Resnet模块的结构 Resnet模块的代码如下: defresnet_blocks(input_res, num_features):out_res_1 = general_conv2d(input_res, num_features, window_width= 3, window_heigth= 3, stride_width= 1, stride_heigth= 1) out_res_2 = general_conv2d(out_res_1, num_features, window_width= 3, window_heigth= 3, stride_width= 1, stride_heigth= 1) return(out_res_2 + input_res) 解码 到目前为止,我们已经将特征向量OAenc传递到转换层,得到了另一个大小为[64,64,256]的特征向量OBenc。 解码过程与编码方式完全相反,从特征向量中还原出低级特征,开奖,这是利用了反卷积层(deconvolution)来完成的。 o_d1 = general_deconv2d(o_enc_B, num_features=ngf* 2window_width= 3, window_height= 3, stride_width= 2, stride_height= 2)o_d2 = general_deconv2d(o_d1, num_features=ngf, window_width= 3, window_height= 3, stride_width= 2, stride_height= 2) 最后,我们将这些低级特征转换得到一张在DB域中的图像,代码如下所示: gen_B = general_conv2d(o_d2, num_features= 3, window_width= 7, window_height= 7, stride_width= 1, stride_height= 1) 最后,我们得到了一个大小为[256,256,3]的生成图像genB,构建生成器的代码可以用如下函数实现: defbuild_generator(input_gen):o_c1 = general_conv2d(input_gen, num_features=ngf, window_width= 7, window_height= 7, stride_width= 1, stride_height= 1) o_c2 = general_conv2d(o_c1, num_features=ngf* 2, window_width= 3, window_height= 3, stride_width= 2, stride_height= 2) o_enc_A = general_conv2d(o_c2, num_features=ngf* 4, window_width= 3, window_height= 3, stride_width= 2, stride_height= 2) # Transformationo_r1 = build_resnet_block(o_enc_A, num_features= 64* 4) o_r2 = build_resnet_block(o_r1, num_features= 64* 4) o_r3 = build_resnet_block(o_r2, num_features= 64* 4) o_r4 = build_resnet_block(o_r3, num_features= 64* 4) o_r5 = build_resnet_block(o_r4, num_features= 64* 4) o_enc_B = build_resnet_block(o_r5, num_features= 64* 4) #Decodingo_d1 = general_deconv2d(o_enc_B, num_features=ngf* 2window_width= 3, window_height= 3, stride_width= 2, stride_height= 2) o_d2 = general_deconv2d(o_d1, num_features=ngf, window_width= 3, window_height= 3, stride_width= 2, stride_height= 2) gen_B = general_conv2d(o_d2, num_features= 3, window_width= 7, window_height= 7, stride_width= 1, stride_height= 1) returngen_B 构建鉴别器 (责任编辑:本港台直播) |