人脸变漫画脸!AI 教你轻松 Pick 胡歌漫画头像

2020 年 5 月 9 日 CSDN

作者 | 李秋键
责编 | Carol
出品 | AI科技大本营(ID:rgznai100)

近几天一个GitHub项目火遍了朋友圈,那就是卡通头像AI生成小程序。如下图所见:

而这个项目的基本原理是用Python搭建的GAN算法模型,进行训练得出。

而所谓的GAN就是指生成对抗网络深度学习模型。网络中有生成器G(generator)和鉴别器(Discriminator)。有两个数据域分别为X,Y。G 负责把X域中的数据拿过来拼命地模仿成真实数据并把它们藏在真实数据中,而 D 就拼命地要把伪造数据和真实数据分开。经过二者的博弈以后,G 的伪造技术越来越厉害,D 的鉴别技术也越来越厉害。直到 D 再也分不出数据是真实的还是 G 生成的数据的时候,这个对抗的过程达到一个动态的平衡。

而CycleGAN本质上是两个镜像对称的GAN,构成了一个环形网络。

两个GAN共享两个生成器,并各自带一个判别器,即共有两个判别器和两个生成器。一个单向GAN两个loss,两个即共四个loss。 

可以实现无配对的两个图片集的训练是CycleGAN与Pixel2Pixel相比的一个典型优点。但是我们仍然需要通过训练创建这个映射来确保输入图像和生成图像间存在有意义的关联,即输入输出共享一些特征。

简而言之,该模型通过从域DA获取输入图像,该输入图像被传递到第一个生成器GeneratorA→B,其任务是将来自域DA的给定图像转换到目标域DB中的图像。然后这个新生成的图像被传递到另一个生成器GeneratorB→A,其任务是在原始域DA转换回图像,这里可与自动编码器作对比。这个输出图像必须与原始输入图像相似,用来定义非配对数据集中原来不存在的有意义映射。

在本次的项目中就是利用了CycleGAN进行搭建模型。模型训练数据集如下:

 

实验前的准备


首先我们使用的python版本是3.6.5所用到的库有pytorch和TensorFlow,用来训练和加载神经网络常见的框架;face-alignment用来是用来提取人脸特征的常用库;

dlib是一个机器学习的开源库,包含了机器学习的很多算法,使用起来很方便,直接包含头文件即可,并且不依赖于其他库(自带图像编解码库源码)。Dlib可以帮助您创建很多复杂的机器学习方面的软件来帮助解决实际问题。目前Dlib已经被广泛的用在行业和学术领域,包括机器人,嵌入式设备,移动电话和大型高性能计算环境。


模型的训练


1、数据集处理和准备:

训练数据包括真实照片和卡通画像,为降低训练复杂度,我们对两类数据进行了如下预处理:

· 检测人脸及关键点。

· 根据关键点旋转校正人脸。

· 将关键点边界框按固定的比例扩张并裁剪出人脸区域。

· 使用人像分割模型将背景置白。

为了形成匹配效果,需要准备一些卡通人物图片和真实的人脸图片进行训练

2、模型的训练:

模型的训练使用python train.py --dataset photo2cartoon进行训练即可。

3、神经网络结构搭建:

整个算法的搭建正如上面可见,需要有生成器和判别器。使用论文提出的一种Soft-AdaLIN(Soft Adaptive Layer-Instance Normalization)归一化方法,在反规范化时将编码器的均值方差(照片特征)与解码器的均值方差(卡通特征)相融合。

模型结构方面,在U-GAT-IT的基础上,在编码器之前和解码器之后各增加了2个hourglass模块,渐进地提升模型特征抽象和重建能力。

部分代码如下:

   
   
     
class ResnetGenerator(nn.Module):
    def __init__( self, ngf= 64, img_size= 256, light=False):
         super(ResnetGenerator,  self).__init__()
         self.light = light
         self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d( 3),
                                       nn.Conv2d( 3, ngf, kernel_size= 7, stride= 1, padding= 0, bias=False),
                                    nn.InstanceNorm2d(ngf),
                                       nn.ReLU(True))
         self.HourGlass1 = HourGlass(ngf, ngf)
         self.HourGlass2 = HourGlass(ngf, ngf)
         # Down-Sampling
         self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d( 1),
                                        nn.Conv2d(ngf, ngf* 2, kernel_size= 3, stride= 2, padding= 0, bias=False),
                                        nn.InstanceNorm2d(ngf *  2),
                                        nn.ReLU(True))
         self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d( 1),
                                        nn.Conv2d(ngf* 2, ngf* 4, kernel_size= 3, stride= 2, padding= 0, bias=False),
                                        nn.InstanceNorm2d(ngf* 4),
                                        nn.ReLU(True))
         # Encoder Bottleneck
         self.EncodeBlock1 = ResnetBlock(ngf* 4)
         self.EncodeBlock2 = ResnetBlock(ngf* 4)
         self.EncodeBlock3 = ResnetBlock(ngf* 4)
         self.EncodeBlock4 = ResnetBlock(ngf* 4)
         # Class Activation Map
         self.gap_fc = nn.Linear(ngf* 41)
         self.gmp_fc = nn.Linear(ngf* 41)
         self.conv1x1 = nn.Conv2d(ngf* 8, ngf* 4, kernel_size= 1, stride= 1)
         self.relu = nn.ReLU(True)
         # Gamma, Beta block
         if  self.light:
             self.FC = nn.Sequential(nn.Linear(ngf* 4, ngf* 4),
                                    nn.ReLU(True),
                                    nn.Linear(ngf* 4, ngf* 4),
                                    nn.ReLU(True))
         else:
             self.FC = nn.Sequential(nn.Linear(img_size //4*img_size//4*ngf*4, ngf*4),
                                    nn.ReLU(True),
                                    nn.Linear(ngf* 4, ngf* 4),
                                    nn.ReLU(True))
         # Decoder Bottleneck
         self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf* 4)
         self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf* 4)
         self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf* 4)
         self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf* 4)
         # Up-Sampling
         self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor= 2),
                                      nn.ReflectionPad2d( 1),
                                      nn.Conv2d(ngf* 4, ngf* 2, kernel_size= 3, stride= 1, padding= 0, bias=False),
                                      LIN(ngf* 2),
                                      nn.ReLU(True))
         self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor= 2),
                                      nn.ReflectionPad2d( 1),
                                      nn.Conv2d(ngf* 2, ngf, kernel_size= 3, stride= 1, padding= 0, bias=False),
                                      LIN(ngf),
                                      nn.ReLU(True))
         self.HourGlass3 = HourGlass(ngf, ngf)
         self.HourGlass4 = HourGlass(ngf, ngf, False)
         self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d( 3),
                                        nn.Conv2d( 33, kernel_size= 7, stride= 1, padding= 0, bias=False),
                                        nn.Tanh())
    def forward( self, x):
        x =  self.ConvBlock1(x)
        x =  self.HourGlass1(x)
        x =  self.HourGlass2(x)
        x =  self.DownBlock1(x)
        x =  self.DownBlock2(x)
        x =  self.EncodeBlock1(x)
        content_features1 = F.adaptive_avg_pool2d(x,  1).view(x.shape[ 0],  -1)
        x =  self.EncodeBlock2(x)
        content_features2 = F.adaptive_avg_pool2d(x,  1).view(x.shape[ 0],  -1)
        x =  self.EncodeBlock3(x)
        content_features3 = F.adaptive_avg_pool2d(x,  1).view(x.shape[ 0],  -1)
        x =  self.EncodeBlock4(x)
        content_features4 = F.adaptive_avg_pool2d(x,  1).view(x.shape[ 0],  -1)
        gap = F.adaptive_avg_pool2d(x,  1)
        gap_logit =  self.gap_fc(gap.view(x.shape[ 0],  -1))
        gap_weight = list( self.gap_fc.parameters())[ 0]
        gap = x * gap_weight.unsqueeze( 2).unsqueeze( 3)
        gmp = F.adaptive_max_pool2d(x,  1)
        gmp_logit =  self.gmp_fc(gmp.view(x.shape[ 0],  -1))
        gmp_weight = list( self.gmp_fc.parameters())[ 0]
        gmp = x * gmp_weight.unsqueeze( 2).unsqueeze( 3)
        cam_logit = torch.cat([gap_logit, gmp_logit],  1)
        x = torch.cat([gap, gmp],  1)
        x =  self.relu( self.conv1x1(x))
        heatmap = torch.sum(x, dim= 1, keepdim=True)
         if  self.light:
            x_ = F.adaptive_avg_pool2d(x,  1)
            style_features =  self.FC(x_.view(x_.shape[ 0],  -1))
         else:
            style_features =  self.FC(x.view(x.shape[ 0],  -1))
        x =  self.DecodeBlock1(x, content_features4, style_features)
        x =  self.DecodeBlock2(x, content_features3, style_features)
        x =  self.DecodeBlock3(x, content_features2, style_features)
        x =  self.DecodeBlock4(x, content_features1, style_features)
        x =  self.UpBlock1(x)
        x =  self.UpBlock2(x)
        x =  self.HourGlass3(x)
        x =  self.HourGlass4(x)
         out =  self.ConvBlock2(x)
         return  out, cam_logit, heatmap
4、提取人脸特征:

为了提取人脸特征以达到加载到网络中的目的,我们需要正确框出人脸同时计算特征距离,以方便后面训练模型师损失函数的调用。

代码如下:

   
   
     
     
     
       
class FaceFeatures(object):
     def __init__(self, weights_path, device):
         self.device = device
         self.model = MobileFaceNet( 512).to(device)
         self.model.load_state_dict(torch.load(weights_path))
         self.model.eval()
     def infer(self, batch_tensor):
         # crop face
        h, w = batch_tensor.shape[ 2:]
        top = int(h /  2.1 * ( 0. 8 -  0. 33))
        bottom = int(h - (h /  2.1 *  0. 3))
        size = bottom - top
        left = int(w /  2 - size /  2)
        right = left + size
        batch_tensor = batch_tensor[ ::top: bottom,  left: right]
        batch_tensor = F.interpolate(batch_tensor, size=[ 112112], mode= 'bilinear', align_corners=True)
        features =  self.model(batch_tensor)
         return features
     def cosine_distance(self, batch_tensor1, batch_tensor2):
        feature1 =  self.infer(batch_tensor1)
        feature2 =  self.infer(batch_tensor2)
         return  1 - torch.cosine_similarity(feature1, feature2)


模型测试


在训练好模型后,我们使用python test.py --photo_path ./images/1.jpg --save_path ./images/2.png测试生成图片。其中1.jpg是原始图片,最终会生成2.jpg图片。

使用python data_process.py --data_path YourPhotoFolderPath --save_path YourSaveFolderPath批量生成

1、调用模型:

调用模型首先要使用torch进行加载模型,读取神经网络参数。在对原始图片提取人脸特征的基础上,加载进网络进行生成即可。因为这里我们还需要对生成的数据进行转换成图片,我们这里还需要使用numpy和opencv进行图片的转化。因为加载如模型和模型生成的必然是数据,而我们需要将生成器产生的数据再转换为图片,就用到了这两个库。

代码如下:

   
   
     
     
     
       
class Photo2Cartoon:
     def __init__(self):
         self.pre = Preprocess()
         self.device = torch.device( "cuda:0"  if torch.cuda.is_available()  else  "cpu")
         self.net = ResnetGenerator(ngf= 32, img_size= 256, light=True).to( self.device)
        params = torch.load( './models/photo2cartoon_weights.pt', map_location= self.device)
         self.net.load_state_dict(params[ 'genA2B'])
     def inference(self, img):
         # face alignment and segmentation
        face_rgba =  self.pre.process(img)
         if face_rgba is  None:
            print( 'can not detect face!!!')
             return None
        face_rgba = cv2.resize(face_rgba, ( 256256), interpolation=cv2.INTER_AREA)
        face = face_rgba[ :::3].copy()
        mask = face_rgba[ ::3][ ::, np.newaxis].copy() /  255.
        face = (face*mask + ( 1-mask)* 255) /  127.5 -  1
        face = np.transpose(face[np.newaxis,  :::], ( 0312)).astype(np.float32)
        face = torch.from_numpy(face).to( self.device)
         # inference
        with torch.no_grad():
            cartoon =  self.net(face)[ 0][ 0]
         # post-process
        cartoon = np.transpose(cartoon.cpu().numpy(), ( 120))
        cartoon = (cartoon +  1) *  127.5
        cartoon = (cartoon * mask +  255 * ( 1 - mask)).astype(np.uint8)
        cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)
         return cartoon
if __name_ _ ==  '__main__':
    img = cv2.cvtColor(cv2.imread(args.photo_path), cv2.COLOR_BGR2RGB)
    c2p = Photo2Cartoon()
    cartoon = c2p.inference(img)
     if cartoon is  not  None:
        cv2.imwrite(args.save_path, cartoon)
到这里,我们整体的程序就搭建完成,下面为我们程序的运行结果: 

在这里附上源码地址:

链接:https://pan.baidu.com/s/1jYVt8T0IPqpYmuNIRyvNGg

提取码:54vp

作者简介:

李秋键,CSDN 博客专家,CSDN达人课作者。硕士在读于中国矿业大学,开发有taptap安卓武侠游戏一部,vip视频解析,文意转换工具,写作机器人等项目,发表论文若干,多次高数竞赛获奖等等。

更多精彩推荐

全球 Python 调查报告:Python 2 正在消亡,PyCharm 比 VS Code 更受欢迎!

雷军喜提第四家上市公司;梨视频 App 被全网下架;Flutter 1.17 稳定版发布 | 极客头条

微服务太杂乱难以管理?一站式服务治理平台来袭!

开源一年,阿里轻量级AI推理引擎MNN 1.0.0正式发布

Redis 6.0 新特性:多线程连环 13 问!

从技术原理解析区块链为何列入新基建

你点的每个“在看”,我都认真当成了喜欢
登录查看更多
0

相关内容

 【SIGGRAPH 2020】人像阴影处理,Portrait Shadow Manipulation
专知会员服务
28+阅读 · 2020年5月19日
【天津大学】风格线条画生成技术综述
专知会员服务
31+阅读 · 2020年4月26日
【干货书】流畅Python,766页pdf,中英文版
专知会员服务
223+阅读 · 2020年3月22日
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
68+阅读 · 2020年1月17日
必读的10篇 CVPR 2019【生成对抗网络】相关论文和代码
专知会员服务
31+阅读 · 2020年1月10日
【GitHub实战】Pytorch实现的小样本逼真的视频到视频转换
专知会员服务
35+阅读 · 2019年12月15日
已删除
将门创投
8+阅读 · 2019年7月10日
项目 | 基于GAN的人脸照片涂鸦编辑
机器学习算法与Python学习
5+阅读 · 2019年3月1日
AI都可以将文字轻松转成图像
计算机视觉战队
4+阅读 · 2018年7月24日
【学界】实景照片秒变新海诚风格漫画:清华大学提出CartoonGAN
GAN生成式对抗网络
14+阅读 · 2018年6月20日
如何上手深度学习中的图像领域?有这个资源库就够了
数据挖掘入门与实战
5+阅读 · 2018年4月13日
Pluralistic Image Completion
Arxiv
8+阅读 · 2019年3月11日
Arxiv
7+阅读 · 2018年6月8日
Arxiv
4+阅读 · 2018年4月9日
Arxiv
6+阅读 · 2018年3月28日
Arxiv
6+阅读 · 2018年3月12日
Arxiv
7+阅读 · 2018年1月21日
Arxiv
5+阅读 · 2017年7月23日
VIP会员
相关资讯
已删除
将门创投
8+阅读 · 2019年7月10日
项目 | 基于GAN的人脸照片涂鸦编辑
机器学习算法与Python学习
5+阅读 · 2019年3月1日
AI都可以将文字轻松转成图像
计算机视觉战队
4+阅读 · 2018年7月24日
【学界】实景照片秒变新海诚风格漫画:清华大学提出CartoonGAN
GAN生成式对抗网络
14+阅读 · 2018年6月20日
如何上手深度学习中的图像领域?有这个资源库就够了
数据挖掘入门与实战
5+阅读 · 2018年4月13日
相关论文
Pluralistic Image Completion
Arxiv
8+阅读 · 2019年3月11日
Arxiv
7+阅读 · 2018年6月8日
Arxiv
4+阅读 · 2018年4月9日
Arxiv
6+阅读 · 2018年3月28日
Arxiv
6+阅读 · 2018年3月12日
Arxiv
7+阅读 · 2018年1月21日
Arxiv
5+阅读 · 2017年7月23日
Top
微信扫码咨询专知VIP会员