但是,我们在不配对的数据集中没有这个对象,也没有预先定义好的用于学习的有意义转换,所以我们将要创建它。我们需要确保输入图像和生成图像之间存在一些有意义的关联。 所以,作者试图通过生成器将输入图像(inputA)从域DA映射到目标域DB中,转换成对应图像。但是为了确保这些图像之间存在有意义的关系,它们必须共享一些特征,这些特征可用于将此输出图像映射回输入图像,因此必须有另一个生成器能将此输出图像映射回原始域。因此,我们需要定义inputA和genB之间有意义的映射。 简而言之,该模型通过从域DA获取输入图像,该输入图像被传递到第一个生成器GeneratorA→B,其任务是将来自域DA的给定图像转换到目标域DB中的图像。然后这个新生成的图像被传递到另一个生成器GeneratorB→A,其任务是在原始域DA转换回图像CyclicA,这里可与自动编码器作对比。 正如上面讨论的,这个输出图像必须与原始输入图像相似,用来定义非配对数据集中原来不存在的有意义映射。 如图5所示,两个输入被传递到对应的鉴别器(一个是对应于该域的原始图像,另一个是通过生成器产生的图像),并且鉴别器的任务是区分它们,识别出生成器输出的生成图像,并拒绝此生成图像。生成器想要确保这些图像被鉴别器接受,所以它将尝试生成与DB类中原始图像非常接近的新图像。事实上,在生成器分布与所需分布相同时,生成器和鉴别器之间实现了纳什均衡(Nash equilibrium)。 我们可以通过TensorFlow轻松实现CycleGAN,下面将介绍CycleGAN各部分的实现细节,可在GitHub上找到完整代码。 构建生成器 生成器的结构已在下图列出。
图6:生成器结构 生成器由三个部分组成:编码器、转换器和解码器。 该生成器的超参数定义如下,包括卷积核个数、批数量、池化大小和输入图像的格式: ngf = 32# Number of filters in first layer of generatorndf = 64# Number of filters in first layer of discriminatorbatch_size = 1# batch_sizepool_size = 50# pool_sizeimg_width = 256# Imput image will of width 256img_height = 256# Input image will be of height 256img_depth = 3# RGB format 前三个参数简单易懂,我们将在生成图像库部分中解释pool_size的含义。 编码 为了简单起见,在此文章中我们把输入大小固定设置为[256,256,3]。第一步是利用卷积网络从输入图像中提取特征。要了解有关卷积网络的基础知识,你可以查看文末的CNN介绍链接。卷积网络将一张图像作为输入,不同大小的卷积核能在输入图像上移动并提取特征,步幅(stride)大小能决定在图像中卷积核窗口的数量。所以编码器的第一层定义如下: o_c1 = general_conv2d(input_gen, num_features=ngf, window_width= 7, window_height= 7, stride_width= 1, stride_height= 1) 其中,input_gen是生成器的输入图像,num_features是在卷积层中卷积得到的特征图谱数量,也可以看作是提取不同特征的滤波器数量。window_width和window_height表示在输入图像上滑动来提取特征的滤波器窗口大小。类似地,stride_width和stride_height定义了每次迭代后滤波器的移位方式。输出Oc1是尺寸为[256,256,64]的张量,继续传输给下个卷积层。这里,鉴别器第一层的滤波器个数设置为64,完成对general_conv2d函数的定义。当然可以添加其他层,如ReLU层或批归一化层(BN层),在本教程中跳过这些层的介绍。 defgeneral_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1):withtf.variable_scope(name): conv = tf.contrib.layers.conv2d(inputconv, num_features, [window_width, window_height], [stride_width, stride_height], padding, activation_fn= None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev), biases_initializer=tf.constant_initializer( 0.0)) 接下来: o_c2 = general_conv2d(o_c1, num_features= 64* 2, window_width= 3, window_height= 3, stride_width= 2, stride_height= 2) # o_c2.shape = (128, 128, 128)o_enc_A = general_conv2d(o_c2, num_features= 64* 4, window_width= 3, window_height= 3, stride_width= 2, stride_height= 2) # o_enc_A.shape = (64, 64, 256) (责任编辑:本港台直播) |