Github 项目推荐 | PyTorch 实现的 GAN 文本生成框架

2019 年 6 月 10 日 AI研习社

Github项目地址:https://github.com/williamSYSU/TextGAN-PyTorch


TextGAN是一个用于生成基于GANs的文本生成模型的PyTorch框架。TextGAN是一个基准测试平台,支持基于GAN的文本生成模型的研究。由于大多数基于GAN的文本生成模型都是由Tensorflow实现的,TextGAN可以帮助那些习惯了PyTorch的人更快地进入文本生成领域。

目前,只有少数基于GAN的模型被实现,包括 SeqGAN (Yu et. al, 2017), LeakGAN (Guo et. al, 2018) 和 RelGAN (Nie et. al, 2018)。

环境要求

  • PyTorch >= 1.0.0

  • Python 3.6

  • Numpy 1.14.5

  • CUDA 7.5+ (For GPU)

  • nltk 3.4

  • tqdm 4.32.1

运行 pip install -r requirements.txt 即可安装。 如果出现了CUDA问题,请查看PyTorch官方的入门指南(https://pytorch.org/get-started/locally/)。

实现模型和原始论文

  • SeqGAN - SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient

    https://arxiv.org/abs/1609.05473

  • LeakGAN - Long Text Generation via Adversarial Training with Leaked Information

    https://arxiv.org/abs/1709.08624

  • RelGAN - RelGAN: Relational Generative Adversarial Networks for Text Generation

    https://openreview.net/forum?id=rJedV3R5tm

入门

  • 开始

git clone
cd TextGAN-PyTorch

对于真实数据实验,可以从下载Image COCOEMNLP新闻数据集,下载链接:

https://drive.google.com/drive/folders/1XvT3GqbK1wh3XhTgqBLWUtH_mLzGnKZP?usp=sharing

  • 使用SeqGAN运行

cd run
python3 run_seqgan.py 0 0 # The first 0 is job_id, the second 0 is gpu_id

  • 使用LeakGAN运行

cd run
python3 run_leakgan.py 0 0

  • 使用RelGAN运行

cd run
python3 run_relgan.py 0 0

特点

1.Instructor

对于每个模型,整个运行过程在instructor/oracle_data/seqgan_instructor.py中定义。 (以合成数据实验中的SeqGAN为例)。 init_model()optimize()等基本函数在instructor.py的基类BasicInstructor中定义。 如果要添加新的基于GAN的文本生成模型,请在Instructor/oracle_data下创建一个新的Instructor,并定义模型的训练过程。

2.可视化

使用utils/visualization.py可视化日志文件,包括模型丢失和度量标准分数。 在log_file_list中自定义日志文件,不超过 len(color_list)。 日志文件名应排除.txt。

3.日志记录

TextGAN-PyTorch使用Python中的logging(日志记录)模块来记录正在运行的进程,如生成器的丢失和度量标准分数。 为了便于可视化,将分别在log/log _****_ ****。txt和save/**/log.txt中保存两个相同的日志文件。 此外,代码将自动保存模型的状态字典和批量大小的生成器样本,每个日志步骤为./save/**/models./save/**/samples,其中**取决于您的超级参数。

4.运行信号

你可以使用基于字典文件run_signal.txt的Signal类(请查看utils/helpers.py)轻松控制训练过程。

如果要使用Signal,只需编辑本地文件run_signal.txt并将pre_sig设置为Fasle,程序将停止预训练过程并进入下一个训练阶段。 如果你认为当前的训练已经足够,可以非常方便地提前停止训练。

5.自动选择GPU

config.py中,程序会自动选择nvidia-smiGPU-Util最少的GPU设备。 默认情况下启用此功能。 如果要手动选择GPU设备,请取消注释run_[run_model].py中的--device args并使用命令指定GPU设备。

TODO

  • 添加实验结果

  • 修复LeakGAN模型中的错误

  • instrutor/real_data中添加SeqGANLeakGANinstructors


 点击 阅读原文 ,进技术交流小组查看更多Github项目推荐

登录查看更多
33

相关内容

在自然语言处理中,另外一个重要的应用领域,就是文本的自动撰写。关键词、关键短语、自动摘要提取都属于这个领域的一种应用。
【IJCAI2020-华为诺亚】面向深度强化学习的策略迁移框架
专知会员服务
25+阅读 · 2020年5月25日
【CVPR2020】MSG-GAN:用于稳定图像合成的多尺度梯度GAN
专知会员服务
26+阅读 · 2020年4月6日
必读的10篇 CVPR 2019【生成对抗网络】相关论文和代码
专知会员服务
31+阅读 · 2020年1月10日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【GitHub实战】Pytorch实现的小样本逼真的视频到视频转换
专知会员服务
35+阅读 · 2019年12月15日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
160+阅读 · 2019年10月28日
Keras作者François Chollet推荐的开源图像搜索引擎项目Sis
专知会员服务
29+阅读 · 2019年10月17日
生成式对抗网络GAN异常检测
专知会员服务
114+阅读 · 2019年10月13日
基于PyTorch/TorchText的自然语言处理库
专知
27+阅读 · 2019年4月22日
Github 项目推荐 | YOLOv3 的最小化 PyTorch 实现
AI研习社
25+阅读 · 2018年5月31日
用PyTorch实现各种GANs(附论文和代码地址)
Github 项目推荐 | 用 Pytorch 实现的 Capsule Network
AI研习社
22+阅读 · 2018年3月7日
Adversarial Mutual Information for Text Generation
Arxiv
13+阅读 · 2020年6月30日
Arxiv
4+阅读 · 2018年5月21日
VIP会员
相关VIP内容
【IJCAI2020-华为诺亚】面向深度强化学习的策略迁移框架
专知会员服务
25+阅读 · 2020年5月25日
【CVPR2020】MSG-GAN:用于稳定图像合成的多尺度梯度GAN
专知会员服务
26+阅读 · 2020年4月6日
必读的10篇 CVPR 2019【生成对抗网络】相关论文和代码
专知会员服务
31+阅读 · 2020年1月10日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【GitHub实战】Pytorch实现的小样本逼真的视频到视频转换
专知会员服务
35+阅读 · 2019年12月15日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
160+阅读 · 2019年10月28日
Keras作者François Chollet推荐的开源图像搜索引擎项目Sis
专知会员服务
29+阅读 · 2019年10月17日
生成式对抗网络GAN异常检测
专知会员服务
114+阅读 · 2019年10月13日
Top
微信扫码咨询专知VIP会员