实践 | 使用fasttext进行文档分类

2018 年 4 月 29 日 黑龙江大学自然语言处理实验室 兜哥

本文授权转载自公众号:兜哥带你学安全


fasttext原理

fasttext提供了一种有效且快速的方式生成词向量以及进行文档分类。fasttext模型输入一个词的序列,输出这个词序列属于不同类别的概率。fasttext模型架构和Word2Vec中的CBOW模型很类似。不同之处在于,fasttext预测标签,而CBOW模型预测中间词。fasttext设计的初衷就是为了作为一个文档分类器,副产品是也生成了词向量。

fasttext特性

n-gram

在词袋模型中,把单词当做独立的个体,没有考虑词前后的关系。比如"我打你"和“你打我“,使用词袋模型的话,这两句话是完全一样的。词袋的特征为:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. ["我",“打“,”你”]

"我打你"和“你打我“对应的特征向量均为:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. [1,1,1]

n-gram是对词袋模型的一种改善,它会关注一个单词的前后关系,比如n-gram中最常见的2-gram,就关注单词的前一个词,比如"我打你",就可以拆分为"我打"和"打你"。这两句话一起建模的话,2-gram对应的特征为:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. ["我打""打你""你打""打我"]

"我打你"对应的特征向量为:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. [1,1,0,0]

"你打我"对应的特征向量为:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. [0,0,1,1]

与Word2Vec使用词袋模型不同,fasttext使用了n-gram模型,因此fasttext可以更有效的表达词前后的之间的关系。

高效率

fasttext在使用标准多核CPU的情况下10分钟内处理超过10亿个词汇,特别是与深度模型对比,fastText能将训练时间由数天缩短到几秒钟。使用一个标准多核CPU,得到了在10分钟内训练完超过10亿词汇量模型的结果。

安装fasttext

fasttext的安装非常简便,直接从github上同步最新的代码并进行安装即可。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. $ git clone https://github.com/facebookresearch/fastText.git

  2. $ cd fastText

  3. $ pip install .

预训练模型

facebook已经基于其收集的海量语料,训练好了fasttext的词向量模型,目前已经支持了150多种语言。有需要的读者可以直接下载并使用,对应的链接为:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. https://github.com/facebookresearch/fastText/blob/master/docs/crawl-vectors.md


数据集

数据集依然使用搜狗实验室提供的"搜狐新闻数据",该数据来自搜狐新闻2012年6月—7月期间国内,国际,体育,社会,娱乐等18个频道的新闻数据,提供URL和正文信息。对应的网址为:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. http://www.sogou.com/labs/resource/cs.php

数据文件的格式为:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. <url>页面URL</url>

  2. <docno>页面ID</docno>

  3. <contenttitle>页面标题</contenttitle>

  4. <content>页面内容</content>

  5. </doc>

我们可以看到数据文件中并没有标记页面内容属于哪个频道,如果需要做文档分类,搜狗提供了页面URL和频道之间的映射关系。

下载SogouTCE文件,可以看到具体的映射关系举例如下:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. http://www.xinhuanet.com/auto/  汽车

  2. http://www.xinhuanet.com/fortune    财经

  3. http://www.xinhuanet.com/internet/  IT

  4. http://www.xinhuanet.com/health/    健康

  5. http://www.xinhuanet.com/sports 体育

  6. http://www.xinhuanet.com/travel 旅游

  7. http://www.xinhuanet.com/edu    教育

  8. http://www.xinhuanet.com/employment 招聘

  9. http://www.xinhuanet.com/life   文化

  10. http://www.xinhuanet.com/mil    军事

  11. http://www.xinhuanet.com/olympics/  奥运

  12. http://www.xinhuanet.com/society    社会

数据清洗

搜狐新闻数据的文件默认编码格式为gb18030,因此解压缩后要线转换成utf-8格式。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. tar -zxvf news_sohusite_xml.full.tar.gz

  2. cat news_sohusite_xml.dat | iconv -f gb18030 -t utf-8 > news_sohusite_xml-utf8.txt

转换完格式后查看文件内容,文件以xml形式记录,举例如下:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. <doc>

  2. <url>http://gongyi.sohu.com/s2008/sourceoflife/</url>

  3. <docno>f2467af22cd2f0ea-34913306c0bb3300</docno>

  4. <contenttitle>中国西部是地球上主要干旱带之一,妇女是当地劳动力...</contenttitle>

  5. <content>同心县地处宁夏中部干旱带的核心区, 冬寒长,春暖迟,夏热短,秋凉早,干旱少雨,蒸发强烈,风大沙多。主要自然灾害有沙尘暴、干热风、霜冻、冰雹等,其中以干旱危害最为严重。由于生态环境的极度恶劣,导致农村经济发展缓慢,人民群众生产、生活水平低下,靠天吃饭的被动局

  6. 面依然存在,同心,又是国家级老、少、边、穷县之一…[详细]</content>

  7. </doc>

但是数据文件并不是标准的xml格式,如下所示,该文件相对标准的xml格式缺少了根元素。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. <doc>

  2.    <url></url>

  3.    <docno></docno>

  4.    <contenttitle></contenttitle>

  5.    <content></content>

  6. </doc>

  7. <doc>

  8.    <url></url>

  9.    <docno></docno>

  10.    <contenttitle></contenttitle>

  11.    <content></content>

  12. </doc>

所有的doc节点都直接是最顶层,没有根节点。因此要添加根节点使该文本文件符合xml文件的规范,最简单的一种形式就是在文件的开始和结尾添加根元素标签。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. <?xml version="1.0" encoding="utf-8"?>

  2. <docs>

  3.    <doc>

  4.        <url></url>

  5.        <docno></docno>

  6.        <contenttitle></contenttitle>

  7.        <content></content>

  8.    </doc>

  9.    <doc>

  10.        <url></url>

  11.        <docno></docno>

  12.        <contenttitle></contenttitle>

  13.        <content></content>

  14.    </doc>

  15. </docs>

可以直接使用文本编辑工具在数据文件的开始和结尾进行修改,但是这有可能导致你的终端因为内存使用过大而崩溃。一种比较稳妥的做法是使用程序完成。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. def make_xml():

  2.    print "<?xml version="1.0" encoding="utf-8"?>"

  3.    print "<docs>"

  4.    with open("data/news_sohusite_xml-utf8.txt") as F:

  5.        for line in F:

  6.            print line

  7.        F.close()

  8.    print "</docs>"

在终端执行该程序,并将标准输出的结果保存即可,剩下的操作只要解析xml文件即可。下面我们介绍另一种方法,观察可以发现,url和content是成对出现的,并且一一对应。我们可以过滤这两个字段的内容,分别保存成content文件和url文件。首先过滤出url字段的内容,并且删除掉url标签。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. cat news_sohusite_xml-utf8.txt | grep '<url>' | sed  's/<url>//g' | sed  's/</url>//g' > news_sohusite_url.txt

然后过滤出content字段的内容,并且删除掉content标签。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. cat news_sohusite_xml-utf8.txt | grep '<content>' | sed  's/<content>//g' | sed  's/</content>//g' > news_sohusite_content.txt

content是中文内容,需要使用jieba进行切词,可以把切词的动作也放到上面的命令里面。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. cat news_sohusite_xml-utf8.txt | grep '<content>' | sed  's/<content>//g' | sed  's/</content>//g' | python -m jieba -d ' '  > news_sohusite_content.txt

加载url和对应领域的映射关系的文件,以哈希的形式保存对应的映射关系。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. def load_SogouTCE():

  2.    SogouTCE=[]

  3.    SogouTCE_kv = {}

  4.    with open("../data/SogouTCE.txt") as F:

  5.        for line in F:

  6.            (url,channel)=line.split()

  7.            SogouTCE.append(url)

  8.        F.close()

  9.    for index,url in enumerate(SogouTCE):

  10.        #删除http前缀

  11.        url=re.sub('http://','',url)

  12.        print "k:%s v:%d" % (url,index)

  13.        SogouTCE_kv[url]=index

  14.    return  SogouTCE_kv

我们分析下各个领域的数据分布情况,把匹配上的url对应的标记打印出来。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. def load_url(SogouTCE_kv):

  2.    labels=[]

  3.    with open("../data/news_sohusite_url.txt") as F:

  4.        for line in F:

  5.            for k,v in SogouTCE_kv.items():

  6.                if re.search(k,line,re.IGNORECASE):

  7.                    #print "x:%s y:%d" % (line,v)

  8.                    print v

  9.                    labels.append(v)

  10.        F.close()

  11.    return  labels

运行程序,分析各个领域对应的url数量。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. python fasttext.py > v.txt

  2. cat v.txt | sort -n | uniq -c

每行的第一个字段是数量,第二个字段是对应的领域的id,结果表明搜狐新闻数据集中在某几个领域,并且分布不均匀。为了避免样本不均衡导致的误判,我们选择数量上占前三的领域作为后继分析的数据,id分别为81,79和91。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. 138576 79

  2. 27489 80

  3. 199871 81

  4. 23409 82

  5. 44537 83

  6. 2179 84

  7. 13012 85

  8. 1924 87

  9. 3294 88

  10. 842 89

  11. 50138 91

  12. 5882 92

反查对应的url为:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. kit.sohu.com/ id:81

  2. auto.sohu.com/ id:79

  3. yule.sohu.com/ id:91

过滤我们关注的领域的内容,将content保存在x列表里,对应的领域的id保存在y列表里,作为标签使用,至此我们完成了数据清洗的工作。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. def load_selecteddata(SogouTCE_kv):

  2.    x=[]

  3.    y=[]

  4.    #加载content列表

  5.    with open("../data/news_sohusite_content.txt") as F:

  6.        content=F.readlines()

  7.        F.close()

  8.    # 加载url列表

  9.    with open("../data/news_sohusite_url.txt") as F:

  10.        url = F.readlines()

  11.        F.close()

  12.    for index,u in  enumerate(url):

  13.        for k, v in SogouTCE_kv.items():

  14.            # 只加载id为81,79和91的数据

  15.            if re.search(k, u, re.IGNORECASE) and v in (81, 79, 91):

  16.                #保存url对应的content内容

  17.                x.append(content[index])

  18.                y.append(v)

  19.    return x,y

删除停用词

在处理中文语料时,需要删除停用词。所谓停用词就是对理解中文含义没有明显作用的哪些单词,常见的停用词举例如下:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. 一一  

  2. 一下  

  3. 一个  

  4. 一些  

  5. 一何  

  6. 一切  

  7. 一则  

  8. 一则通过  

  9. 一天  

  10. 一定  

  11. 一方面  

  12. 一旦  

  13. 一时

另外所有的字母和数字还有标点符号也可以作为停用词。我们把停用词保存在一个文本文件里面便于配置使用。定义加载停用词的函数。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. def load_stopwords():

  2.    with open("stopwords.txt") as F:

  3.        stopwords=F.readlines()

  4.        F.close()

  5.    return [word.strip() for word in stopwords]

使用停用词过滤之前提取的文本内容。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. stopwords=load_stopwords()

  2. #切割token

  3. x=[  [word for word in line.split() if word not in stopwords]   for line in x]

文档分类

数据文件格式

fasttext对训练和测试的数据格式有一定的要求,数据文件和标签文件要合并到一个文件里面。文件中的每一行代表一条记录,同时每条记录的最后标记对应的标签。默认情况下标签要以__label__开头,比如:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. 这是一条测试数据    __label__1

python下实现合并数据文件和标签文件的功能非常简单。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. def dump_file(x,y,filename):

  2.    with open(filename, 'w') as f:

  3.        for i,v in enumerate(x):

  4.            line="%s __label__%d " % (v,y[i])

  5.            f.write(line)

  6.        f.close()

加载数据清洗后的数据和标签,随机划分成训练数据和测试数据,其中测试数据占20%。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. SogouTCE_kv=load_SogouTCE()

  2. x,y=load_selecteddata(SogouTCE_kv)

  3. # 分割训练集和测试集

  4. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)

按照fasttext的格式要求保存成训练数据和测试数据。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. #按照fasttest的要求生成训练数据和测试数据

  2. dump_file(x_train,y_train,"../data/sougou_train.txt")

  3. dump_file(x_test, y_test, "../data/sougou_test.txt")

查看训练数据文件的内容,举例如下:

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. 长安 标致 雪铁龙 九寨沟 试驾 __label__79

训练模型

下面开始训练fasttext模型。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. # train_supervised uses the same arguments and defaults as the fastText cli

  2. model = train_supervised(

  3.        input="../data/sougou_train.txt", epoch=25, lr=0.6, wordNgrams=2, verbose=2, minCount=1

  4.    )

其中比较重要的几个参数的含义为:

  • input;表示训练数据文件的路径

  • epoch:表示训练的次数

  • lr:表示初始的学习速率

  • wordNgrams:表示n-gram的值,一般使用2,表示2-gram

  • minCount:表示参与计算的单词的最小出现次数。

验证效果

fasttext默认情况下会计算对应的准确率和召回率。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. def print_results(N, p, r):

  2.    print("N " + str(N))

  3.    print("P@{} {:.3f}".format(1, p))

  4.    print("R@{} {:.3f}".format(1, r))

使用测试数据文件进行校验。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. print_results(*model.test("../data/sougou_test.txt"))

运行程序,显示加载了36M的单词,其中包含288770的单词组合,标记类型一共3种。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. Read 36M words

  2. Number of words:  288770

  3. Number of labels: 3

验证效果如下所示,准确率为99.0%,召回率为99.0%,对应的F1计算为99.0%,效果非常不错。

                                                                                                                                                                                       
                                                                                                                                                                                         
                                                                                                                                                                                         
                                                                                                                                                                                         
  1. Progress: 100.0% words/sec/thread:  626183 lr:  0.000000 loss:  0.005640 ETA:   0h 0m

  2. N   71107

  3. P@1 0.990

  4. R@1 0.990






推荐阅读

基础 | TreeLSTM Sentiment Classification

基础 | 详解依存树的来龙去脉及用法

基础 | 基于注意力机制的seq2seq网络

原创 | Simple Recurrent Unit For Sentence Classification

原创 | Highway Networks For Sentence Classification


欢迎关注交流


登录查看更多
7

相关内容

【实用书】Python机器学习Scikit-Learn应用指南,247页pdf
专知会员服务
256+阅读 · 2020年6月10日
零样本文本分类,Zero-Shot Learning for Text Classification
专知会员服务
95+阅读 · 2020年5月31日
【Amazon】使用预先训练的Transformer模型进行数据增强
专知会员服务
56+阅读 · 2020年3月6日
【干货】用BRET进行多标签文本分类(附代码)
专知会员服务
84+阅读 · 2019年12月27日
计算机视觉最佳实践、代码示例和相关文档
专知会员服务
17+阅读 · 2019年10月9日
下载 | 最全中文文本分类模型库,上手即用
机器学习算法与Python学习
30+阅读 · 2019年10月17日
使用 Bert 预训练模型文本分类(内附源码)
数据库开发
102+阅读 · 2019年3月12日
FastText的内部机制
黑龙江大学自然语言处理实验室
5+阅读 · 2018年7月25日
收藏!CNN与RNN对中文文本进行分类--基于TENSORFLOW实现
全球人工智能
12+阅读 · 2018年5月26日
在Python中使用SpaCy进行文本分类
专知
24+阅读 · 2018年5月8日
专栏 | fastText原理及实践
机器之心
3+阅读 · 2018年1月26日
使用fasttext实现文本处理及文本预测
数据挖掘入门与实战
5+阅读 · 2018年1月13日
Arxiv
6+阅读 · 2019年8月22日
How to Fine-Tune BERT for Text Classification?
Arxiv
13+阅读 · 2019年5月14日
Arxiv
5+阅读 · 2019年4月25日
SlowFast Networks for Video Recognition
Arxiv
19+阅读 · 2018年12月10日
Arxiv
12+阅读 · 2018年9月15日
Arxiv
6+阅读 · 2018年7月12日
VIP会员
相关资讯
下载 | 最全中文文本分类模型库,上手即用
机器学习算法与Python学习
30+阅读 · 2019年10月17日
使用 Bert 预训练模型文本分类(内附源码)
数据库开发
102+阅读 · 2019年3月12日
FastText的内部机制
黑龙江大学自然语言处理实验室
5+阅读 · 2018年7月25日
收藏!CNN与RNN对中文文本进行分类--基于TENSORFLOW实现
全球人工智能
12+阅读 · 2018年5月26日
在Python中使用SpaCy进行文本分类
专知
24+阅读 · 2018年5月8日
专栏 | fastText原理及实践
机器之心
3+阅读 · 2018年1月26日
使用fasttext实现文本处理及文本预测
数据挖掘入门与实战
5+阅读 · 2018年1月13日
相关论文
Arxiv
6+阅读 · 2019年8月22日
How to Fine-Tune BERT for Text Classification?
Arxiv
13+阅读 · 2019年5月14日
Arxiv
5+阅读 · 2019年4月25日
SlowFast Networks for Video Recognition
Arxiv
19+阅读 · 2018年12月10日
Arxiv
12+阅读 · 2018年9月15日
Arxiv
6+阅读 · 2018年7月12日
Top
微信扫码咨询专知VIP会员