原文及参考
原文地址:直接点我就能跳转
初识
作者认为,当前主流的半监督学习方法(semi-supervised learning, SSL)都可以归结于:在无label数据上构造人为标签。
比如伪标签(也叫self-training
),使用模型的预测作为label再次进行训练;一致性约束,对输入和模型进行随机扰动(增广 / dropout)后的输出结果作为人为标签(希望两者趋于一致)。
关于半监督学习的了解,我觉得这篇博文总结得非常好:https://www.jarvis73.com/2021/08/07/Semi-Supervised-Learning/
也出现了一些工作融合了两种思想一起进行SSL,但都比较复杂【并且是越来越复杂,比如Mixmatch和RemixMatch】。而本文就是打破这种趋势,提出了FixMatch,一种更简单并且更精确的SSL方法。其能在CIFAR-10上只用250个标注样本达到94.93%,甚至比全监督学习下还要高,哪怕每个类只用四个样本,也能训练得到不错的性能。
相知
主要算法
FixMatch的核心步骤如上图所示,对unlabled的输入图像分别进行弱增广Weakly-augmented
和强增广Strongly-augmented
,送入模型中得到两个logits,对于弱增广会根据置信度阈值构造伪标签,并将此伪标签与强增广样本计算交叉熵loss。整个过程包含了伪标签和一致性约束两种思想,对弱增广构造伪标签去与强增广输出进行一致性约束,下面介绍更多细节。
构造伪标签时,会选择一个阈值τ,只有最大置信度高于这个阈值才构造伪标签
存在有标签数据
X
X
X和无标签数据
U
U
U,所以构造损失函数包含两部分:① 监督损失
L
s
L_s
Ls,只对数据
X
X
X计算,采用最常用的交叉熵损失;② 无监督损失
L
u
L_u
Lu,对数据
U
U
U进行计算,如上文说的那样,进行两次增广(弱增广+强增广),用弱增广构造的伪标签去和强增广计算交叉熵。
最后两个loss会用一个超参数
λ
u
\lambda_u
λu进行加权,值得注意的是,这个超参数全程都是固定不变的,这与之前一致性约束的方法不一样。之前的方法会用一些ramp-up/down的策略去逐渐增加
L
u
L_u
Lu,但作者发现这在FixMatch中是不必要的,因为通过另外一个阈值τ完成了这个任务,在训练初期阶段,置信度高于τ的样本会比较少,随着训练的进行,样本数也会越来越多。
强弱增广:在实验中,对于弱增广采用了"翻转和平移";对于强增广作者采用了RandAugment和CTAugment,这两者都是一种AutoAugment的做法,具体实验原理博主并不太清楚,主要就是随机采样策略,从数十种增强类别和一系列增强幅度中选择一部分用于图像增强,具体可以参照论文附录E。总之,其核心就是让强增广后的图像变得难以确认,但仍保留足够的语义信息。
下面展示了某个库的实现代码,我们发现就是在AugmentList中随机选取固定数量的增广策略,最后再加上Cutout
顺便再插播一个题外话,就是很多非官方库按照论文进行复现,效果比论文高了5~7个点,而且不是偶然现象hhhh(主要是在4 label实验上)
一些细节:作者也提到对于SSL来说,训练策略特别重要(优化器、正则化项等)。因此在论文中也讨论了具体的实验设置,最后反向SGDM效果最好,使用cosine学习率衰减策略,并且使用了EMA来报告结果(详情见原论文)。
部分实验:
上表展示了FixMatch在不同数据集上的实验结果,效果还是非常不错的。特别是在每个类别只提供少量标签时(xx labels表示这个数据集一共只提供xx个带标签的数据),比其他方法表现得更优秀。
并且在CIFAR-10上实验,每个类别只提供一个labeled data,也能达到78%的准确率。
回顾
FixMatch同时引入了伪标签和一致性的思想,其中伪标签
操作指的是:选取阈值大于τ的弱增广图像logits作为软标签;一致性
指的是:使用弱增广的软标签与强增广输出计算交叉熵损失。FixMatch思路简单,效果也很好,属于一个里程碑式的工作,在其他领域也有很多利用了这种策略进行半监督学习的工作(目标检测、关键点估计)。
pytorch实现的代码:https://github.com/kekmodel/FixMatch-pytorch