官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

百家 作者:量子位 2019-06-30 10:22:43
铜灵 发自 凹非寺
量子位 出品| 公众号 QbitAI

CycleGAN,一个可以将一张图像的特征迁移到另一张图像的酷算法,此前可以完成马变斑马、冬天变夏天、苹果变桔子等一颗赛艇的效果。

这行被顶会ICCV收录的研究自提出后,就为图形学等领域的技术人员所用,甚至还成为不少艺术家用来创作的工具。

也是目前大火的“换脸”技术的老前辈了。

如果你还没学会这项厉害的研究,那这次一定要抓紧上车了。

现在,TensorFlow开始手把手教你,在TensorFlow 2.0中CycleGAN实现大法。

这个官方教程贴几天内收获了满满人气,获得了Google AI工程师、哥伦比亚大学数据科学研究所Josh Gordon的推荐,推特上已近600赞。

有国外网友称赞太棒,表示很高兴看到TensorFlow 2.0教程中涵盖了最先进的模型。

这份教程全面详细,想学CycleGAN不能错过这个:

详细内容

在TensorFlow 2.0中实现CycleGAN,只要7个步骤就可以了。

1、设置输入Pipeline

安装tensorflow_examples包,用于导入生成器和鉴别器。

!pip?install?-q?git+https://github.com/tensorflow/examples.git

!pip?install?-q?tensorflow-gpu==2.0.0-beta1
import?tensorflow?as?tf


from?__future__?import?absolute_import,?division,?print_function,?unicode_literals

import?tensorflow_datasets?as?tfds
from?tensorflow_examples.models.pix2pix?import?pix2pix

import?os
import?time
import?matplotlib.pyplot?as?plt
from?IPython.display?import?clear_output

tfds.disable_progress_bar()
AUTOTUNE?=?tf.data.experimental.AUTOTUNE

2、输入pipeline

在这个教程中,我们主要学习马到斑马的图像转换,如果想寻找类似的数据集,可以前往:

https://www.tensorflow.org/datasets/datasets#cycle_gan

在CycleGAN论文中也提到,将随机抖动( Jitter )和镜像应用到训练集中,这是避免过度拟合的图像增强技术。

和在Pix2Pix中的操作类似,在随机抖动中吗,图像大小被调整成286×286,然后随机裁剪为256×256。

在随机镜像中吗,图像随机水平翻转,即从左到右进行翻转。

dataset,?metadata?=?tfds.load('cycle_gan/horse2zebra',
??????????????????????????????with_info=True,?as_supervised=True)

train_horses,?train_zebras?=?dataset['trainA'],?dataset['trainB']
test_horses,?test_zebras?=?dataset['testA'],?dataset['testB']


BUFFER_SIZE?=?1000
BATCH_SIZE?=?1
IMG_WIDTH?=?256
IMG_HEIGHT?=?256
def?random_crop(image):
??cropped_image?=?tf.image.random_crop(
??????image,?size=[IMG_HEIGHT,?IMG_WIDTH,?3])

??return?cropped_image


#?normalizing?the?images?to?[-1,?1]
def?normalize(image):
??image?=?tf.cast(image,?tf.float32)
??image?=?(image?/?127.5)?-?1
??return?image


def?random_jitter(image):
??#?resizing?to?286?x?286?x?3
??image?=?tf.image.resize(image,?[286,?286],
??????????????????????????method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

??#?randomly?cropping?to?256?x?256?x?3
??image?=?random_crop(image)

??#?random?mirroring
??image?=?tf.image.random_flip_left_right(image)

??return?image


def?preprocess_image_train(image,?label):
??image?=?random_jitter(image)
??image?=?normalize(image)
??return?image


def?preprocess_image_test(image,?label):
??image?=?normalize(image)
??return?image


train_horses?=?train_horses.map(
????preprocess_image_train,?num_parallel_calls=AUTOTUNE).cache().shuffle(
????BUFFER_SIZE).batch(1)

train_zebras?=?train_zebras.map(
????preprocess_image_train,?num_parallel_calls=AUTOTUNE).cache().shuffle(
????BUFFER_SIZE).batch(1)

test_horses?=?test_horses.map(
????preprocess_image_test,?num_parallel_calls=AUTOTUNE).cache().shuffle(
????BUFFER_SIZE).batch(1)

test_zebras?=?test_zebras.map(
????preprocess_image_test,?num_parallel_calls=AUTOTUNE).cache().shuffle(
????BUFFER_SIZE).batch(1)


sample_horse?=?next(iter(train_horses))
sample_zebra?=?next(iter(train_zebras))


plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0]?*?0.5?+?0.5)

plt.subplot(122)
plt.title('Horse?with?random?jitter')
plt.imshow(random_jitter(sample_horse[0])?*?0.5?+?0.5)


plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0]?*?0.5?+?0.5)

plt.subplot(122)
plt.title('Zebra?with?random?jitter')
plt.imshow(random_jitter(sample_zebra[0])?*?0.5?+?0.5)

3、导入并重新使用Pix2Pix模型

通过安装tensorflow_examples包,从Pix2Pix中导入生成器和鉴别器。

这个教程中使用的模型体系结构与Pix2Pix中很类似,但也有一些差异,比如Cyclegan使用的是实例规范化而不是批量规范化,比如Cyclegan论文使用的是修改后的resnet生成器等。

我们训练两个生成器(G和F)和两个鉴别器(X和Y)。生成器G架构图像X转换为图像Y,生成器F将图像Y转换为图像X。

鉴别器D_X区分图像X和生成的图像X(F(Y)),辨别器D_Y区分图像Y和生成的图像Y(G(X))。

OUTPUT_CHANNELS?=?3

generator_g?=?pix2pix.unet_generator(OUTPUT_CHANNELS,?norm_type='instancenorm')
generator_f?=?pix2pix.unet_generator(OUTPUT_CHANNELS,?norm_type='instancenorm')

discriminator_x?=?pix2pix.discriminator(norm_type='instancenorm',?target=False)
discriminator_y?=?pix2pix.discriminator(norm_type='instancenorm',?target=False)


to_zebra?=?generator_g(sample_horse)
to_horse?=?generator_f(sample_zebra)
plt.figure(figsize=(8,?8))
contrast?=?8

plt.subplot(221)
plt.title('Horse')
plt.imshow(sample_horse[0]?*?0.5?+?0.5)

plt.subplot(222)
plt.title('To?Zebra')
plt.imshow(to_zebra[0]?*?0.5?*?contrast?+?0.5)

plt.subplot(223)
plt.title('Zebra')
plt.imshow(sample_zebra[0]?*?0.5?+?0.5)

plt.subplot(224)
plt.title('To?Horse')
plt.imshow(to_horse[0]?*?0.5?*?contrast?+?0.5)

plt.show()


plt.figure(figsize=(8,?8))

plt.subplot(121)
plt.title('Is?a?real?zebra?')
plt.imshow(discriminator_y(sample_zebra)[0,?...,?-1],?cmap='RdBu_r')

plt.subplot(122)
plt.title('Is?a?real?horse?')
plt.imshow(discriminator_x(sample_horse)[0,?...,?-1],?cmap='RdBu_r')

plt.show()

4、损失函数

在CycleGAN中,因为没有用于训练的成对数据,因此无法保证输入X和目标Y在训练期间是否有意义。因此,为了强制学习正确的映射,CycleGAN中提出了“循环一致性损失”(cycle consistency loss)。

鉴别器和生成器的损失与Pix2Pix中的类似。

LAMBDA?=?10


loss_obj?=?tf.keras.losses.BinaryCrossentropy(from_logits=True)


def?discriminator_loss(real,?generated):
??real_loss?=?loss_obj(tf.ones_like(real),?real)

??generated_loss?=?loss_obj(tf.zeros_like(generated),?generated)

??total_disc_loss?=?real_loss?+?generated_loss

??return?total_disc_loss?*?0.5


def?generator_loss(generated):
??return?loss_obj(tf.ones_like(generated),?generated)


循环一致性意味着结果接近原始输入。

例如将一个句子和英语翻译成法语,再将其从法语翻译成英语后,结果与原始英文句子相同。

在循环一致性损失中,图像X通过生成器传递C产生的图像Y^,生成的图像Y^通过生成器传递F产生的图像X^,然后计算平均绝对误差X和X^。

前向循环一致性损失为:

反向循环一致性损失为:

def?calc_cycle_loss(real_image,?cycled_image):
??loss1?=?tf.reduce_mean(tf.abs(real_image?-?cycled_image))

??return?LAMBDA?*?loss1

初始化所有生成器和鉴别器的的优化:

generator_g_optimizer?=?tf.keras.optimizers.Adam(2e-4,?beta_1=0.5)
generator_f_optimizer?=?tf.keras.optimizers.Adam(2e-4,?beta_1=0.5)

discriminator_x_optimizer?=?tf.keras.optimizers.Adam(2e-4,?beta_1=0.5)
discriminator_y_optimizer?=?tf.keras.optimizers.Adam(2e-4,?beta_1=0.5)

5、检查点

checkpoint_path?=?"./checkpoints/train"

ckpt?=?tf.train.Checkpoint(generator_g=generator_g,
???????????????????????????generator_f=generator_f,
???????????????????????????discriminator_x=discriminator_x,
???????????????????????????discriminator_y=discriminator_y,
???????????????????????????generator_g_optimizer=generator_g_optimizer,
???????????????????????????generator_f_optimizer=generator_f_optimizer,
???????????????????????????discriminator_x_optimizer=discriminator_x_optimizer,
???????????????????????????discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager?=?tf.train.CheckpointManager(ckpt,?checkpoint_path,?max_to_keep=5)

#?if?a?checkpoint?exists,?restore?the?latest?checkpoint.
if?ckpt_manager.latest_checkpoint:
??ckpt.restore(ckpt_manager.latest_checkpoint)
??print?('Latest?checkpoint?restored!!')

6、训练

注意:为了使本教程的训练时间合理,本示例模型迭代次数较少(40次,论文中为200次),预测效果可能不如论文准确。
EPOCHS?=?40


def?generate_images(model,?test_input):
??prediction?=?model(test_input)

??plt.figure(figsize=(12,?12))

??display_list?=?[test_input[0],?prediction[0]]
??title?=?['Input?Image',?'Predicted?Image']

??for?i?in?range(2):
????plt.subplot(1,?2,?i+1)
????plt.title(title[i])
????#?getting?the?pixel?values?between?[0,?1]?to?plot?it.
????plt.imshow(display_list[i]?*?0.5?+?0.5)
????plt.axis('off')
??plt.show()


尽管训练起来很复杂,但基本的步骤只有四个,分别为:获取预测、计算损失、使用反向传播计算梯度、将梯度应用于优化程序。

@tf.function
def?train_step(real_x,?real_y):
??#?persistent?is?set?to?True?because?gen_tape?and?disc_tape?is?used?more?than
??#?once?to?calculate?the?gradients.
??with?tf.GradientTape(persistent=True)?as?gen_tape,?tf.GradientTape(
??????persistent=True)?as?disc_tape:

????fake_y?=?generator_g(real_x,?training=True)
????cycled_x?=?generator_f(fake_y,?training=True)

????fake_x?=?generator_f(real_y,?training=True)
????cycled_y?=?generator_g(fake_x,?training=True)

????disc_real_x?=?discriminator_x(real_x,?training=True)
????disc_real_y?=?discriminator_y(real_y,?training=True)

????disc_fake_x?=?discriminator_x(fake_x,?training=True)
????disc_fake_y?=?discriminator_y(fake_y,?training=True)

????#?calculate?the?loss
????gen_g_loss?=?generator_loss(disc_fake_y)
????gen_f_loss?=?generator_loss(disc_fake_x)

????#?Total?generator?loss?=?adversarial?loss?+?cycle?loss
????total_gen_g_loss?=?gen_g_loss?+?calc_cycle_loss(real_x,?cycled_x)
????total_gen_f_loss?=?gen_f_loss?+?calc_cycle_loss(real_y,?cycled_y)

????disc_x_loss?=?discriminator_loss(disc_real_x,?disc_fake_x)
????disc_y_loss?=?discriminator_loss(disc_real_y,?disc_fake_y)

??#?Calculate?the?gradients?for?generator?and?discriminator
??generator_g_gradients?=?gen_tape.gradient(total_gen_g_loss,?
????????????????????????????????????????????generator_g.trainable_variables)
??generator_f_gradients?=?gen_tape.gradient(total_gen_f_loss,?
????????????????????????????????????????????generator_f.trainable_variables)

??discriminator_x_gradients?=?disc_tape.gradient(
??????disc_x_loss,?discriminator_x.trainable_variables)
??discriminator_y_gradients?=?disc_tape.gradient(
??????disc_y_loss,?discriminator_y.trainable_variables)

??#?Apply?the?gradients?to?the?optimizer
??generator_g_optimizer.apply_gradients(zip(generator_g_gradients,?
?????????????????????????????????????????????generator_g.trainable_variables))

??generator_f_optimizer.apply_gradients(zip(generator_f_gradients,?
?????????????????????????????????????????????generator_f.trainable_variables))

??discriminator_x_optimizer.apply_gradients(
??????zip(discriminator_x_gradients,
??????discriminator_x.trainable_variables))

??discriminator_y_optimizer.apply_gradients(
??????zip(discriminator_y_gradients,
??????discriminator_y.trainable_variables))


for?epoch?in?range(EPOCHS):
??start?=?time.time()

??n?=?0
??for?image_x,?image_y?in?tf.data.Dataset.zip((train_horses,?train_zebras)):
????train_step(image_x,?image_y)
????if?n?%?10?==?0:
??????print?('.',?end='')
????n+=1

??clear_output(wait=True)
??#?Using?a?consistent?image?(sample_horse)?so?that?the?progress?of?the?model
??#?is?clearly?visible.
??generate_images(generator_g,?sample_horse)

??if?(epoch?+?1)?%?5?==?0:
????ckpt_save_path?=?ckpt_manager.save()
????print?('Saving?checkpoint?for?epoch?{}?at?{}'.format(epoch+1,
?????????????????????????????????????????????????????????ckpt_save_path))

??print?('Time?taken?for?epoch?{}?is?{}?secn'.format(epoch?+?1,
??????????????????????????????????????????????????????time.time()-start))


7、使用测试集生成图像

#?Run?the?trained?model?on?the?test?dataset
for?inp?in?test_horses.take(5):
??generate_images(generator_g,?inp)

8、进阶学习方向

在上面的教程中,我们学习了如何从Pix2Pix中实现的生成器和鉴别器进一步实现CycleGAN,接下来的学习你可以尝试使用TensorFlow中的其他数据集。

你还可以用更多次的迭代改善结果,或者实现论文中修改的ResNet生成器,进行知识点的进一步巩固。

传送门

https://www.tensorflow.org/beta/tutorials/generative/cyclegan

GitHub地址:
https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cyclegan.ipynb

作者系网易新闻·网易号“各有态度”签约作者

AI社群 | 与优秀的人交流

小程序 | 全类别AI学习教程

量子位?QbitAI · 头条号签约作者

?'?' ? 追踪AI技术和产品新动态

喜欢就点「好看」吧 !?

关注公众号:拾黑(shiheibook)了解更多

[广告]赞助链接:

四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/

公众号 关注网络尖刀微信公众号
随时掌握互联网精彩
赞助链接