【开发】 用GAN来做图像生成,这是最好的方法

2017 年 8 月 9 日 GAN生成式对抗网络

来源:AI科技评论

作者: 天雨粟,原文载于作者的知乎专栏——机器不学习


前言

在我们之前的文章中,我们学习了如何构造一个简单的 GAN 来生成 MNIST 手写图片。对于图像问题,卷积神经网络相比于简单地全连接的神经网络更具优势,因此,我们这一节我们将继续深入 GAN,通过融合卷积神经网络来对我们的 GAN 进行改进,实现一个深度卷积 GAN。如果还没有亲手实践过 GAN 的小伙伴可以先去学习一下上一篇专栏:生成对抗网络(GAN)之 MNIST 数据生成。

专栏中的所有代码都在我的 GitHub中,欢迎 star 与 fork。

本次代码在 NELSONZHAO/zhihu/dcgan,里面包含了两个文件:

  • dcgan_mnist:基于 MNIST 手写数据集构造深度卷积 GAN 模型

  • dcgan_cifar:基于 CIFAR 数据集构造深度卷积 GAN 模型

本文主要以 MNIST 为例进行介绍,两者在本质上没有差别,只在细微的参数上有所调整。由于穷学生资源有限,没有对模型增加迭代次数,也没有构造更深的模型。并且也没有选取像素很高的图像,高像素非常消耗计算量。本节只是一个抛砖引玉的作用,让大家了解 DCGAN 的结构,如果有资源的小伙伴可以自己去尝试其他更清晰的图片以及更深的结构,相信会取得很不错的结果。

工具

  • Python3

  • TensorFlow 1.0

  • Jupyter notebook

正文

整个正文部分将包括以下部分:

- 数据加载

- 模型输入

- Generator

- Discriminator

- Loss

- Optimizer

- 训练模型

- 可视化

数据加载

数据加载部分采用 TensorFlow 中的 input_data 接口来进行加载。关于加载细节在前面的文章中已经写了很多次啦,相信看过我文章的小伙伴对 MNIST 加载也非常熟悉,这里不再赘述。

模型输入

在 GAN 中,我们的输入包括两部分,一个是真实图片,它将直接输入给 discriminator 来获得一个判别结果;另一个是随机噪声,随机噪声将作为 generator 来生成图片的材料,generator 再将生成图片传递给 discriminator 获得一个判别结果。

上面的函数定义了输入图片与噪声图片两个 tensor。

Generator

生成器接收一个噪声信号,基于该信号生成一个图片输入给判别器。在上一篇专栏文章生成对抗网络(GAN)之 MNIST 数据生成中,我们的生成器是一个全连接层的神经网络,而本节我们将生成器改造为包含卷积结构的网络,使其更加适合处理图片输入。整个生成器结构如下:

我们采用了 transposed convolution 将我们的噪声图片转换为了一个与输入图片具有相同 shape 的生成图像。我们来看一下具体的实现代码:

上面的代码是整个生成器的实现细节,里面包含了一些 trick,我们来一步步地看一下。

首先我们通过一个全连接层将输入的噪声图像转换成了一个 1 x 4*4*512 的结构,再将其 reshape 成一个 [batch_size, 4, 4, 512] 的形状,至此我们其实完成了第一步的转换。接下来我们使用了一个对加速收敛及提高卷积神经网络性能中非常有效的方法——加入 BN(batch normalization),它的思想是归一化当前层输入,使它们的均值为 0 和方差为 1,类似于我们归一化网络输入的方法。它的好处在于可以加速收敛,并且加入 BN 的卷积神经网络受权重初始化影响非常小,具有非常好的稳定性,对于提升卷积性能有很好的效果。关于 batch normalization,我会在后面专栏中进行一个详细的介绍。

完成 BN 后,我们使用 Leaky ReLU 作为激活函数,在上一篇专栏中我们已经提过这个函数,这里不再赘述。最后加入 dropout 正则化。剩下的 transposed convolution 结构层与之类似,只不过在最后一层中,我们不采用 BN,直接采用 tanh 激活函数输出生成的图片。

在上面的 transposed convolution 中,很多小伙伴肯定会对每一层 size 的变化疑惑,在这里来讲一下在 TensorFlow 中如何来计算每一层 feature map 的 size。首先,在卷积神经网络中,假如我们使用一个 k x k 的 filter 对 m x m x d 的图片进行卷积操作,strides 为 s,在 TensorFlow 中,当我们设置 padding='same'时,卷积以后的每一个 feature map 的 height 和 width 为;当设置 padding='valid'时,每一个 feature map 的 height 和 width 为。那么反过来,如果我们想要进行 transposed convolution 操作,比如将 7 x 7 的形状变为 14 x 14,那么此时,我们可以设置 padding='same',strides=2 即可,与 filter 的 size 没有关系;而如果将 4 x 4 变为 7 x 7 的话,当设置 padding='valid'时,即,此时 s=1,k=4 即可实现我们的目标。

上面的代码中我也标注了每一步 shape 的变化。

Discriminator

Discriminator 接收一个图片,输出一个判别结果(概率)。其实 Discriminator 完全可以看做一个包含卷积神经网络的图片二分类器。结构如下:

实现代码如下:

上面代码其实就是一个简单的卷积神经网络图像识别问题,最终返回 logits(用来计算 loss)与 outputs。这里没有加入池化层的原因在于图片本身经过多层卷积以后已经非常小了,并且我们加入了 batch normalization 加速了训练,并不需要通过 max pooling 来进行特征提取加速训练。

Loss Function

Loss 部分分别计算 Generator 的 loss 与 Discriminator 的 loss,和之前一样,我们加入 label smoothing 防止过拟合,增强泛化能力。

Optimizer

GAN 中实际包含了两个神经网络,因此对于这两个神经网络要分开进行优化。代码如下:

这里的 Optimizer 和我们之前不同,由于我们使用了 TensorFlow 中的 batch normalization 函数,这个函数中有很多 trick 要注意。首先我们要知道,batch normalization 在训练阶段与非训练阶段的计算方式是有差别的,这也是为什么我们在使用 batch normalization 过程中需要指定 training 这个参数。上面使用 tf.control_dependencies 是为了保证在训练阶段能够一直更新 moving averages。具体参考 A Gentle Guide to Using Batch Normalization in Tensorflow - Rui Shu。

训练

到此为止,我们就完成了深度卷积 GAN 的构造,接着我们可以对我们的 GAN 来进行训练,并且定义一些辅助函数来可视化迭代的结果。代码太长就不放上来了,可以直接去我的 GitHub 下载。

我这里只设置了 5 轮 epochs,每隔 100 个 batch 打印一次结果,每一行代表同一个 epoch 下的 25 张图:

我们可以看出仅仅经过了少部分的迭代就已经生成非常清晰的手写数字,并且训练速度是非常快的。

上面的图是最后几次迭代的结果。我们可以回顾一下上一篇的一个简单的全连接层的 GAN,收敛速度明显不如深度卷积 GAN。

总结

到此为止,我们学习了一个深度卷积 GAN,并且看到相比于之前简单的 GAN 来说,深度卷积 GAN 的性能更加优秀。当然除了 MNST 数据集以外,小伙伴儿们还可以尝试很多其他图片,比如我们之前用到过的 CIFAR 数据集,我在这里也实现了一个 CIFAR 数据集的图片生成,我只选取了马的图片进行训练:

刚开始训练时:

训练 50 个 epochs:

这里我只设置了 50 次迭代,可以看到最后已经生成了非常明显的马的图像,可见深度卷积 GAN 的优势。

我的 GitHub:NELSONZHAO (Nelson Zhao)

上面包含了我的专栏中所有的代码实现,欢迎 star,欢迎 fork。


高质量延伸阅读

☞ 【学界】邢波团队提出contrast-GAN:实现生成式语义处理

☞  【专栏】阿里SIGIR 2017论文:GAN在信息检索领域的应用

☞ 【学界】康奈尔大学说对抗样本出门会失效,被OpenAI怼回来了!

☞ 警惕人工智能系统中的木马、病毒 ——深度学习对抗样本简介

☞ 【生成图像】Facebook发布的LR-GAN如何生成图像?这里有一篇Pytorch教程

☞ 【智能自动化学科前沿讲习班第1期】国立台湾大学(位于中国台北)李宏毅教授:Anime Face Generation

☞ 【变狗为猫】伯克利图像迁移cycleGAN,猫狗互换效果感人

☞ 【论文】对抗样本到底会不会对无人驾驶目标检测产生干扰?又有人发文质疑了

【智能自动化学科前沿讲习班第1期】王飞跃教授:生成式对抗网络GAN的研究进展与展望

【开发】看完立刻理解GAN!初学者也没关系

【专栏】基于对抗学习的生成式对话模型的坚实第一步 :始于直观思维的曲折探索

☞ 【重磅】平行将成为一种常态:从SimGAN获得CVPR 2017最佳论文奖说起

☞ 【最新】OpenAI:3段视频演示无人驾驶目标检测强大的对抗性样本!

☞  【干货】生成对抗网络(GAN)之MNIST数据生成

☞ 【论文】CVPR 2017最佳论文出炉,DenseNet和苹果首篇论文获奖

☞   AI侦探敲碎深度学习黑箱

☞ 【深度学习】解析深度学习的局限性与未来,谷歌Keras之父「连发两文」发人深省

☞   苹果重磅推出AI技术博客,CVPR合成逼真照片论文打响第一枪

☞ 【Ian Goodfellow 五问】GAN、深度学习,如何与谷歌竞争

☞ 【巨头升级寡头】AI产业数据称王,GAN和迁移学习能否突围BAT垄断?

☞ 【高大上的DL】BEGAN: Boundary Equilibrium GAN

☞ 【最详尽的GAN介绍】王飞跃等:生成式对抗网络 GAN 的研究进展与展望

☞ 【最全GAN变体列表】Ian Goodfellow推荐:GAN动物园

☞   二十世纪的十大科学骗局

☞ 【DCGAN】深度卷积生成对抗网络的无监督学习,补全人脸合成图像匹敌真实照片

【学界】让莫奈画作变成照片:伯克利图像到图像翻译新研究

☞ 【DualGAN】对偶学习的生成对抗网络

☞ 【开源】收敛速度更快更稳定的Wasserstein GAN(WGAN)

☞ 【Valse 2017】生成对抗网络(GAN)研究年度进展评述

☞ 【开源】谷歌新推BEGAN模型用于人脸数据集:效果惊人!

☞ 【深度】Ian Goodfellow AIWTB开发者大会演讲:对抗样本与差分隐私

☞   论文引介 | StackGAN: Stacked Generative Adversarial Networks

☞ 【专题GAN】GAN应用情况调研

☞ 【纵览】从自编码器到生成对抗网络:一文纵览无监督学习研究现状

☞ 【论文解析】Ian Goodfellow 生成对抗网络GAN论文解析

☞ 【VALSE 前沿】利用对抗学习改进目标检测的结果

☞ 【干货】全面分析GAN,以及如何用TF实现GAN?

☞   苹果首份AI论文横空出世,提出SimGAN训练方法

☞ 【推荐】条条大路通罗马LS-GAN:把GAN建立在Lipschitz密度上

☞   到底什么是生成式对抗网络GAN?

☞   看穿机器学习(W-GAN模型)的黑箱

☞   看穿机器学习的黑箱(II)

【Geometric GAN】引入线性分类器SVM的Geometric GAN

☞ 【征稿】“生成式对抗网络GAN技术与应用”专刊

☞ 【GAN for NLP】PaperWeekly 第二十四期 --- GAN for NLP

☞ 【学界 】 从感知机到GAN,机器学习简史梳理

☞ 【Demo】GAN学习指南:从原理入门到制作生成Demo

☞ 【学界】伯克利与OpenAI整合强化学习与GAN:让智能体学习自动发现目标

☞ 【解读】通过拳击学习生成对抗网络(GAN)的基本原理

☞ 【人物 】Ian Goodfellow亲述GAN简史:人工智能不能理解它无法创造的东西

☞ 【DCGAN】DCGAN: 一类稳定的GANs

☞ 【DCGAN】DCGAN:深度卷积生成对抗网络的无监督学习,补全人脸合成图像匹敌真实照片

☞ 【原理】 直观理解GAN背后的原理:以人脸图像生成为例

☞ 【干货】深入浅出 GAN·原理篇文字版(完整)

☞   带你理解CycleGAN,并用TensorFlow轻松实现

☞   PaperWeekly 第39期 | 从PM到GAN - LSTM之父Schmidhuber横跨22年的怨念

☞ 【CycleGAN】加州大学开源图像处理工具CycleGAN

☞ 【SIGIR2017满分论文】IRGAN:大一统信息检索模型的博弈竞争

☞ 【贝叶斯GAN】贝叶斯生成对抗网络(GAN):当下性能最好的端到端半监督/无监督学习

☞ 【贝叶斯GAN】贝叶斯生成对抗网络(GAN):当下性能最好的端到端半监督/无监督学习

☞ 【GAN X NLP】自然语言对抗生成:加拿大研究员使用GAN生成中国古诗词

☞    ICLR 2017 | GAN Missing Modes 和 GAN

☞ 【论文汇总】生成对抗网络及其变体

☞ 【AI】未来AI这样帮你一键修片,那还有PS什么事?

☞ 【学界】CMU新研究试图统一深度生成模型:搭建GAN和VAE之间的桥梁

☞ 【专栏】大漠孤烟,长河落日:面向景深结构的风景照生成技术

☞ 【开发】最简单易懂的 GAN 教程:从理论到实践(附代码)

☞ 【论文访谈】求同存异,共创双赢 - 基于对抗网络的利用不同分词标准语料的中文分词方法

☞ 【LeCun论战Yoav】自然语言GAN惹争议:深度学习远离NLP?

☞ 【争论】从Yoav Goldberg与Yann LeCun争论,看当今的深度学习、NLP与arXiv风气

☞ 【观点】Yoav Goldberg撰文再回应Yann LeCun:「深度学习这群人」不了解NLP(附各方评论)

☞   PaperWeekly 第41期 | 互怼的艺术:从零直达 WGAN-GP

☞ 【业界】CMU和谷歌联手研制左右互搏的对抗性机器人

☞ 【谷歌 GAN 生成人脸】对抗创造新艺术风格,128 像素扩展到 4000

☞ 【原理】模拟上帝之手的对抗博弈——GAN背后的数学原理

☞ 【原理】只知道GAN你就OUT了——VAE背后的哲学思想及数学原理

登录查看更多
点赞 0

Most of the recent successful methods in accurate object detection build on the convolutional neural networks (CNN). However, due to the lack of scale normalization in CNN-based detection methods, the activated channels in the feature space can be completely different according to a scale and this difference makes it hard for the classifier to learn samples. We propose a Scale Aware Network (SAN) that maps the convolutional features from the different scales onto a scale-invariant subspace to make CNN-based detection methods more robust to the scale variation, and also construct a unique learning method which considers purely the relationship between channels without the spatial information for the efficient learning of SAN. To show the validity of our method, we visualize how convolutional features change according to the scale through a channel activation matrix and experimentally show that SAN reduces the feature differences in the scale space. We evaluate our method on VOC PASCAL and MS COCO dataset. We demonstrate SAN by conducting several experiments on structures and parameters. The proposed SAN can be generally applied to many CNN-based detection methods to enhance the detection accuracy with a slight increase in the computing time.

点赞 0
阅读1+

Images are not simply sets of objects: each image represents a web of interconnected relationships. These relationships between entities carry semantic meaning and help a viewer differentiate between instances of an entity. For example, in an image of a soccer match, there may be multiple persons present, but each participates in different relationships: one is kicking the ball, and the other is guarding the goal. In this paper, we formulate the task of utilizing these "referring relationships" to disambiguate between entities of the same category. We introduce an iterative model that localizes the two entities in the referring relationship, conditioned on one another. We formulate the cyclic condition between the entities in a relationship by modelling predicates that connect the entities as shifts in attention from one entity to another. We demonstrate that our model can not only outperform existing approaches on three datasets --- CLEVR, VRD and Visual Genome --- but also that it produces visually meaningful predicate shifts, as an instance of interpretable neural networks. Finally, we show that by modelling predicates as attention shifts, we can even localize entities in the absence of their category, allowing our model to find completely unseen categories.

点赞 0
阅读1+

Convolutional neural networks (CNNs) are usually built by stacking convolutional operations layer-by-layer. Although CNN has shown strong capability to extract semantics from raw pixels, its capacity to capture spatial relationships of pixels across rows and columns of an image is not fully explored. These relationships are important to learn semantic objects with strong shape priors but weak appearance coherences, such as traffic lanes, which are often occluded or not even painted on the road surface as shown in Fig. 1 (a). In this paper, we propose Spatial CNN (SCNN), which generalizes traditional deep layer-by-layer convolutions to slice-byslice convolutions within feature maps, thus enabling message passings between pixels across rows and columns in a layer. Such SCNN is particular suitable for long continuous shape structure or large objects, with strong spatial relationship but less appearance clues, such as traffic lanes, poles, and wall. We apply SCNN on a newly released very challenging traffic lane detection dataset and Cityscapse dataset. The results show that SCNN could learn the spatial relationship for structure output and significantly improves the performance. We show that SCNN outperforms the recurrent neural network (RNN) based ReNet and MRF+CNN (MRFNet) in the lane detection dataset by 8.7% and 4.6% respectively. Moreover, our SCNN won the 1st place on the TuSimple Benchmark Lane Detection Challenge, with an accuracy of 96.53%.

点赞 0
阅读1+

We address the task of entity-relationship (E-R) retrieval, i.e, given a query characterizing types of two or more entities and relationships between them, retrieve the relevant tuples of related entities. Answering E-R queries requires gathering and joining evidence from multiple unstructured documents. In this work, we consider entity and relationships of any type, i.e, characterized by context terms instead of pre-defined types or relationships. We propose a novel IR-centric approach for E-R retrieval, that builds on the basic early fusion design pattern for object retrieval, to provide extensible entity-relationship representations, suitable for complex, multi-relationships queries. We performed experiments with Wikipedia articles as entity representations combined with relationships extracted from ClueWeb-09-B with FACC1 entity linking. We obtained promising results using 3 different query collections comprising 469 E-R queries.

点赞 0
阅读2+

This paper presents a framework for localization or grounding of phrases in images using a large collection of linguistic and visual cues. We model the appearance, size, and position of entity bounding boxes, adjectives that contain attribute information, and spatial relationships between pairs of entities connected by verbs or prepositions. Special attention is given to relationships between people and clothing or body part mentions, as they are useful for distinguishing individuals. We automatically learn weights for combining these cues and at test time, perform joint inference over all phrases in a caption. The resulting system produces state of the art performance on phrase localization on the Flickr30k Entities dataset and visual relationship detection on the Stanford VRD dataset.

点赞 0
阅读1+
小贴士
Top