7

WGAN 笔记

 3 years ago
source link: https://lotabout.me/2018/WGAN/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client
Table of Contents

Wasserstein GAN(WGAN) 解决传统 GAN 的训练难,训练过程不稳定等问题了。WGAN 的背后有强劲的数学支撑,因此要想理解这它的原理,需要理解许多数学公式的推导。这个笔记尽量尝试从直觉的角度来理解 WGAN 背后的原理。

GAN 的问题

我们知道,GAN 的目的是训练一个生成器 G,使生成的数据的分布 PG 与真实数据的分布 Pdata 尽可能接近。为了衡量接近程度,GAN 使用 JS Divergence来衡量。

从应用的角度,我们甚至不需要知道它是什么,我们只要知道,对于两个分布 Pr 和 Pg,如果它们不重合或重合的部分可以忽略,则它们的 JS 距离 JS(Pr,Pg)=log2 是常数,用梯度下降时,产生的梯度会(近似)为 0。而在 GAN 的训练中,两个分布不重合或重合可忽略的情况 几乎总是出现,因此导致 GAN 的训练中

Wasserstein GAN

依旧地,我们甚至不需要知道 Wasserstein Distance 是什么,只需要知道它有着很好的性质,两个分布的差异都会反应在 Wasserstein Distance 上,因此,不会出现梯度消失的问题。

现在的问题是怎么计算它?答曰无法计算,但在 Wasserstein GAN 论文里证明了如下的事实:

W(Pdata,PG)=maxD∈1-Lipschitz{Ex∼Pdata[D(x)]−Ex∼PG[D(x)]}

在接下去之前我们先说说什么是 1-Lipschitz。如果一个函数 f 满足下面式子:

||f(x1)−f(x2)||≤K||x1−x2||

我们就称它为 K-Lipschitz,当 K=1时,就是 1-Lipschitz。

在图像生成的 GAN 中,上式中的 D(x) 可以认为是以图像为输入,输出图像的质量(是否接近真实图像)。那么我们可以找到两种类型的 D,一类变化剧烈,即赋予真实图像很大的值,而其它图像的值就很小(下图蓝色);另一类则变化平缓(下图绿色)。相像一下,如果用变化剧烈的 D 作为判别器去训练生成器,则会倾向于生成和真实图像一模一样的图片,导致多样性不高。而 1-Lipschitz 的作用就是限制 D 的变化要更平缓一些,是符合直觉的。

于是我们现在的目标是找到一个函数 D 满足 1-Lipschitz 且让上面的式子最大。“最大化”倒是好说,我们不断用梯度上升,但怎么保证我们的判别器 D 满足 1-Lipschitz 呢?还是没有办法,但我们可以做一些 workaround。

Weight Clipping

对于神经网络中的所有权重,在更新梯度后,我们事先选中某个常数 c, 做下面的操作:

  • 如果权重 w>c,则赋值 w←c
  • 如果权重 w<−c,则赋值 w←−c

直觉上,如果神经网络的权重都限制在一定的范围内,那么网络的输出也会被限定在一定范围内。换句话说,这个网络会属于某个 K-Lipschitz。当然,我们并不确定K 是多少,并且这样的函数也不一定能使 Ex∼Pdata[D(x)]−Ex∼PG[D(x)] 最大化。

不管怎么说吧,这就是原版 WGAN 的方法,对 GAN 的具大提升。

Gradient Penalty

新版的 WGAN 提出了不用 weight clipping,而用加惩罚项的方式,我们去优化下面这个目标:

W(Pdata,PG)=maxD{Ex∼Pdata[D(x)]−Ex∼PG[D(x)]−λ∫xmax(0,||∇xD(x)||−1)dx⏟regularization}

为什么呢?因为如果 D∈1-Lipschitz,显然对于所有 x,我们有 ||∇xD(x)||≤1。但同之前一样,我们无法穷举所有 x 求积分,于是我们又用期望来近似它,于是有:

W(Pdata,PG)=maxD{Ex∼Pdata[D(x)]−Ex∼PG[D(x)]−λEx∼Ppenalty[max(0,||∇xD(x)||−1)]⏟regularization}

那这里的 Ppenalty 又是什么?它代表的是输入 x 的分布,那具体如何采样呢?新版 WGAN 是这样设计的:

  1. 从真实数据 Pdata 中采样得到一个点
  2. 从生成器生成的数据 PG 中采样得到一个点
  3. 为这两个点连线
  4. 在线上随机采样得到一个点作为 Ppenalty 的点。

How to sample P_penalty

为什么这么采样?直觉上,我们想将 PG “拉”向 Pdata,于是希望 D 在它们之间的这些数据上能更平缓地变化。而惩罚项就是为了保证 D “平缓变化”的,于是正则项中的 Ppenalty 就在这些数据点上进行采样。

最后,实际中我们其实并不是用 max(0,||∇xD(x)||−1) 这个惩罚项,而是用 (||∇xD(x)||−1)2。也就是说,我们惩罚的目的不是让 ||∇xD(x)|| 尽可能小于1,而是要让它尽可能 等于 1。想象一个完美的判别器 D 满足优化的目标,则在 Pdata 附近它要尽可能大,而在 PG 附近要尽可能小,也就是说 D 越斜越好,但由于 ||∇xD(x)||≤1,那么 ||∇xD(x)|| 只能是 1。所以,真正的优化目标如下:

W(Pdata,PG)=maxD{Ex∼Pdata[D(x)]−Ex∼PG[D(x)]−λEx∼Ppenalty[(||∇xD(x)||−1)2]}

GAN 的优化目标是 JS Divergence,它有许多缺点不利于 GAN 的训练。Wasserstein Distance 是一个更好的距离度量,它最终可以转化为优化问题,我们需要找出一个判别器 D,并要求它满足 1-Lipschitz。实际使用时我们并做不到这一点,于是有两种方法来近似:weight clipping 和 gradient penalty。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK