3

GAN[01]-生成对抗网络:从GAN到到WGAN

 2 years ago
source link: https://yerfor.github.io/2020/02/06/gan-01/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

GAN[01]-生成对抗网络:从GAN到到WGAN

发表于

2020-02-06 更新于 2020-03-09

生成对抗网络(Generative Adversarial Networks, GAN)是Ian J. Goodfellow在2014年提出的机器学习框架,广泛应用于生成任务中(目前有图像生成/AI换脸/风格迁移等应用),是近年来计算机视觉领域最热门的研究方向之一。

GAN包含两个模块:生成模型(Generative Model)和判别模型(Discriminative Model),生成模型用于生成数据(可以是图像、点云等任意高维数据,前提是要有对应的database,GAN并不能无中生有),判别模型用于判断接收的数据是真实数据还是由生成模型伪造的。生成模型和判别模型在相互博弈的过程中逐渐变强,可以生成以假乱真的图像。

本文主要内容有:

  • 介绍了生成对抗网络(GAN)的概念和具体算法
  • 指出了传统GAN用JS散度作为度量的缺陷
  • 介绍了WGAN的原理,将GAN与WGAN的效果进行了对比,
  • 提供了WGAN-GP的具体实现代码

1. 生成对抗网络(GAN)介绍

1.1 Generator和Discriminator

  • Generator和Discriminator都是神经网络。
  • Generator网络的输入通常是一个来自于高斯分布的噪声向量,在网络中经过一系列上采样,输出一张图片或其他高维数据,下图是用于图像生成和文本生成时的Generator输入输出结构。
  • 可以把Generator生成图片的过程视作:
    • 输入一个高斯分布,Generator将该分布扭曲变形成一个和Database差不多形状的分布,生成的图片都是这个分布的一个采样。

0

  • Discriminator网络的输入是和Generator输出同类型的数据,如下图所示,当完成图像生成任务时,Discriminator的输入是图片,它的任务是判断当前输入的图片是真实的图片还是Generator伪造的。Discriminator的输出是[0,1]的标量,输出越接近1代表iscriminator认为当前输入的图片Discriminator认为当前输入的图片越像真实的图像。

1

  • 在训练的每个迭代的步骤中,都会先把真实图像和Generator生成的图像都传给Discriminator,Discriminator要学会分辨真实图像和伪造的图像(相当于一个简单的二分类问题);随后Generator会朝增大Discriminator给自己伪造的图片打的分数的方向更新。
  • 在训练过程中,Generator和Discriminator都会不断变强。Discriminator会发现越来越多真实图像具有的特征,具有越来越强的辨真伪的能力;而Generator要骗过Discriminator,伪造图像的能力也会不断增强,直到在人类眼里可以以假乱真。

1.2 算法流程

  • 随机初始化Generator和Discriminator的参数
  • 对每个迭代步骤:
    • 从数据集中随机抽取m个样本{$x_1,…,x_m$}
    • 从高斯分布中随机抽取m个噪声{$z_1,…,z_m$}
    • 将噪声全部输入到Generator网络中,得到m个生成的图像{$x’_1,…,x’_m$},其中$x’_i=Generator(z_i)$
    • 将数据集中的m个样本和伪造的m个图像都输入到Discriminator中,用以下方式更新Discriminator:
      • 定义目标函数:$V=\frac{1}{m}\Sigma^{m}_{i=1}logD(x_i)+\frac{1}{m}\Sigma^{m}_{i=1}log(1-D(x’_i))$
        • 上式类似于分类问题中的交叉熵,最大化上式,意味着最大化对真实图像的评分$D(x_i)$,同时最小化对虚假图像的评分$D(x’_i)$。
      • 采用梯度法最大化目标函数:$\theta_d \leftarrow \theta_d + \eta\cdot\nabla V(\theta_d)$
      • 至此完成了对Discriminator的更新
    • 将之前生成的m张图像重新传入更新后的Discriminator网络中,得到m个评分:$D(G(z_i))$
    • 用以下方式更新Generator:
      • 定义目标函数:$V = -\frac{1}{m}\Sigma_{i=1}^{m}logD(G(z_i))$
        • 上式除了负号部分奖励了Generator伪造的图片“骗过了”Discriminator,得到高分的情况。加上负号后变成了惩罚,所以要最小化该式
      • 利用梯度法最小化目标函数:$\theta_g \leftarrow \theta_g - \eta\cdot\nabla V(\theta_g)$
      • 至此完成了对Generator的更新

2

1.3 最小化目标函数使Discriminator判别分布间的的JS散度

GAN的目标是生成和数据集相同风格,可以以假乱真的图像,从统计学的角度看,实际上就是要拟合出一个可以描述数据集图像的概率分布$P_{data}(x)$。注意$x$是Generator的输出,即图像。

3

如下图所示,Generator的输入$z$是一个高斯分布,经过Generator后https://github.com/yerfor/WGAN-TensorFlow2该分布被弯曲变形成下图所示的$P_G(x)$,在不断训练的过程中,$P_G(x)$会越来越趋同于$P_{data}(x)$。统计学中用散度($Div$)来描述两组概率分布的相似程度,所以我们期望训练出来的Generator可以表示为$G=argminDiv(P_G,P_{data})$。

4

要计算$Div(P_G,P_{data})$,必须要能够判断一个图像属不属于$P_{data}$,根据1.1中GAN的算法,我们训练了Discriminator来拟合$P_{data}(x)$的形状,只有落在该区域内的图像才能获得高分。当Discriminator训练的足够好的时候,我们认为它已经可以描述$P_{data}(x)$,下面是条件:

  • 当然,在GAN中我们不可能一步到位地训练Discriminator,因为我们没有足够的负样本,实作中这些负样本都要由Generator来提供。所以一般都是把二个网络一起迭代更新。

我们认为被最优化后的$V$可以最好的分辨图片是否属于Data,可以数学证明(略):

也就是说我们用交叉熵作为目标函数训练Discriminator,实际是教会了Discriminator计算输入图片和数据集分布之间的JS散度

如果让Generator的目标是最小化$maxV(G,D)$,那么训练Generator 的过程实际上就是减小Generator生成的数据分布和数据集分布之间的JS散度。

  • 本节讨论的意义在于:修改目标函数V,可以起到修改作为度量标准的散度的目的。如下表所示:

5

2. Wasertein GAN (WGAN)

2.1 传统GAN的问题

Ian Fellow提出的传统GAN模型,在实际训练时会出现很难收敛,以及存在Mode Collapse和Mode Dropping等问题。

  • Mode Collapse:多个噪声输入经过Generator产生的都是同一个图像)

mode_dropping

  • Mode Dropping:每次迭代的Generator都只能产生某一种特征,比如下图所示,一次只产生白皮肤、黄皮肤。日哦批评

mode_dropping

这很可能是因为传统GAN采用的JS散度不能很好描述两个分布之间的距离。

由数学推导知道,用传统交叉熵作为目标函数训练,Discriminator训练到极致的时候,它的函数D(x)其实就是评价输入的图像和我们现有的图像数据分布之间的JS散度。然而JS散度有一个巨大的问题,就是当两个数据没有overlap的时候,它的值始终是log2,这个值和数据分布之间的距离无关,仅在二者重叠时有一个突然的变化,这不利于我们模型的收敛。

从梯度下降的角度来分析交叉熵作为Discriminator的目标函数的缺点,可能更加直观。我们知道Discriminator训练到最好的情况下评价的是JS散度,而由于真实图片和生成的图片之间很可能没有overlap,这会导致D(x)变成一个方形波的形状,即对真实图片都是1,对虚假图片都是-1(因为传统GAN采用了sigmoid的激活函数激活函数),这会导致梯度消失,如下图左边子图所示。一个解决思路是简单地把Sigmoid函数去掉,如下图右边子图所示(这个就是LSGAN),这虽然能阻止梯度消失,但使用的Discriminator采用的仍然是趋近于JS散度的度量模式,效果也不是很好。

lsgan

2.2 WGAN

WGAN度量的是Earth Moving距离,即将一个分布搬到另一个分布需要搬运的数据量最小值作为度量两个数据分布之间距离的标准。Earth Moving距离是一个需要求解min的数学式,但经过推导后发现,可以由下图的式子来表示:

wgan

发现该式子完全继承了上一小节的LSGAN的目标函数,只是多亏了一个$D(x)\in 1-Lipschitzd$的约束。这个约束的意思就是D(x)函数在任意一点的导数不能大于1。

当时提出WGAN的人最初没有想到好方法满足该式,于是简单地对Discriminator做了一个weight clilpping,实验下来发现也能用。

后来,WGAN的作者再次发文,指出D满足$1-Lipschitzd$等价于D(x)相对于x(x是discriminator的输入,即图像)的梯度小于等于1,所以在目标函数的后面剪去了一项Gradient Penalty,这样目标函数变为:

这种改良后的WGAN因为增加了Gradient Penalty,也被称作WGAN-GP。它能够显著减少Mode Collapse和Mode Dropping。

WGAN是对传统GAN的全方面优化,下面是两个结构performance的对比:

  • GAN,iterations=10000

gan-10000

  • WGAN-GP,iterations=10000

10000

2.3 Coding

使用Tensorflow 2.0复现了WGAN-GP的效果,项目已经上传到github上了,召唤传送门

batch_size=512,训练了3万个iterations,训练过程如下图所示:

WGAN-demo


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK