本港台开奖现场直播 j2开奖直播报码现场
当前位置: 新闻频道 > IT新闻 >

码报:【组图】带你理解CycleGAN,并用TensorFlow轻松实现(4)

时间:2017-05-27 22:44来源:本港台直播 作者:118KJ 点击:
我们讨论了如何构建生成器,但是为了完成网络的对抗训练部分,还需要构建鉴别器。鉴别器将一张图像作为输入,并尝试预测其为原始图像或是生成器的

  我们讨论了如何构建生成器,但是为了完成网络的对抗训练部分,还需要构建鉴别器。鉴别器将一张图像作为输入,并尝试预测其为原始图像或是生成器的输出图像。生成器的结构如下所示:

  

码报:【j2开奖】带你理解CycleGAN,并用TensorFlow轻松实现

  图8:生成器的结构

  鉴别器本身就属于卷积网络,需要从图像中提取特征。

  o_c1 = general_conv2d(input_disc, ndf, f, f, 2, 2)o_c2 = general_conv2d(o_c1, ndf* 2, f, f, 2, 2)o_enc_A = general_conv2d(o_c2, ndf* 4, f, f, 2, 2)o_c4 = general_conv2d(o_enc_A, ndf* 8, f, f, 2, 2)

  下一步是确定这些特征是否属于该特定类别,添加一个产生1维输出的卷积层来完成这个任务。这里,ndf表示鉴别器初始层的特征个数,可以尝试调整来获得最佳效果。

  decision = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02)

  我们已经完成该模型的两个主要组成部分,即生成器和鉴别器。由于要使这个模型既可以从A→B和B→A两个方向工作,我们设置了两个生成器,直播,即生成器A→B和生成器B→A,以及两个鉴别器,即鉴别器A和鉴别器B。

  建立模型

  在定义损失函数前,先定义基础输入变量,来构建模型。

  input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name= "input_A")input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name= "input_B")

  这些占位符将作为输入,同时定义模型如下:

  gen_B = build_generator(input_A, name= "generator_AtoB")gen_A = build_generator(input_B, name= "generator_BtoA")dec_A = build_discriminator(input_A, name= "discriminator_A")dec_B = build_discriminator(input_B, name= "discriminator_B")dec_gen_A = build_discriminator(gen_A, "discriminator_A")dec_gen_B = build_discriminator(gen_B, "discriminator_B")cyc_A = build_generator(gen_B, "generator_BtoA")cyc_B = build_generator(gen_A, "generator_AtoB")

  上面的变量名在本质上是非常直观的。gen表示使用相应的生成器后生成的图像,dec表示在将相应输入传递到鉴别器后做出的判断。

  损失函数

  现在我们有两个生成器和两个鉴别器。我们要按照实际目的来设计损失函数。损失函数应该包括如下四个部分:

鉴别器必须允许所有相应类别的原始图像,即对应输出置1;

鉴别器必须拒绝所有想要愚弄过关的生成图像,即对应输出置0;

生成器必须使鉴别器允许通过所有的生成图像,来实现愚弄操作;

所生成的图像必须保留有原始图像的特性,所以如果我们使用生成器GeneratorA→B生成一张假图像,那么要能够使用另一个生成器GeneratorB→A来努力恢复成原始图像。此过程必须满足循环一致性。

  鉴别器损失

  第1部分

  我们通过训练鉴别器,使其对A类图像的输出接近于1,鉴别器B也是如此。鉴别器A的训练目标为最小化“(DiscriminatorA(a)?1)2”的值,鉴别器B也是如此。对应代码如下:

  D_A_loss_1 = tf.reduce_mean(tf.squared_difference(dec_A, 1))D_B_loss_1 = tf.reduce_mean(tf.squared_difference(dec_B, 1))

  第2部分

  由于鉴别器应该能够区分生成图像和原始图像,所以在处理生成图像时期望输出为0,即鉴别器A要最小化“(DiscriminatorA(GeneratorB→A(b)))2”的值。对应代码如下:

  D_A_loss_2 = tf.reduce_mean(tf.square(dec_gen_A))D_B_loss_2 = tf.reduce_mean(tf.square(dec_gen_B))D_A_loss = (D_A_loss_1 + D_A_loss_2)/ 2D_B_loss = (D_B_loss_1 + D_B_loss_2)/ 2

  生成器损失

(责任编辑:本港台直播)
顶一下
(0)
0%
踩一下
(0)
0%
------分隔线----------------------------
栏目列表
推荐内容