50

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

 5 years ago
source link: https://www.tuicool.com/articles/NvqIVrj
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Last updated on 2019年7月1日

CycleGAN,一个可以将一张图像的特征迁移到另一张图像的酷算法,此前可以完成马变斑马、冬天变夏天、苹果变桔子等一颗赛艇的效果。

6nAfAvU.jpg!web

这行被顶会ICCV收录的研究自提出后,就为图形学等领域的技术人员所用,甚至还成为不少艺术家用来创作的工具。

FzMzQjm.jpg!web

也是目前大火的“换脸”技术的老前辈了。

b6b22mY.jpg!web

如果你还没学会这项厉害的研究,那这次一定要抓紧上车了。

现在,TensorFlow开始手把手教你,在TensorFlow 2.0中CycleGAN实现大法。

这个 官方教程贴 几天内收获了满满人气,获得了Google AI工程师、哥伦比亚大学数据科学研究所Josh Gordon的推荐,推特上已近600赞。

aqquYf2.jpg!web

有国外网友称赞太棒,表示很高兴看到TensorFlow 2.0教程中涵盖了最先进的模型。

这份教程全面详细,想学CycleGAN不能错过这个:

详细内容

在TensorFlow 2.0中实现CycleGAN,只要7个步骤就可以了。

1、设置输入Pipeline

安装tensorflow_examples包,用于导入生成器和鉴别器。

!pip install -q git+https: //github.com/tensorflow/examples.git

!pip install -q tensorflow-gpu==2.0.0-beta1
import tensorflow as tf
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

2、输入pipeline

在这个教程中,我们主要学习马到斑马的图像转换,如果想寻找类似的数据集,可以前往:

https://www.tensorflow.org/datasets/datasets#cycle_gan

在CycleGAN论文中也提到,将随机抖动( Jitter )和镜像应用到训练集中,这是避免过度拟合的图像增强技术。

和在Pix2Pix中的操作类似,在随机抖动中吗,图像大小被调整成286×286,然后随机裁剪为256×256。

在随机镜像中吗,图像随机水平翻转,即从左到右进行翻转。

ZviUNfQ.jpg!web

BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)
  # random mirroring
  image = tf.image.random_flip_left_right(image)
  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

fqMvYfI.jpg!web

6B7JBjZ.jpg!web

3、导入并重新使用Pix2Pix模型

通过安装tensorflow_examples包,从Pix2Pix中导入生成器和鉴别器。

这个教程中使用的模型体系结构与Pix2Pix中很类似,但也有一些差异,比如Cyclegan使用的是实例规范化而不是批量规范化,比如Cyclegan论文使用的是修改后的resnet生成器等。

我们训练两个生成器(G和F)和两个鉴别器(X和Y)。生成器G架构图像X转换为图像Y,生成器F将图像Y转换为图像X。

鉴别器D_X区分图像X和生成的图像X(F(Y)),辨别器D_Y区分图像Y和生成的图像Y(G(X))。

Iby6V3B.jpg!web

OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8
plt.subplot(221)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(222)
plt.title('To Zebra')
plt.imshow(to_zebra[0] * 0.5 * contrast + 0.5)
plt.subplot(223)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)
plt.subplot(224)
plt.title('To Horse')
plt.imshow(to_horse[0] * 0.5 * contrast + 0.5)
plt.show()

ayMveyv.jpg!web

plt.figure(figsize=(8, 8))
plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')
plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')
plt.show()

fMbM3e6.jpg!web

4、损失函数

在CycleGAN中,因为没有用于训练的成对数据,因此无法保证输入X和目标Y在训练期间是否有意义。因此,为了强制学习正确的映射,CycleGAN中提出了“循环一致性损失”(cycle consistency loss)。

鉴别器和生成器的损失与Pix2Pix中的类似。

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)
  generated_loss = loss_obj(tf.zeros_like(generated), generated)
  total_disc_loss = real_loss + generated_loss
  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

循环一致性意味着结果接近原始输入。

例如将一个句子和英语翻译成法语,再将其从法语翻译成英语后,结果与原始英文句子相同。

在循环一致性损失中,图像X通过生成器传递C产生的图像Y^,生成的图像Y^通过生成器传递F产生的图像X^,然后计算平均绝对误差X和X^。

前向循环一致性损失为:

IVN7jyj.png!web

反向循环一致性损失为:

mqyyIvJ.jpg!web

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  return LAMBDA * loss1

初始化所有生成器和鉴别器的的优化:

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

5、检查点

checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

6、训练

注意:为了使本教程的训练时间合理,本示例模型迭代次数较少(40次,论文中为200次),预测效果可能不如论文准确。

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)
  plt.figure(figsize=(12, 12))
  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']
  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

尽管训练起来很复杂,但基本的步骤只有四个,分别为:获取预测、计算损失、使用反向传播计算梯度、将梯度应用于优化程序。

@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because gen_tape and disc_tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape(
      persistent=True) as disc_tape:
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)
    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)
    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)
    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)
    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + calc_cycle_loss(real_x, cycled_x)
    total_gen_f_loss = gen_f_loss + calc_cycle_loss(real_y, cycled_y)
    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = gen_tape.gradient(total_gen_g_loss, 
                                            generator_g.trainable_variables)
  generator_f_gradients = gen_tape.gradient(total_gen_f_loss, 
                                            generator_f.trainable_variables)
  discriminator_x_gradients = disc_tape.gradient(
      disc_x_loss, discriminator_x.trainable_variables)
  discriminator_y_gradients = disc_tape.gradient(
      disc_y_loss, discriminator_y.trainable_variables)
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                             generator_g.trainable_variables))
  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                             generator_f.trainable_variables))
  discriminator_x_optimizer.apply_gradients(
      zip(discriminator_x_gradients,
      discriminator_x.trainable_variables))
  discriminator_y_optimizer.apply_gradients(
      zip(discriminator_y_gradients,
      discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()
  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1
  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)
  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))
  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

rANZV3j.jpg!web

NNbmArb.jpg!web

7、使用测试集生成图像

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

73ARjuI.jpg!web

EjM3uee.jpg!web

r6n2UnQ.jpg!web

8、进阶学习方向

在上面的教程中,我们学习了如何从Pix2Pix中实现的生成器和鉴别器进一步实现CycleGAN,接下来的学习你可以尝试使用TensorFlow中的其他数据集。

你还可以用更多次的迭代改善结果,或者实现论文中修改的ResNet生成器,进行知识点的进一步巩固。

传送门

https://www.tensorflow.org/beta/tutorials/generative/cyclegan

GitHub地址:

https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cyclegan.ipynb


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK