原文:
www.kdnuggets.com/2018/12/handling-imbalanced-datasets-deep-learning.html
像灭霸一样为你的数据集带来平衡
并不是所有数据都是完美的。实际上,如果你能获得一个完美平衡的现实世界数据集,你将非常幸运。大多数情况下,你的数据会有某种程度的类别不平衡,即每个类别的示例数量不同。
在投入时间进行任何可能漫长的深度学习项目之前,理解为什么我们应该这样做是很重要的,以确保这是一个有价值的投资。类别平衡技术只有在我们真正关心少数类时才是必要的。
比如,假设我们要预测是否应该根据市场的当前状况、房子的属性和我们的预算来购买一栋房子。在这种情况下,如果我们决定买房,那么这个决定是否正确是非常重要的,因为这是一个巨大的投资。同时,如果我们的模型说不买而实际上应该买,这也不是大问题。即使错过了某一栋房子,仍然会有其他房子可供选择,但对如此大的资产做出错误投资会很糟糕。
在这个例子中,我们绝对需要我们的少数“买”类别非常准确,而对于“不要买”类别则不是那么重要。然而在实际情况中,由于在数据中购买会比不购买少得多,我们的模型会倾向于非常好地学习“不要买”类别,因为它有最多的数据,并可能在“买”类别上表现不佳。这就需要平衡数据,以便我们可以更重视“买”预测的正确性!
那么如果我们不太关心少数类呢?例如,假设我们正在做图像分类,并且你的类别分布如下:
初看起来,平衡数据似乎有帮助。但也许我们对那些少数类不太感兴趣。也许我们的主要目标是获得最高可能的百分比准确度。在这种情况下,进行任何平衡并没有太大意义,因为我们的百分比准确度大部分来自于训练示例更多的类别。其次,即使数据集不平衡,分类交叉熵损失在追求最高百分比准确度时也往往表现良好。总的来说,我们的少数类对实现主要目标贡献不大,因此平衡并不是必要的。
说到这里,当我们遇到需要平衡数据的情况时,我们可以使用两种技术来帮助我们。
权重平衡通过调整每个训练样本在计算损失时所承载的权重来平衡我们的数据。通常,我们的损失函数中的每个样本和类别将承担相等的权重,即 1.0。但是有时我们可能希望某些类别或某些训练样本承载更多的权重,如果它们更为重要。再次以我们买房的例子为例,由于“买入”类别的准确性对我们最为重要,因此该类别中的训练样本应该对损失函数有显著影响。
我们可以通过根据类别的不同,将每个样本的损失乘以某个因子来给类别赋权。在 Keras 中,我们可以这样做:
我们创建了一个字典,其中基本上规定了我们的“买入”类别应在损失函数中占据 75%的权重,因为“买入”类别比“不要买”类别更为重要,而“不要买”类别的权重则相应设置为 25%。当然,这些值可以很容易地调整,以找到最适合您应用的设置。如果我们的某个类别的样本显著多于其他类别,我们也可以使用这种权重平衡的方法。与其花费时间和资源去收集更多的少数类样本,不如尝试使用权重平衡,使所有类别对我们的损失函数的贡献相等。
另一种平衡训练样本权重的方法是焦点损失。其主要思想是:在我们的数据集中,某些训练样本自然比其他样本更容易分类。在训练过程中,这些样本将以 99%的准确率被分类,而其他更具挑战性的样本可能仍表现较差。问题在于,那些容易分类的训练样本仍然在贡献损失。为什么我们还要对它们给予相同的权重,而在其他更具挑战性的数据点上,如果正确分类,能对我们的整体准确率贡献更多?!
这正是焦点损失可以解决的问题!焦点损失降低了对已分类样本的权重,而不是对所有训练样本给予相等的权重。这有助于将更多的训练重点放在那些难以分类的数据上!在数据不平衡的实际设置中,由于我们拥有更多的数据,我们的多数类很快就会被良好分类。因此,为了确保我们在少数类上也能实现高准确率,我们可以使用焦点损失在训练过程中给予这些少数类样本更多的相对权重。焦点损失可以很容易地在 Keras 中作为自定义损失函数实现:
选择合适的类别权重有时可能会很复杂。简单的逆频率方法可能效果不好。焦点损失可以提供帮助,但即便如此,它也会平等地减少每个类别的所有良好分类示例的权重。因此,平衡数据的另一种方式是通过直接采样。请查看下面的图示。
欠采样和过采样
在上图的左侧和右侧,我们的蓝色类别样本远多于橙色类别。在这种情况下,我们有两个预处理选项可以帮助训练我们的机器学习模型。
欠采样意味着我们只会从多数类中选择一部分数据,只使用与少数类相同数量的样本。这种选择应保持类别的概率分布。这样很简单!我们通过减少样本数量平衡了数据集!
过采样意味着我们会创建副本以使少数类的样本数量与多数类相同。这些副本的生成会保持少数类的分布。我们没有获取更多数据,但依然平衡了数据集!如果发现类别权重难以有效设置,采样可以是平衡类别的一个好替代方案。
关注我的推特,我会发布最新最前沿的 AI、技术和科学内容!
简介: George Seif 是一名认证极客及 AI/机器学习工程师。
原文. 经许可转载。
相关内容:
-
数据科学家必知的 5 种聚类算法
-
三种提高不平衡数据集上机器学习模型性能的技术
-
使用 Python 提升数据预处理速度 2-6 倍
1. 谷歌网络安全证书 - 快速进入网络安全职业。
2. 谷歌数据分析专业证书 - 提升您的数据分析技能
3. 谷歌 IT 支持专业证书 - 支持您的组织 IT 工作