论文笔记|类别不平衡半监督学习中的平滑自适应权重:提高对未知分布数据的可信性

论文标题:Smoothed Adaptive Weighting for Imbalanced Semi-Supervised Learning: Improve Reliability Against Unknown Distribution Data

论文链接:https://proceedings.mlr.press/v162/lai22b/lai22b.pdf

代码链接:https://github.com/ZJUJeffLai/SAW_SSL

论文来源:ICML2022

作者单位:University of California Davis; Southern University of Science and Technology; University of Kentucky

摘要

  在目前类不平衡的半监督学习研究中,往往假设无标注数据的类别分布与有标注数据的类别分布一致。然而在现实中无标注数据的类别分布往往是未知的。本文提出了一个针对一致性损失的自适应权重调整算法,命名为Smoothed Adaptive Weighting (SAW),来增强半监督学习框架的鲁棒性。

模型框架

先验知识

  半监督学习的范式一般是通过最小化有标注数据和无标注数据的联合损失:
\min _{\theta \in \Theta} \sum_{i=1}^{m} \mathcal{L}_{l}\left(x_{i}^{(1)}, y_{i}^{(1)} ; \theta\right)+\sum_{j=1}^{n} \Omega\left(x_{j}^{(\mathrm{u})} ; \theta\right)
其中\Omega是每个无标注样本的正则化项,通常指的是一致性损失(consistency loss):
\mathcal{L}_{c}(x ; \theta):=\sum_{k=1}^{C} p(x ; \theta)_{k} \cdot \log \left(h(\operatorname{pt}(x) ; \theta)_{k}\right)
其中\operatorname{pt}(x)代表不同的增强方式(对抗扰动、fixmatch中的强增强等等)。在本文,我们仅考虑\Omega=\mathcal{L}_{c}的情况。

深入挖掘一致性损失

  在本节,我们定义如下加权的一致性损失(无标注样本对扰动所输出的结果一致):
\mathcal{L}_{c w}(x ; w, \theta):=\sum_{k=1}^{C} w_{k} \cdot p(x ; \theta)_{k} \cdot \log \left(h(\operatorname{pt}(x) ; \theta)_{k}\right)
其中w_{k}是第k个类的权重。在完全监督学习中,此权重和每个类别的样本数量相关。然而在半监督学习中,无标注样本的类别未知。本文研究首先假设有标注数据和无标注数据类别分布一致,即\gamma_{l}=\gamma_{u},然后再研究有标注数据和无标注数据类别分布不一致的情况。

  首先是类别分布一致的情况,本文根据有标注数据类别的数量,设置成相反比例的权重:
w_{k} \propto 1 / E_{k}, \text { where } E_{k}=\left(1-\beta^{n_{k}}\right) /(1-\beta)
其中\beta是超参数,\beta=0代表相同的权重,\beta=1E_{k} \rightarrow n_{k}代表完全和类别数量成反比。如下图可见,这两种情况均产生了不理想的效果,然而平滑的(smoothed)权重却产生了较好效果,这个结果暗示了在一致性损失中添加更加平滑的权重的重要性。

  本文还采用了另一种平滑的函数:
w_{k} \propto \frac{n}{n+\nu \cdot n_{k}}, \text { for } k=1, \ldots, C
其中\nu是控制平滑程度的超参数,下图展示了其不同取值的表现:

针对未知分布数据的平滑自适应权重

  在现实世界中,有标注数据和无标注数据类别分布往往不一致,无标注数据的类别分布不可知。本节提出根据各类别学习难度来生成一致性损失中的权重,这种方法不依赖对无标注数据的任何假设。
  在Flexmatch中,作者提出对某一类的学习效果可以由预测为这个类别的样本数量反映出来。本文继承了这个观点,认为伪标签的数量能够反映出各类别在训练中的学习情况。本文将少数类视为难学习类,利用某类伪标签的数量来衡量模型对此类学习的难易度。更少的伪标签意味着更难学习的类。
  本文希望根据估计的各类学习难度减少伪标签的偏置误差。通过自适应地调整一致性损失的权重w^{(\mathrm{u})},鼓励模型关注于难学类别。对于总损失:
\mathcal{L}\left(\theta, w^{(\mathbf{u})}\right)=\sum_{i=1}^{m} \mathcal{L}_{l w}\left(x_{i}^{(l)}, y_{i}^{(l)} ; w^{(l)}, \theta\right)+\sum_{j=1}^{n} \mathcal{L}_{c w}\left(x_{j}^{(\mathrm{u})} ; w^{(\mathrm{u})}, \theta\right)
其中w^{(l)}由有标注数据的类别分布确定,w^{(\mathrm{u})}初始化为uniform。在第一个epoch训练完成后,利用更新后的模型生成one-hot的伪标签p(x, \theta),计算所有伪标签的类别分布n_{k}=\sum_{j=1}^{n} p\left(x_{j}, \theta\right)_{k}。随后利用前文介绍的平滑性函数(1-\beta) /\left(1-\beta^{n_{k}}\right)确定各类的权重项。为了避免n_{k}=0的情况,本文调整n_{k}\hat{n}_{k}=\max \left(n_{k}, 1\right)。整体算法框架如下图:

实验

CIFAR10-LT(有标注数据和无标注数据类别分布一致):


CIFAR10-LT(有标注数据和无标注数据类别分布不一致):


CIFAR100-LT和STL-10(有标注数据和无标注数据类别分布一致):


CIFAR10-LT(有标注数据和无标注数据类别分布一致),测试集是reversed的类别分布:


©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容