我们讨论了如何构建生成器,但是为了完成网络的对抗训练部分,还需要构建鉴别器。鉴别器将一张图像作为输入,并尝试预测其为原始图像或是生成器的输出图像。生成器的结构如下所示:
图8:生成器的结构 鉴别器本身就属于卷积网络,需要从图像中提取特征。 o_c1 = general_conv2d(input_disc, ndf, f, f, 2, 2)o_c2 = general_conv2d(o_c1, ndf* 2, f, f, 2, 2)o_enc_A = general_conv2d(o_c2, ndf* 4, f, f, 2, 2)o_c4 = general_conv2d(o_enc_A, ndf* 8, f, f, 2, 2) 下一步是确定这些特征是否属于该特定类别,添加一个产生1维输出的卷积层来完成这个任务。这里,ndf表示鉴别器初始层的特征个数,可以尝试调整来获得最佳效果。 decision = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02) 我们已经完成该模型的两个主要组成部分,即生成器和鉴别器。由于要使这个模型既可以从A→B和B→A两个方向工作,我们设置了两个生成器,直播,即生成器A→B和生成器B→A,以及两个鉴别器,即鉴别器A和鉴别器B。 建立模型 在定义损失函数前,先定义基础输入变量,来构建模型。 input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name= "input_A")input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name= "input_B") 这些占位符将作为输入,同时定义模型如下: gen_B = build_generator(input_A, name= "generator_AtoB")gen_A = build_generator(input_B, name= "generator_BtoA")dec_A = build_discriminator(input_A, name= "discriminator_A")dec_B = build_discriminator(input_B, name= "discriminator_B")dec_gen_A = build_discriminator(gen_A, "discriminator_A")dec_gen_B = build_discriminator(gen_B, "discriminator_B")cyc_A = build_generator(gen_B, "generator_BtoA")cyc_B = build_generator(gen_A, "generator_AtoB") 上面的变量名在本质上是非常直观的。gen表示使用相应的生成器后生成的图像,dec表示在将相应输入传递到鉴别器后做出的判断。 损失函数 现在我们有两个生成器和两个鉴别器。我们要按照实际目的来设计损失函数。损失函数应该包括如下四个部分: 鉴别器必须允许所有相应类别的原始图像,即对应输出置1; 鉴别器必须拒绝所有想要愚弄过关的生成图像,即对应输出置0; 生成器必须使鉴别器允许通过所有的生成图像,来实现愚弄操作; 所生成的图像必须保留有原始图像的特性,所以如果我们使用生成器GeneratorA→B生成一张假图像,那么要能够使用另一个生成器GeneratorB→A来努力恢复成原始图像。此过程必须满足循环一致性。 鉴别器损失 第1部分 我们通过训练鉴别器,使其对A类图像的输出接近于1,鉴别器B也是如此。鉴别器A的训练目标为最小化“(DiscriminatorA(a)?1)2”的值,鉴别器B也是如此。对应代码如下: D_A_loss_1 = tf.reduce_mean(tf.squared_difference(dec_A, 1))D_B_loss_1 = tf.reduce_mean(tf.squared_difference(dec_B, 1)) 第2部分 由于鉴别器应该能够区分生成图像和原始图像,所以在处理生成图像时期望输出为0,即鉴别器A要最小化“(DiscriminatorA(GeneratorB→A(b)))2”的值。对应代码如下: D_A_loss_2 = tf.reduce_mean(tf.square(dec_gen_A))D_B_loss_2 = tf.reduce_mean(tf.square(dec_gen_B))D_A_loss = (D_A_loss_1 + D_A_loss_2)/ 2D_B_loss = (D_B_loss_1 + D_B_loss_2)/ 2 生成器损失 (责任编辑:本港台直播) |