本港台开奖现场直播 j2开奖直播报码现场
当前位置: 新闻频道 > IT新闻 >

码报:【组图】带你理解CycleGAN,并用TensorFlow轻松实现(2)

时间:2017-05-27 22:44来源:本港台直播 作者:118KJ 点击:
但是,我们在不配对的数据集中没有这个对象,也没有预先定义好的用于学习的有意义转换,所以我们将要创建它。我们需要确保输入图像和生成图像之间

  但是,我们在不配对的数据集中没有这个对象,也没有预先定义好的用于学习的有意义转换,所以我们将要创建它。我们需要确保输入图像和生成图像之间存在一些有意义的关联。

  所以,作者试图通过生成器将输入图像(inputA)从域DA映射到目标域DB中,转换成对应图像。但是为了确保这些图像之间存在有意义的关系,它们必须共享一些特征,这些特征可用于将此输出图像映射回输入图像,因此必须有另一个生成器能将此输出图像映射回原始域。因此,我们需要定义inputA和genB之间有意义的映射。

  简而言之,该模型通过从域DA获取输入图像,该输入图像被传递到第一个生成器GeneratorA→B,其任务是将来自域DA的给定图像转换到目标域DB中的图像。然后这个新生成的图像被传递到另一个生成器GeneratorB→A,其任务是在原始域DA转换回图像CyclicA,这里可与自动编器作对比。

  正如上面讨论的,这个输出图像必须与原始输入图像相似,用来定义非配对数据集中原来不存在的有意义映射。

  如图5所示,两个输入被传递到对应的鉴别器(一个是对应于该域的原始图像,另一个是通过生成器产生的图像),并且鉴别器的任务是区分它们,识别出生成器输出的生成图像,并拒绝此生成图像。生成器想要确保这些图像被鉴别器接受,所以它将尝试生成与DB类中原始图像非常接近的新图像。事实上,在生成器分布与所需分布相同时,生成器和鉴别器之间实现了纳什均衡(Nash equilibrium)。

  我们可以通过TensorFlow轻松实现CycleGAN,下面将介绍CycleGAN各部分的实现细节,可在GitHub上找到完整代

  构建生成器

  生成器的结构已在下图列出。

  

码报:【j2开奖】带你理解CycleGAN,并用TensorFlow轻松实现

  图6:生成器结构

  生成器由三个部分组成:编码器、转换器和解码器。

  该生成器的超参数定义如下,包括卷积核个数、批数量、池化大小和输入图像的格式:

  ngf = 32# Number of filters in first layer of generatorndf = 64# Number of filters in first layer of discriminatorbatch_size = 1# batch_sizepool_size = 50# pool_sizeimg_width = 256# Imput image will of width 256img_height = 256# Input image will be of height 256img_depth = 3# RGB format

  前三个参数简单易懂,我们将在生成图像库部分中解释pool_size的含义。

  编码

  为了简单起见,在此文章中我们把输入大小固定设置为[256,256,3]。第一步是利用卷积网络从输入图像中提取特征。要了解有关卷积网络的基础知识,你可以查看文末的CNN介绍链接。卷积网络将一张图像作为输入,不同大小的卷积核能在输入图像上移动并提取特征,步幅(stride)大小能决定在图像中卷积核窗口的数量。所以编码器的第一层定义如下:

  o_c1 = general_conv2d(input_gen, num_features=ngf, window_width= 7, window_height= 7, stride_width= 1, stride_height= 1)

  其中,input_gen是生成器的输入图像,num_features是在卷积层中卷积得到的特征图谱数量,也可以看作是提取不同特征的滤波器数量。window_width和window_height表示在输入图像上滑动来提取特征的滤波器窗口大小。类似地,stride_width和stride_height定义了每次迭代后滤波器的移位方式。输出Oc1是尺寸为[256,256,64]的张量,继续传输给下个卷积层。这里,鉴别器第一层的滤波器个数设置为64,完成对general_conv2d函数的定义。当然可以添加其他层,如ReLU层或批归一化层(BN层),在本教程中跳过这些层的介绍。

  defgeneral_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1):withtf.variable_scope(name): conv = tf.contrib.layers.conv2d(inputconv, num_features, [window_width, window_height], [stride_width, stride_height], padding, activation_fn= None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev), biases_initializer=tf.constant_initializer( 0.0))

  接下来:

  o_c2 = general_conv2d(o_c1, num_features= 64* 2, window_width= 3, window_height= 3, stride_width= 2, stride_height= 2) # o_c2.shape = (128, 128, 128)o_enc_A = general_conv2d(o_c2, num_features= 64* 4, window_width= 3, window_height= 3, stride_width= 2, stride_height= 2) # o_enc_A.shape = (64, 64, 256)

(责任编辑:本港台直播)
顶一下
(0)
0%
踩一下
(0)
0%
------分隔线----------------------------
栏目列表
推荐内容