50行PyTorch代码实现生成对抗网络(GANs)

【导读】这是一份非常简单的PyTorch实现GAN教程和代码。文中另附有TensorFlow实现版本。


作者 | Dev Nag

编译 | Xiaowen



2014年,蒙特利尔大学的Ian Goodfellow和他的同事发表了一篇令人惊叹的论文,向世界介绍了生成对抗网络GAN。通过计算图和博弈论的创新结合,他们表明,如果建模能力足够强,两个相互对抗的模型将能够通过普通的反向传播进行协同训练。


模型扮演两个截然不同的角色(也就是对抗)。给定一些真实的数据集R,G是生成器(Generator),试图创建看起来像真实数据的假数据,而D是判别器(Discriminator),从真实的集合或G中获取数据并标记差异。Goodfellow的比喻是,G就像是一个伪造者试图把真实的画与他们的输出相匹配,而D则是侦探的团队,试图分辨出不同之处。(除了在这种情况下,伪造者永远无法看到原始数据,只有D的判断——他们就像盲伪造者。)



在理想情况下,随着时间的推移,D和G都会变得更好,直到G本质上成为真正物品的“主伪造者”,而D则不知所措,“无法区分这两种分布”。


在实践中, Goodfellow已经证明,G能够在原始数据集上执行一种形式的无监督学习,找到某种方式以(可能)更低维的方式来表示这些数据。正如Yann LeCun所指出的,无监督学习是真正的AI的“蛋糕”。



这种强大的技术似乎需要一吨的代码才可以开始,对吧?不。使用PyTorch,我们实际上可以在50行代码中创建一个非常简单的GAN。实际上只有5个因素需要考虑:

  • R:原始的真实数据集

  • I:进入生成器的随机噪声

  • G:试图复制/模仿原始数据集的生成器

  • D:试图区分G的输出与真实的R的判别器

  • Loop:实际的“训练”循环,我们教G来欺骗D,D来小心G。


(1)R:在我们的例子中,我们将从最简单的R (钟形曲线)开始。该函数采用均值和标准差,并返回一个函数,该函数提供了具有这些参数的高斯样本数据的正确形状。在我们的样本代码中,我们将使用平均值为4.0,标准差为1.25的数据。



(2)I:对生成器的输入也是随机的,但是为了使我们的工作有些难度,让我们使用统一的分布而不是普遍的分布。这意味着我们的模型G不能简单地移动/缩放输入来复制R,而是必须以非线性的方式重塑数据。


 


(3)G:生成器是标准前馈图——两个隐藏层,三个线性映射。我们使用ELU(指数线性单元)。G将从 I 获得均匀分布的数据样本并且以某种方式模拟来自R的正态分布的样本。



(4)D:判别码与G的生成码非常相似;一个包含两个隐藏层和三个线性映射的前馈图。它将从R或G中获取样本,并输出0到1之间的单个标量,解释为‘假’与‘真’。这是神经网络所能得到的最大限度的误差。



(5)最后,训练循环在两种模式之间交替进行:第一次用准确的标签训练D关于真实数据vs假数据;然后用不准确的标签来训练G以愚弄D。

 


即使你以前没见过PyTorch,你也可能知道上图代码的结构。在第一个(绿色)部分,我们把两种类型的数据都给D,并对D的猜测和实际的标签应用一个可微的标准。然后我们显式地调用‘back()’来计算梯度,用于更新d_optimizer.step()中的参数。G是有使用的,但是这里没有训练。


然后,在最后一节(红色)中,我们对G 做了同样的操作,注意,我们也在D中运行G的输出(我们实际上是给伪造者一个测试来练习),但是我们没有在这一步优化或更改D。我们不希望侦探D学习错误的标签。因此,我们只调用g_optimizer.step()。


仅此而已。还有其他一些示例代码,但GAN特有的东西只是这5个组件,没有别的。



在D和G之间进行了几千轮的训练之后,我们得到了什么呢?判别器D很快就好了(G在缓慢地上升),但是一旦它达到了一定的能力水平,G就有了一个值得尊敬的对手,并开始快速改进和提高。


超过20,000次训练回合,G的输出平均值超过4.0,然后回到一个相当稳定的正确范围(左)。同样,标准偏差最初是向错误的方向下降,然后上升到预期的1.25左右(右),匹配R。


 


让我们来展示G生成的最终分布。

 


还不错诶。左边的尾巴比右边长一点,但是偏态和峰态看起来应该是高斯分布了。


G几乎完全拟合了原始的数据分布R,而D正在角落里瑟瑟发抖,无法区分G和R。这正是我们想要的。


本文代码在这儿[1]。


最后,提供给大家一些参考资料。Goodfellow的其他GAN工作[2],包括这里适用的小型批处理识别方法。另外还有NIPS2016上一个两小时的演讲教程[3]。对于TensorFlow的用户来说,这里也有一份教程[4]。


参考链接:

1. https://github.com/devnag/pytorch-generative-adversarial-networks

2. https://arxiv.org/pdf/1606.03498.pdf

3. https://channel9.msdn.com/Events/Neural-Information-Processing-Systems-Conference/Neural-Information-Processing-Systems-Conference-NIPS-2016/Generative-Adversarial-Networks

4. http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/


原文链接:

https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f


-END-

专 · 知


人工智能领域26个主题知识资料全集获取加入专知人工智能服务群: 欢迎微信扫一扫加入专知人工智能知识星球群,获取专业知识教程视频资料和与专家交流咨询!




请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料!




请加专知小助手微信(扫一扫如下二维码添加),加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)交流~




请关注专知公众号,获取人工智能的专业知识!

点击“阅读原文”,使用

展开全文
Top
微信扫码咨询专知VIP会员