【干货】对抗自编码器PyTorch手把手实战系列——PyTorch实现对抗自编码器

即使是非计算机行业, 大家也知道很多有名的神经网络结构, 比如CNN在处理图像上非常厉害, RNN能够建模序列数据. 然而CNN, RNN之类的神经网络结构本身, 并不能用于执行比如图像的内容和风格分离, 生成一个逼真的图片, 用少量的label信息来分类图像, 或者做数据压缩等任务. 因为上述几个任务, 都需要特殊的网络结构和训练算法 .


有没有一个网络结构, 能够把上述任务全搞定呢? 显然是有的, 那就是对抗自编码器Adversarial Autoencoder(AAE) . 在本文中, 我们将构建一个AAE, 来压缩数据, 分离图像的内容和风格, 用少量样本来分类图像, 然后生成它们。

本系列文章, 专知小组成员Huaiwen一共分成四篇讲解,这是第二篇:


PyTorch实现对抗自编码器


1.对抗自编码器




常规的Autoencoder长这样:

上回, 我们说, 随意取一个隐变量, 比如 传给解码器, 结果解码器生成不了一个有意义的图像. 

这是因为Autoencoder, 其实是在做是哈希(个人理解). 比方说把菜和菜单上的编号对应起来, 那么我们去饭店点菜, 如果你常来的话, 你会告诉服务员, 我要1318号菜, 然后, 人家给你端上来一盘烤鸭. 某一天, 你突发奇想, 说我要00号菜, 人家没有, 只好拿菜单里有的, 尽量靠近00号的几个菜拼一拼, 那这菜能不能吃就两说了.


所以, 本篇文章, 我们想要, 强制让Encoder 的输出(即隐变量) , 服从某种分布, 比如正态分布. 那么, 经过大量的学习之后, Autoencoder会学到一些必要的知识(强行拟合). 这个时候只要随便给一个从这个分布里采样出来的隐变量, Decoder都会生成一个相对合理的图片. 

 

那么, 问题来了, 我们怎么强制让隐变量h 服从某种分布?


GAN这个时候就上场了. 我这里简单介绍一下GAN(对抗生成网络)

我们从数据库里, 抽出一些real images 作为样本集A, 用生成器Generator以一些随机数为输入, 生成一些样本, 作为样本集B, 把样本集A, B混起来, 交给判别器Discriminator, 让判别器判断, 哪个样本是Real(0), 哪个样本是Fake(0). 生成器致力于生成高仿样本(让判别器认为它生成的都是真的), 判别器致力于成为鉴宝专家, 不放过一点蛛丝马迹. 他们就这样生成和对抗着, 直到生成器生成的样本, 判别器都认为是真的为止. 训练结束后, 一般我们会用生成器来生成我们想要的样本.


把GAN的思想加到Autoencoder里来, 是为了让隐变量 尽量像是真的从某种分布出采样的.

对比一下GAN, 我们发现, 在AAE(Adversarial Autoencoder)里, Autoencoder承担的职能, 就是GAN中的生成器Generator. 其中隐变量, 我们这里称为 对应于GAN, 就是Fake image. 那么, Real Image 就是真的从正太分布(随你选什么分布)中采样出来的. 然后, 把它们混合, 送个判别器Discriminator去判断. 训练步骤分两步: 1. 训练判别器Discriminator, 然后反向传播, 2. 将生成器和判别器连起来, 然后反向传播, 注意判别器参数需要固定住。


简单解释一下上图的符号: 即输入样本,  是Encoder, 即给定输入样本产生的过程. 是隐变量, 它是从 里采样出来的. 是真是样本, 是从你设定的分布中采样出来的. 是Decoder, 即给定隐变量产生的过程, 而是Decoder重建出来的样本. 是判别器Discriminator.

 

那么, 剩下的事情就比较清晰了, 我们的目标是: 1. 要尽量相近, Autoencoder要发挥自己的能力. 2. GAN也要发挥自己的能力, 所以也要尽量相近.

光看Autoencoder. 需要最小化, 跟下面的结构没啥关系.

光看下半部分, 这里需要训练判别器Discriminator

 然后将生成器和判别器连起来, 更新生成器(这里是Encoder)


2.PyTorch实现




直接上代码:

import一些包

import torch
from torch import mean, log, rand
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
import torchvision.datasets as dsets
import torchvision.transforms as transforms

我们计划让z服从高斯分布, 也即 


Encoder模型:

# q(z|x)
class Q_net(nn.Module):
def __init__(self,X_dim,N,z_dim):
super(Q_net, self).__init__()
self.lin1 = nn.Linear(X_dim, N)
self.lin2 = nn.Linear(N, N)
self.lin3_gauss = nn.Linear(N, z_dim)
def forward(self, x):
x = F.dropout(self.lin1(x), p=0.25, training=self.training)
x = F.relu(x)
x = F.dropout(self.lin2(x), p=0.25, training=self.training)
x = F.relu(x)
z_gauss = self.lin3_gauss(x)
return z_gauss


Decoder模型:

# p(x|z)
class P_net(nn.Module):
def __init__(self,X_dim,N,z_dim):
super(P_net, self).__init__()
self.lin1 = nn.Linear(z_dim, N)
self.lin2 = nn.Linear(N, N)
self.lin3 = nn.Linear(N, X_dim)
def forward(self, x):
x = F.dropout(self.lin1(x), p=0.25, training=self.training)
x = F.relu(x)
x = F.dropout(self.lin2(x), p=0.25, training=self.training)
x = self.lin3(x)
return F.sigmoid(x)


判别器Discriminator:

# D()
class D_net_gauss(nn.Module):
def __init__(self,N,z_dim):
super(D_net_gauss, self).__init__()
self.lin1 = nn.Linear(z_dim, N)
self.lin2 = nn.Linear(N, N)
self.lin3 = nn.Linear(N, 1)
def forward(self, x):
x = F.dropout(self.lin1(x), p=0.2, training=self.training)
x = F.relu(x)
x = F.dropout(self.lin2(x), p=0.2, training=self.training)
x = F.relu(x)
return F.sigmoid(self.lin3(x))


那么, 用Torchvision加载下MNIST的数据:

# MNIST Dataset 
dataset = dsets.MNIST(root='./data',
                     train=True,
                     transform=transforms.ToTensor(),  
                     download=True)

# Data Loader (Input Pipeline)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=100,
                                         shuffle=True)


定义一些辅助函数, 让代码更简洁:

def to_np(x):
return x.data.cpu().numpy()

def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return V(x)


对模型进行配置:

EPS = 1e-15
# 学习率
gen_lr = 0.0001
reg_lr = 0.00005
# 隐变量的维度
z_red_dims = 120
# encoder
Q = Q_net(784,1000,z_red_dims).cuda()
# decoder
P = P_net(784,1000,z_red_dims).cuda()
# discriminator
D_gauss = D_net_gauss(500,z_red_dims).cuda()


#encode/decode 优化器
optim_P = torch.optim.Adam(P.parameters(), lr=gen_lr)
optim_Q_enc = torch.optim.Adam(Q.parameters(), lr=gen_lr)
# GAN部分优化器
optim_Q_gen = torch.optim.Adam(Q.parameters(), lr=reg_lr)
optim_D = torch.optim.Adam(D_gauss.parameters(), lr=reg_lr)


实现训练步骤:

# 数据迭代器
data_iter = iter(data_loader)
iter_per_epoch = len(data_loader)
total_step = 50000

for step in range(total_step):

if (step+1) % iter_per_epoch == 0:
data_iter = iter(data_loader)

# 从MNSIT数据集中拿样本
   images, labels = next(data_iter)
images, labels = to_var(images.view(images.size(0), -1)), to_var(labels)

# 把这三个模型的累积梯度清空
   P.zero_grad()
Q.zero_grad()
D_gauss.zero_grad()
################ Autoencoder部分 ######################
   # encoder 编码x, 生成z
   z_sample = Q(images)
# decoder 解码z, 生成x'
   X_sample = P(z_sample)
# 这里计算下autoencoder 的重建误差|x' - x|
   recon_loss = F.binary_cross_entropy(X_sample + EPS, images + EPS)

# 优化autoencoder
   recon_loss.backward()
optim_P.step()
optim_Q_enc.step()

################ GAN 部分 #############################

   # 从正太分布中, 采样real gauss(真-高斯分布样本点)
   z_real_gauss = V(randn(images.size()[0], z_red_dims) * 5.).cuda()
# 判别器判别一下真的样本, 得到loss
   D_real_gauss = D_gauss(z_real_gauss)

# 用encoder 生成假样本
   Q.eval() # 切到测试形态, 这时候, Q(即encoder)不参与优化
   z_fake_gauss = Q(images)
# 用判别器判别假样本, 得到loss
   D_fake_gauss = D_gauss(z_fake_gauss)

# 判别器总误差
   D_loss = -mean(log(D_real_gauss + EPS) + log(1 - D_fake_gauss + EPS))

# 优化判别器
   D_loss.backward()
optim_D.step()

# encoder充当生成器
   Q.train() # 切换训练形态, Q(即encoder)参与优化
   z_fake_gauss = Q(images)
D_fake_gauss = D_gauss(z_fake_gauss)

G_loss = -mean(log(D_fake_gauss + EPS))

G_loss.backward()
# 仅优化Q
   optim_Q_gen.step()

# 训练结束后, 存一下encoder的参数
torch.save(Q.state_dict(), 'Q_encoder_weights.pt')


那么, 让我们看一下效果:

上半部分是我们想要的分布, 下边是经过训练之后的分布, 你可以看到, 随着迭代次数的增加, 它们越来越像.

生成器的Loss在下降, 判别器越来越难以判断谁是真的谁是假的.

-END-

专 · 知

人工智能领域主题知识资料查看获取【专知荟萃】人工智能领域26个主题知识资料全集(入门/进阶/论文/综述/视频/专家等)

请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料

请扫一扫如下二维码关注我们的公众号,获取人工智能的专业知识!

请加专知小助手微信(Rancho_Fang),加入专知主题人工智能群交流!加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)交流~

点击“阅读原文”,使用专知

展开全文
Top
微信扫码咨询专知VIP会员