本港台开奖现场直播 j2开奖直播报码现场
当前位置: 新闻频道 > IT新闻 >

码报:【组图】带你理解CycleGAN,并用TensorFlow轻松实现(5)

时间:2017-05-27 22:44来源:本港台直播 作者:118KJ 点击:
最终生成器应该能够提高鉴别器对生成图像的输出值。如果鉴别器对生成图像的输出值尽可能接近1,则生成器的作用达到。故生成器想要最小化“(Discri

  最终生成器应该能够提高鉴别器对生成图像的输出值。如果鉴别器对生成图像的输出值尽可能接近1,则生成器的作用达到。故生成器想要最小化“(DiscriminatorB(GeneratorA→B(a))?1)2”,因此损失为:

  g_loss_B_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A, 1))g_loss_A_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A, 1))

  循环损失

  最后一个重要参数为循环丢失(cyclic loss),能判断用另一个生成器得到的生成图像与原始图像的差别。因此原始图像和循环图像之间的差异应该尽可能小。

  cyc_loss = tf.reduce_mean(tf.abs(input_A-cyc_A)) + tf.reduce_mean(tf.abs(input_B-cyc_B))

  所以完整的生成器损失为:

  g_loss_A = g_loss_A_1 + 10*cyc_lossg_loss_B = g_loss_B_1 + 10*cyc_loss

  cyc_loss的乘法因子设置为10,说明循环损失比鉴别损失更重要。

  混合参数

  定义好损失函数,接下来只需要训练模型来最小化损失函数。

  d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars) 训练模型 forepoch inrange( 0, 100): # Define the learning rate schedule. The learning rate is kept# constant upto 100 epochs and then slowly decayedif(epoch < 100) : curr_lr = 0.0002else: curr_lr = 0.0002- 0.0002*(epoch- 100)/ 100# Running the training loop for all batchesforptr inrange( 0,num_images): # Train generator G_A->B_, gen_B_temp = sess.run([g_A_trainer, gen_B], feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr}) # We need gen_B_temp because to calculate the error in training D_B_ = sess.run([d_B_trainer], feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr}) # Same for G_B->A and D_A as follow_, gen_A_temp = sess.run([g_B_trainer, gen_A], feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr}) _ = sess.run([d_A_trainer], feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

  你可以在训练函数中看到,在训练时需要不断调用不同鉴别器和生成器。为了训练模型,需要输入训练图像和选择优化器的学习率。由于batch_size设置为1,所以num_batches等于num_images。

  我们已经完成了模型构建,下面是模型中一些默认超参数。

  生成图像库

  计算每个生成图像的鉴别器损失是不可能的,因为会耗费大量的计算资源。为了加快训练,我们存储了之前每个域的所有生成图像,并且每次仅使用一张图像来计算误差。首先,逐个填充图像库使其完整,然后随机将某个库中的图像替换为最新的生成图像,并使用这个替换图像来作为该步的训练。

  defimage_pool(self, num_gen, gen_img, gen_pool):if(num_gen < pool_size): gen_img_pool[num_gen] = gen_img returngen_img else: p = random.random() ifp > 0.5: # Randomly selecting an id to return for calculating the discriminator lossrandom_id = random.randint( 0,pool_size- 1) temp = gen_img_pool[random_id] gen_pool[random_id] = gen_img returntemp else: returngen_img gen_image_pool_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name= "gen_img_pool_A")gen_image_pool_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name= "gen_img_pool_B")gen_pool_rec_A = build_gen_discriminator(gen_image_pool_A, "d_A")gen_pool_rec_B = build_gen_discriminator(gen_image_pool_B, "d_B") # Also the discriminator loss will change as followD_A_loss_2 = tf.reduce_mean(tf.square(gen_pool_rec_A))D_A_loss_2 = tf.reduce_mean(tf.square(gen_pool_rec_A))

  图像库代码仍需要微小的修改,完整代码见文末。

  结果

  我们运行了野马转斑马的模型,但是由于缺乏图像库,该模型只运行了100步,得到以下结果。

  

码报:【j2开奖】带你理解CycleGAN,并用TensorFlow轻松实现

  图9:野马转斑马的实际效果 讨论

  1.在训练时,我们发现初始化很大程度影响了输出结果,因此通过多次训练来获得最佳效果。你会发现图10中特殊的背景颜色,这个效果只有在10-20步的训练时才能观察到,你可以再运行代码试试。

  

码报:【j2开奖】带你理解CycleGAN,并用TensorFlow轻松实现

  图10:该模型出现失真效果

(责任编辑:本港台直播)
顶一下
(0)
0%
踩一下
(0)
0%
------分隔线----------------------------
栏目列表
推荐内容