Distilling the Knowledge in a Neural Network

知识蒸馏(Knowledge Distilling)是模型压缩的一种方法,是指利用已经训练的一个较复杂的Teacher模型,指导一个较轻量的Student模型训练,从而在减小模型大小和计算资源的同时,尽量保持原Teacher模型的准确率的方法。这种方法受到大家的注意,主要是由于Hinton的论文Distilling the Knowledge in a Neural Network。这篇博客做一总结。后续还会有KD方法的改进相关论文的心得介绍。

背景

这里我将Wang Naiyang在知乎相关问题的回答粘贴如下,将KD方法的motivation讲的很清楚。图森也发了论文对KD进行了改进,下篇笔记总结。

Knowledge Distill是一种简单弥补分类问题监督信号不足的办法。传统的分类问题,模型的目标是将输入的特征映射到输出空间的一个点上,例如在著名的Imagenet比赛中,就是要将所有可能的输入图片映射到输出空间的1000个点上。这么做的话这1000个点中的每一个点是一个one hot编码的类别信息。这样一个label能提供的监督信息只有log(class)这么多bit。然而在KD中,我们可以使用teacher model对于每个样本输出一个连续的label分布,这样可以利用的监督信息就远比one hot的多了。另外一个角度的理解,大家可以想象如果只有label这样的一个目标的话,那么这个模型的目标就是把训练样本中每一类的样本强制映射到同一个点上,这样其实对于训练很有帮助的类内variance和类间distance就损失掉了。然而使用teacher model的输出可以恢复出这方面的信息。具体的举例就像是paper中讲的, 猫和狗的距离比猫和桌子要近,同时如果一个动物确实长得像猫又像狗,那么它是可以给两类都提供监督。综上所述,KD的核心思想在于”打散”原来压缩到了一个点的监督信息,让student模型的输出尽量match teacher模型的输出分布。其实要达到这个目标其实不一定使用teacher model,在数据标注或者采集的时候本身保留的不确定信息也可以帮助模型的训练。

蒸馏

这篇论文很好阅读。论文中实现蒸馏是靠soften softmax prob实现的。在分类任务中,常常使用交叉熵作为损失函数,使用one-hot编码的标注好的类别标签1,2,…,K1,2,…,K作为target,如下所示: \mathcal{L}=-\sum_{i=1}^{K} t_{i} \log p_{i}

Cross entropy loss:

作者指出,粗暴地使用one-hot编码丢失了类间和类内关于相似性的额外信息。举个例子,在手写数字识别时,22和33就长得很像。但是使用上述方法,完全没有考虑到这种相似性。对于已经训练好的模型,当识别数字2时,很有可能它给出的概率是:数字2为0.99,数字3为 10^{-2} ,数字7为 10^{-4} 。如何能够利用训练好的Teacher模型给出的这种信息呢?

可以使用带温度的softmax函数。对于softmax的输入(下文统一称为logit),我们按照下式给出输出: q_{i}=\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)}

其中,当T=1时,就是普通的softmax变换。这里令T>1,就得到了软化的softmax。(这个很好理解,除以一个比1大的数,相当于被squash了,线性的sqush被指数放大,差距就不会这么大了)。OK,有了这个东西,我们将Teacher网络和Student的最后充当分类器的那个全连接层的输出都做这个处理。

对Teacher网络的logit如此处理,得到的就是soft target。相比于one-hot的ground truth或softmax的prob输出,这个软化之后的target能够提供更多的类别间和类内信息。
可以对待训练的Student网络也如此处理,这样就得到了另外一个“交叉熵”损失:

\mathcal{L}_{s o f t}=-\sum_{i=1}^{K} p_{i} \log q_{i}

实现

这里给出一个开源的MXNet的实现:kd loss by mxnet。MXNet中的SoftmaxOutput不仅能直接支持one-hot编码类型的array作为label输入,甚至label的dtype也可以不是整型!

def kd(student_hard_logits, teacher_hard_logits, temperature, weight_lambda, prefix):
    student_soft_logits = student_hard_logits / temperature
    teacher_soft_logits = teacher_hard_logits / temperature
    teacher_soft_labels = mx.symbol.SoftmaxActivation(teacher_soft_logits,
        name="teacher%s_soft_labels" % prefix)
    kd_loss = mx.symbol.SoftmaxOutput(data=student_soft_logits, label=teacher_soft_labels,
                                      grad_scale=weight_lambda, name="%skd_loss" % prefix)
    return kd_loss

好文链接:

编辑于 2020-05-31 11:54