本港台开奖现场直播 j2开奖直播报码现场
当前位置: 新闻频道 > IT新闻 >

wzatv:【j2开奖】学界 | 最小二乘GAN:比常规GAN更稳定,比WGAN收敛更迅速

时间:2017-03-05 18:41来源:118图库 作者:本港台直播 点击:
近来 GAN 证明是十分强大的。因为当真实数据的概率分布不可算时,传统生成模型无法直接应用,而 GAN 能以对抗的性质逼近概率分布。但其也有很大的限制,因为函数饱和过快,当判

近来 GAN 证明是十分强大的。因为当真实数据的概率分布不可算时,开奖,传统生成模型无法直接应用,而 GAN 能以对抗的性质逼近概率分布。但其也有很大的限制,因为函数饱和过快,当判别器越好时,生成器的消失也就越严重。所以不论是 WGAN 还是本文中的 LSGAN 都是试图使用不同的距离度量,从而构建一个不仅稳定,同时还收敛迅速的生成对抗网络。

项目地址:

由于生成对抗网络训练的一般框架 F-GAN 已经构建了起来,最近我们可以看到一些并不像常规 GAN 的修订版生成对抗网络,它们会学习使用其它度量方法,而不只是 Jensen-Shannon 散度 (Jensen-Shannon divergence/JSD)。

其中一个修订版就是 Wasserstein 生成对抗网络(WGAN),该生成网络使用 Wasserstein 距离度量而不是 JSD。Wasserstein GAN 运行十分流畅,甚至其作者都声称该系统已经克服了模型崩溃难题并给生成对抗提供了十分强大的损失函数。尽管 Wasserstein GAN 的实现是很直接的,但在 WGAN 背后的理论是十分困难并需要一些如权重剪枝(weight clipping)等「hack」知识。另外 WGAN 的训练过程和收敛都要比常规 GAN 要慢一点。

现在,问题是:我们能设计一个比 WGAN 运行得更稳定、收敛更快速、流程更简单更直接的生成对抗网络吗?我们的答案是肯定的!

最小二乘生成对抗网络

LSGAN 的主要思想就是在辨别器 D 中使用更加平滑和非饱和(non-saturating)梯度的损失函数。我们想要辨别器(discriminator)D 将生成器(generator)G 所生成的数据「拖」到真实数据流形(data manifold)Pdata(X),从而使得生成器 G 生成类似 Pdata(X) 的数据。

我们知道在常规 GAN 中,辨别器使用的是对数损失(log loss.)。而对数损失的决策边界就如下图所示:

  

wzatv:【j2开奖】学界 | 最小二乘GAN:比常规GAN更稳定,比WGAN收敛更迅速

因为辨别器 D 使用的是 sigmoid 函数,并且由于 sigmoid 函数饱和得十分迅速直播,所以即使是十分小的数据点 x,该函数也会迅速忽略 x 到决策边界 w 的距离。这也就意味着 sigmoid 函数本质上不会惩罚远离 w 的 x。这也就说明我们满足于将 x 标注正确,因此随着 x 变得越来越大,辨别器 D 的梯度就会很快地下降到 0。因此对数损失并不关心距离,它仅仅关注于是否正确分类。

为了学习 Pdata(X) 的流形(manifold),对数损失(log loss)就不再有效了。由于生成器 G 是使用辨别器 D 的梯度进行训练的,那么如果辨别器的梯度很快就饱和到 0,生成器 G 就不能获取足够学习 Pdata(X) 所需要的信息。

输入 L2 损失(L2 loss):

  

wzatv:【j2开奖】学界 | 最小二乘GAN:比常规GAN更稳定,比WGAN收敛更迅速

在 L2 损失(L2 loss)中,与 w(即上例图中 Pdata(X) 的回归线)相当远的数据将会获得与距离成比例的惩罚。因此梯度就只有在 w 完全拟合所有数据 x 的情况下才为 0。如果生成器 G 没有没有捕获数据流形(data manifold),那么这将能确保辨别器 D 服从多信息梯度(informative gradients)。

在优化过程中,辨别器 D 的 L2 损失想要减小的唯一方法就是使得生成器 G 生成的 x 尽可能地接近 w。只有这样,生成器 G 才能学会匹配 Pdata(X)。

最小二乘生成对抗网络(LSGAN)的整体训练目标可以用以下方程式表达:

  

wzatv:【j2开奖】学界 | 最小二乘GAN:比常规GAN更稳定,比WGAN收敛更迅速

在上面方程式中,我们选择 b=1 表明它为真实的数据,a=0 表明其为伪造数据。最后 c=1 表明我们想欺骗辨别器 D。

但是这些值并不是唯一有效的值。LSGAN 作者提供了一些优化上述损失的理论,即如果 b-c=1 并且 b-a=2,那么优化上述损失就等同于最小化 Pearson χ^2 散度(Pearson χ^2 divergence)。因此,选择 a=-1、b=1 和 c=0 也是同样有效的。

我们最终的训练目标就是以下方程式所表达的:

  

在 Pytorch 中 LSGAN 的实现

先将我们对常规生成对抗网络的修订给写出来:

1. 从辨别器 D 中移除对数损失

2. 使用 L2 损失代替对数损失

所以现在先让我们从第一个检查表(checklist)开始

G = torch.nn.Sequential(

torch.nn.Linear(z_dim, h_dim),

torch.nn.ReLU(),

torch.nn.Linear(h_dim, X_dim),

torch.nn.Sigmoid()

)

D = torch.nn.Sequential(

torch.nn.Linear(X_dim, h_dim),

torch.nn.ReLU(),

# No sigmoid

torch.nn.Linear(h_dim, 1),

)

G_solver = optim.Adam(G.parameters(), lr=lr)

D_solver = optim.Adam(D.parameters(), lr=lr)

剩下的就十分简单直接了,跟着上面的损失函数做就行。

for it in range(1000000):

# Sample data

z = Variable(torch.randn(mb_size, z_dim))

X, _ = mnist.train.next_batch(mb_size)

X = Variable(torch.from_numpy(X))

# Dicriminator

G_sample = G(z)

D_real = D(X)

D_fake = D(G_sample)

# Discriminator loss

D_loss = 0.5 * (torch.mean((D_real - 1)**2) + torch.mean(D_fake**2))

D_loss.backward()

D_solver.step()

reset_grad()

# Generator

G_sample = G(z)

D_fake = D(G_sample)

# Generator loss

G_loss = 0.5 * torch.mean((D_fake - 1)**2)

G_loss.backward()

G_solver.step()

reset_grad()

完整的代可以在此获得:https://github.com/wiseodd/generative-models

结语

(责任编辑:本港台直播)
顶一下
(0)
0%
踩一下
(0)
0%
------分隔线----------------------------
栏目列表
推荐内容