非平衡数据集 focal loss 多类分类

2019 年 4 月 23 日 AI研习社

本文为 AI 研习社编译的技术博客,原标题 :

Multi-class classification with focal loss for imbalanced datasets

作者 | Chengwei Zhang

翻译 | 汪鹏       校对 | 斯蒂芬·二狗子

审核 | Pita       整理 | 立鱼王

原文链接:

https://medium.com/swlh/multi-class-classification-with-focal-loss-for-imbalanced-datasets-c478700e65f5

注:本文的相关链接请访问文末【阅读原文】


焦点损失函数 Focal Loss(2017年何凯明大佬的论文)被提出用于密集物体检测任务。它可以训练高精度的密集物体探测器,哪怕前景和背景之间比例为1:1000(译者注:facal loss 就是为了解决目标检测中类别样本比例严重失衡的问题)。本教程将向您展示如何在给定的高度不平衡的数据集的情况下,应用焦点损失函数来训练一个多分类模型。


   背景

让我们首先了解类别不平衡数据集的一般的处理方法,然后再学习 focal loss 的解决方式。

在多分类问题中,类别平衡的数据集的目标标签是均匀分布的。若某类目标的样本相比其他类在数量上占据极大优势,则可以将该数据集视为不平衡的数据集。这种不平衡将导致两个问题:

  • 训练效率低下,因为大多数样本都是简单的目标,这些样本在训练中提供给模型不太有用的信息;

  • 简单的样本数量上的极大优势会搞垮训练,使模型性能退化。

一种常见的解决方案是执行某种形式的困难样本挖掘,实现方式就是在训练时选取困难样本 或 使用更复杂的采样,以及重新对样本加权等方案。

对具体图像分类问题,对数据增强技术方案变更,以便为样本不足的类创建增强的数据。

焦点损失函数旨在通过降低内部加权(简单样本)来解决类别不平衡问题,这样即使简单样本的数量很大,但它们对总损失的贡献却很小。也就是说,该函数侧重于用困难样本稀疏的数据集来训练。


   将 Focal Loss 应用于欺诈检测任务

为了演示,我们将会使用 Kaggle上的欺诈检测数据集 构建一个分类器,这个数据及具有极端的类不平衡问题,它包含总共6354407个正常样本和8213个欺诈案例,两者比例约为733:1。对这种高度不平衡的数据集的分类问题,若某模型简单猜测所有输入样本为“正常”就可以达到733 /(733 + 1)= 99.86%的准确度,这显然是不合理。因此,我们需要的是这个模型能够正确检测出欺诈案例。

为了证明focal loss 比传统技术更有效,让我们建立一个简单地使用类别权重 class_weight训练的基准模型,告诉模型“更多地关注”来自代表性不足的欺诈样本。


基准模型

基准模型的准确率达到了99.87%,略好于通过采取“简单路线”去猜测所有情况都为“正常”。

我们还绘制了混淆矩阵来展示模型在测试集上的分类性能。你可以看到总共有1140 + 480 = 1620 个样本被错误分类。


混淆矩阵-基准模型

现在让我们将focal loss应用于这个模型的训练。你可以在下面看到如何在Keras框架下自定义焦点损失函数focal loss 。


焦点损失函数-模型

焦点损失函数focal loss 有两个可调的参数。

  • 焦点参数γ(gamma)平滑地调整简单样本被加权的速率。当γ= 0时, focal loss 效果与交叉熵函数相同,并且随着 γ 增加,调制因子的影响同样增加(γ = 2在实验中表现的效果最好)。

  • α(alpha):平衡focal loss ,相对于非 α 平衡形式可以略微提高它的准确度。

现在让我们把训练好的模型与之前的模型进行比较性能。

Focal Loss 模型:

  • 精确度:99.94%

  • 总错误分类测试集样本:766 + 23 = 789,将错误数减少了一半。


混淆矩阵-focal loss模型


  结论及导读

在这个快速教程中,我们为你的知识库引入了一个新的工具来处理高度不平衡的数据集 — Focal Loss。并通过一个具体的例子展示了如何在Keras 的 API 中定义 focal loss进而改善你的分类模型。

你可以在我的GitHub上找到这篇文章的完整源代码。

有关focal loss的详细情况,可去查阅论文https://arxiv.org/abs/1708.02002。

最初发表于www.dlology.com.

想要继续查看该篇文章相关链接和参考文献?

点击底部【阅读原文】即可访问:

https://ai.yanxishe.com/page/TextTranslation/1646

AI求职百题斩 · 每日一题


每天进步一点点,长按扫码参与每日一题!




今日话题讨论

点击阅读原文,查看本文更多内容

登录查看更多
33

相关内容

RetinaNet是2018年Facebook AI团队在目标检测领域新的贡献。它的重要作者名单中Ross Girshick与Kaiming He赫然在列。来自Microsoft的Sun Jian团队与现在Facebook的Ross/Kaiming团队在当前视觉目标分类、检测领域有着北乔峰、南慕容一般的独特地位。这两个实验室的文章多是行业里前进方向的提示牌。 RetinaNet只是原来FPN网络与FCN网络的组合应用,因此在目标网络检测框架上它并无特别亮眼创新。文章中最大的创新来自于Focal loss的提出及在单阶段目标检测网络RetinaNet(实质为Resnet + FPN + FCN)的成功应用。Focal loss是一种改进了的交叉熵(cross-entropy, CE)loss,它通过在原有的CE loss上乘了个使易检测目标对模型训练贡献削弱的指数式,从而使得Focal loss成功地解决了在目标检测时,正负样本区域极不平衡而目标检测loss易被大批量负样本所左右的问题。此问题是单阶段目标检测框架(如SSD/Yolo系列)与双阶段目标检测框架(如Faster-RCNN/R-FCN等)accuracy gap的最大原因。在Focal loss提出之前,已有的目标检测网络都是通过像Boot strapping/Hard example mining等方法来解决此问题的。作者通过后续实验成功表明Focal loss可在单阶段目标检测网络中成功使用,并最终能以更快的速率实现与双阶段目标检测网络近似或更优的效果。
【CVPR2020-Oral】用于深度网络的任务感知超参数
专知会员服务
25+阅读 · 2020年5月25日
【CVPR2020】MSG-GAN:用于稳定图像合成的多尺度梯度GAN
专知会员服务
26+阅读 · 2020年4月6日
推荐 :如何改善你的训练数据集?(附案例)
数据分析
3+阅读 · 2019年6月19日
一文教你如何处理不平衡数据集(附代码)
大数据文摘
10+阅读 · 2019年6月2日
被忽略的Focal Loss变种
极市平台
29+阅读 · 2019年4月19日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
11+阅读 · 2018年3月15日
【干货】机器学习中样本比例不平衡的处理方法
机器学习研究会
8+阅读 · 2018年1月14日
论文 | 用于密集对象检测的 Focal Loss 函数
七月在线实验室
9+阅读 · 2018年1月4日
何恺明大神的「Focal Loss」,如何更好地理解?
PaperWeekly
10+阅读 · 2017年12月28日
Arxiv
8+阅读 · 2018年11月27日
Arxiv
4+阅读 · 2018年10月4日
Arxiv
8+阅读 · 2018年4月12日
Arxiv
6+阅读 · 2018年3月19日
VIP会员
相关VIP内容
【CVPR2020-Oral】用于深度网络的任务感知超参数
专知会员服务
25+阅读 · 2020年5月25日
【CVPR2020】MSG-GAN:用于稳定图像合成的多尺度梯度GAN
专知会员服务
26+阅读 · 2020年4月6日
相关资讯
推荐 :如何改善你的训练数据集?(附案例)
数据分析
3+阅读 · 2019年6月19日
一文教你如何处理不平衡数据集(附代码)
大数据文摘
10+阅读 · 2019年6月2日
被忽略的Focal Loss变种
极市平台
29+阅读 · 2019年4月19日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
11+阅读 · 2018年3月15日
【干货】机器学习中样本比例不平衡的处理方法
机器学习研究会
8+阅读 · 2018年1月14日
论文 | 用于密集对象检测的 Focal Loss 函数
七月在线实验室
9+阅读 · 2018年1月4日
何恺明大神的「Focal Loss」,如何更好地理解?
PaperWeekly
10+阅读 · 2017年12月28日
Top
微信扫码咨询专知VIP会员