Distilling the Knowledge in a Neural Network
1503.02531v1
摘要
提高几乎任何机器学习算法性能的一个非常简单的方法是在相同的数据上训练许多不同的模型,然后平均它们的预测。不幸的是,使用整个模型集合进行预测是很麻烦的,而且可能计算成本太高,不允许部署到大量用户,特别是当单个模型是大型神经网络时。Caruana和他的合作者已经证明,可以将集成中的知识压缩成一个单一的模型,这是更容易部署的,我们使用不同的压缩技术进一步开发了这种方法。我们在MNIST上取得了一些令人惊讶的结果,并且我们表明,通过将一个模型集成中的知识提取到一个单一的模型中,我们可以显著地改进一个大量使用的商业系统的声学模型。我们还介绍了一种由一个或多个完整模型和许多专业模型组成的新型集成,这些模型学习区分完整模型混淆的细粒度类。与专家的混合不同,这些专家模型可以快速和并行地进行训练。
导言
许多昆虫的幼虫形态可以从环境中提取能量和营养,以及一种完全不同的成虫形态,可以适应非常不同的旅行和繁殖需求。在大规模的机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音识别和物体识别等任务,训练必须从庞大且冗余度极高的数据集中提取结构,但它不需要实时运行,可以使用大量的计算资源。然而,部署到大量用户时,对延迟和计算资源有更严格的要求。与昆虫的类比表明,如果这样可以更容易地从数据中提取结构,我们应该愿意训练非常麻烦的模型。繁琐的模型可以是单独训练的模型的集合,也可以是使用非常强的正则化器训练的一个非常大的模型,如dropout[9]。**一旦繁琐的模型被训练好,我们就可以使用一种不同的训练,我们称之为“distillation”,将知识从繁琐的模型转移到一个更适合部署的小模型。**Rich卡鲁阿纳和他的合作者[1]已经率先提出了这一策略的一个版本。在他们的重要论文中,他们令人信服地证明了通过大量的模型集合所获得的知识可以转移到一个单一的小模型中。
可能阻碍了对这一极具潜力方法进行更深入研究的一个概念性障碍是,我们往往将训练模型中的知识与学习到的参数值联系在一起,这使得我们很难看到如何改变模型的形式但保持相同的知识。对知识的一个更抽象的观点,使它从任何特定的实例化中解放出来,即它是一种已学习的知识从输入向量映射到输出向量。对于学习区分大量类的繁琐模型,正常的训练目标是最大化正确答案的平均对数概率,但学习的副作用是训练模型分配概率所有的错误的答案,即使这些概率很小,其中一些比其他人大得多。错误答案的相对概率告诉了我们很多关于繁琐的模型是如何泛化的。例如,一个BMW的图片可能被误认为垃圾车的可能性很小,但这个错误仍然比误认为胡萝卜的可能性大很多倍。
人们普遍认为,用于训练的目标函数应该尽可能接近地反映用户的真实目标。尽管如此,当真正的目标是很好地推广到新数据时,模型通常被训练以优化训练数据的性能。显然,训练模型进行泛化良好会更好,但这需要关于正确的泛化方法的信息,而这些信息通常是不可用的。然而,当我们将知识从大模型中提取出来到小模型中时,我们可以训练小模型以与大模型相同的方式进行泛化。如果繁琐的模型有良好的泛化能力,例如,它是大型的集成不同的模型的平均,一个以同样的方式训练成泛化的小模型通常在测试数据上做得会比一个在用于训练集成的相同训练集中以正常方式训练的小模型好得多。
**将繁琐模型的泛化能力转移到小模型的一个明显方法是使用繁琐模型产生的类概率作为训练小模型的“软目标”。**在这个转移阶段,我们可以使用相同的训练集或一个单独的“转移”集。当繁琐的模型是一个简单的模型的大型集合时,我们可以使用它们各自的预测分布的算术或几何平均值作为软目标。当软目标有高熵,与硬目标相比,它们提供了更多的信息,训练案例之间的梯度方差也更小,因此小模型通常可以用比原始繁琐模型更少的数据进行训练,并使用更高的学习率。
对于像MNIST这样的任务,繁琐的模型几乎总是以非常高的可信度产生正确的答案,因此关于学习函数的许多信息存在于软目标中非常小的概率的比率中。例如,一个版本的 2 可能得到$10^{−6}$是3, $10^{−9}$ 的概率是7,而对于另一个版本可能是相反的。这是有价值的信息,定义了数据上丰富的相似性结构(即它说哪个2看起来像3,哪个看起来像7),但在转移阶段对交叉熵代价函数的影响很小,因为概率接近于零。卡鲁阿纳和他的合作者规避这个问题通过使用logits(输入最终softmax)而不是由softmax产生的概率作为学习小模型的目标,它们最小化了由复杂模型产生的logits和由小模型产生的logits之间的平方差。我们更一般的解决方案,称为“蒸馏”,是提高最终softmax的温度,直到繁琐的模型产生一组适当的软目标集。然后,我们使用相同的高温,训练小模型来匹配这些软目标。我们稍后将展示,匹配繁琐模型的logits实际上是蒸馏的一种特殊情况。
用于训练小模型的transfer集可以完全由未标记的数据[1]组成,或者我们也可以使用原始的训练集。我们发现,使用原始训练集效果很好,特别是如果我们在目标函数中添加一个小项,鼓励小模型预测真正的目标,并匹配繁琐的模型提供的软目标。通常情况下,小模型不能完全匹配软目标,并且偏离正确答案的方向是有帮助的。
Distillation
神经网络通常通过使用“softmax”输出层来生成类概率。该输出层将为每个类计算的逻辑值$z_i$与所有其他类的逻辑值进行比较,并将其转换为概率$q_i$。SoftMax计算公式:

其中T是通常设置为1的温度。对T使用更高的值会在类上产生更平滑的概率分布。
最简单的蒸馏方法,知识被转移到被提炼出来的模型中通过在一个transfer集上训练它,并在transfer集中的每个情况下使用soft target分布,是由使用高温的复制模型产生的。在训练蒸馏模型时也使用相同的高温度,但在蒸馏模型训练完成后,使用温度1。
当知道所有或部分transfer集的正确标签时,该方法可以通过训练蒸馏模型产生正确的标签来显著改进。一种方法是使用正确的标签来修改soft targets,**但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是与soft targets的交叉熵,这个交叉熵是用与蒸馏模型的softmax相同的高温来生成软目标来计算的。第二个目标函数是具有正确标签的交叉熵。这是用蒸馏模型的softmax中完全相同的logits来计算的,但在温度为1时。**我们发现,最好的结果通常是通过使用第二个目标函数的相对低的权重。由于软目标产生的梯度的大小为$1/T^2$,所以当同时使用硬目标和软目标时,将它们乘以$T^2$是很重要的。这确保了在实验元参数时,如果改变用于蒸馏的温度,则硬目标和软目标的相对贡献大致保持不变.
本质
先用复杂模型算出softmax,用蒸馏模型拟合复杂模型的softmax结果。
Last updated