【ICML2019】中科院自动化所-针对小样本问题的学习生成匹配网络方法

成功的深度神经网络训练往往离不开大量的数据,但是如果每个类别只有几个标注样本的学习问题该怎么办呢?最近,来自中科院自动化所NLPR的多媒体计算团队研究人员提出一种直接使用小样本训练数据生成具有分类功能的网络权重的元学习方法。此论文已被ICML-2019接收。


LGM-Net: Learning to Generate Matching Networks for Few-Shot Learning

论文链接: https://arxiv.org/abs/1905.06331

论文代码: https://github.com/likesiwell/LGM-Net


01

动机与目标


当前人工智能领域的成功,往往依赖于计算机运算能力的提升以及利用大量的数据,但人类智能却可以通过利用积累的学习经验,在面对新的问题时,从少量的样本(few-shot)中进行有效的学习。在现实中,随着更多应用场景的涌现,我们也将必然面临更多数据不足的问题,因此如何能够让机器像人类一样能够利用学习经验,从少量数据中进行有效学习,成为了一个重要的研究方向。

目前,深度神经网络的训练往往需要大量数据和训练时间,当训练数据较少时,神经网络通常容易过拟合,这是由于经典的随机梯度下降(SGD)算法是一种通用的优化算法,它们没有包含任何针对当前任务的先验知识,目的是在神经网络的损失景观(loss landscape)中寻找最优点或者次优点。当一个神经网络的计算结构固定时,网络的参数权重决定了网络的功能,当数据量较少的时候,使用SGD在神经网络的参数空间中得到的参数点并不具有好的泛华能力;但是当数据量充足的时候,使用SGD却可以得到较好泛化能力的参数点。所以,具有良好泛化能力的参数点是存在的,只是经典的优化算法在小数据场景下达不到。我们将使神经网络针对某个任务具有良好泛化能力的参数称为功能权重(functional weights)。


根据上面所述,我们可以认为功能权重是一个基于训练数据的条件概率分布。于是,我们针对小样本学习问题提出一种元学习方法,可以基于训练数据直接产生出目标网络的功能权重来,让神经网络在大量的任务中积累经验,自己学会如何解决小样本问题。


02

提出的方法

2.1  元学习与任务情景式训练简介

目前常见的元学习框架主要由一个元学习器(metalearner)和一个基础学习器(base learner)组成。基础学习器是针对某个任务设计的学习器,而元学习器从各种学习任务中积累学习经验(元知识),用来指导帮助基础学习器的学习。常见的元学习器形式如学习一种参数化的更新算法(如Learning to learn[1], Meta-LSTM[3]),或一个好的初始化参数(如MAML[2])或者一种参数调节器(如Introspection[4])。


任务情景训练是在matchingnetworks[5]中首次提出使用,它从针对小样本的问题,从原始数据中,随机抽取样本组成不同的小样本学习任务,在训练过程中,模拟元学习器帮助基础学习器解决当前目标任务的过程,在元学习器中积累学习经验,可以泛化到没有见过的新任务中去。具体参见各种元学习文章中的训练过程。


2.2
模型框架

如图一所示,我们的模型框架主要由两部分组成,即元网络(MetaNet)对应meta learner和目标网络(TargetNet)对应base learner。目标网络是针对某个问题设计的网络结构,如分类或回归网络,并且目标网络中没有可学习参数,它的全部参数由元网络产生,在本文中使用matching networks的网络结构作为TargetNet来解决小样本分类问题。


图一:算法框架示意图

 

我们同样使用任务情景训练方式来训练我们的算法,训练过程如图二所示。从一个元训练数据中抽取一个batch的N-way K-shot 任务,对于这个batch中每个任务,我们使用这个任务的训练样本通过MetaNet采样生成一个TargetNet的参数,然后使用这个任务的测试样本或者这个任务的测试损失,累积这个batch中的所有损失,更新MetaNet中的可学习参数,对于高维的数据,我们通过一个可学习的embedding module将其转换为低纬度特征用于上述训练。


图二:算法训练过程


2.2.1 元网络模块

MetaNet module由任务环境编码网络(taskcontext encoder)和参数生成器(conditional weight generator)组成,它的目的是编码任务数据(如果是高维数据则编码数据特征)然后采样生成目标网络的功能参数。


在task context encoder中,我们使用任务样本特征统计平均的方式来编码这个任务,使用重新参数化的方式采样一个这个任务的编码向量c_i,如下所示:


 


再通过conditional weight generator,将任务编码向量转换为每一层的参数:


同时,我们需要约束生成权重的尺度,于是我们采用类似weights normalization的方法,如果生成的是convolution weights,则对每个kernel进行L2 normalization,如果生成的是全连接层矩阵,则对矩阵中每个超平面参数进行L2 normalization。这样做可以稳定训练过程。

 

2.2.2 目标网络模块

目标网络模块采用与matching networks相同的计算结构来解决小样本问题。通过计算一个任务中的测试样本与训练样本之间的概率注意力核函数,来推断出测试样本的类别,在通过交叉熵函数,作为当前任务的损失,具体计算过程如下:


2.2.3 任务间标准化

以前的元学习方法,在任务情景训练中,都将每个任务独立对待。但是,我们发现,其实任务之间总是有一些可以共享的信息,可以帮助训练元网络的。于是我们提出了任务间标准化的方法,来实现这个目的。我们的具体实现非常简单,直接使用batch normalization方法,用在整个任务batch的所有数据上,在训练过程中,通过BN的计算,可以达到所有任务都相互关联的目的。


03

实验

3.1 合成实验数据

我们构建了四个合成数据(如图三所示),将我们的算法用在这些数据的小样本学习上,通过可视化的方式来体现我们方法的insight。如果四所示,我们在新的任务上可视化了不同网络参数情况下的决策边界。第一列代表TargetNet随机初始化参数情况下的决策边界,第二列表示TargetNet使用梯度下降得到的参数情况下的决策边界,第三列表示TargetNet使用MetaNet产生的参数情况下的决策边界。可以看到,直接梯度下降训练会导致严重过拟合,而使用训练后的MetaNet产生的参数的决策边界,可以很好地适应同类型的任务。


图三:合成数据的可视化,相同颜色的点代表同一个类别


图四:目标网络在新任务决策边界可视化


3.2  Omniglot 和 nimiImageNet实验结果对比

在Omniglot数据上,虽然我们的结果没有超越所有的其他的方法,但是也同样达到了比较好的性能,由此可以说明我们方法的有效性。



在miniImageNet数据集上,我们发现对于1-shot任务我们的方法均达到了state of the art的水平。



04

结论与讨论

 

元学习的概念最早在Jürgen Schmidhuber的博士论文中提出,它的最初目的就是学习学习器训练过程中可以迁移的学习经验,来指导其他学习器的学习。最近几年,元学习不仅在小样本任务上取得很好的效果,并且也逐渐开始在应用任务领域出现,相信不久的将来,元学习也将会在人工智能领域的研究发挥更大的作用。


在元学习中,有两个方面是最重要的。一个是如何针对具体任务构建任务情景训练结构;另一个是如何针对具体任务构建元学习器。本文在前人研究的基础上,提出了一种新参数生成的元学习器形式,与参数更新器形式,参数调节器形式,共享初始化形式一样,相信对未来的元学习的研究起到重要的启发作用。


05

参考文献

 

[1]. Andrychowicz, M., Denil, M., Gomez,S., Hoffman, M. W., Pfau, D., Schaul, T. & De Freitas, N. Learning to learnby gradient descent by gradient descent. NIPS, 2016

[2]. Finn, C., Abbeel, P., & Levine, S.Model-agnostic meta-learning for fast adaptation of deep networks. ICML, 2017

[3]. Ravi, S. and Larochelle, H.Optimization as a model for fewshot learning. ICLR, 2017.

[4]. Sinha, A., Sarkar, M., Mukherjee, A.,& Krishnamurthy, B. Introspection: Accelerating neural network training bylearning weight evolution, 2017, ICLR

[5]. Vinyals, O., Blundell, C., Lillicrap, T., & Wierstra, D.Matching networks for one shot learning. NIPS, 2016.


-END-

专 · 知

专知,专业可信的人工智能知识分发,让认知协作更快更好!欢迎登录www.zhuanzhi.ai,注册登录专知,获取更多AI知识资料!

欢迎微信扫一扫加入专知人工智能知识星球群,获取最新AI专业干货知识教程视频资料和与专家交流咨询

请加专知小助手微信(扫一扫如下二维码添加),加入专知人工智能主题群,咨询技术商务合作~

专知《深度学习:算法到实战》课程全部完成!550+位同学在学习,现在报名,限时优惠!网易云课堂人工智能畅销榜首位!

点击“阅读原文”,了解报名专知《深度学习:算法到实战》课程

展开全文
Top
微信扫码咨询专知VIP会员