https://arxiv.org/abs/2008.05258 https://github.com/ZHKKKe/PixelSSL
https://paperswithcode.com/paper/guided-collaborative-training-for-pixel-wise-1
https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123580426.pdf
Guided Collaborative Training for Pixel-wise Semi-Supervised Learning
MODNet作者的另一篇工作。MODNet发表于2020.11,这篇工作发表于2020.8
摘要:我们研究了半监督学习(SSL)对不同像素级任务的推广。虽然SSL方法在图像分类中取得了令人印象深刻的效果,但由于其对密集输出的要求,将其应用于像素级任务的性能并不理想。此外,现有的像素级SSL方法仅适用于某些任务,因为它们通常需要使用特定于任务的属性。在本文中,我们提出了一个新的SSL框架,命名为引导协作训练(GCT),用于像素级任务,主要有两个技术贡献。首先,GCT通过一种新型的缺陷检测器来解决密集输出所引起的问题。第二,GCT中的模块通过两个新提出的独立于任务特定属性的约束从未标记数据中协作学习。因此,GCT可以应用于广泛的像素级任务,而无需结构调整。我们在四个具有挑战性的视觉任务上进行了大量的实验,包括语义分割、真实图像去噪、人像抠图和夜间图像增强,实验结果表明GCT的性能大大优于现有的SSL方法。
1. 引言
深度学习在许多视觉任务中取得了显著的成功。然而,为了训练而收集大量的标记数据是昂贵的,特别是对于需要对每个像素进行精确标记的像素级任务,例如语义分割中的类别掩码和图像去噪中的干净图片。近年来,半监督学习(SSL)通过添加未标记数据进行训练,成为缓解标签缺失的一个重要研究方向。许多SSL方法已经被提出用于图像分类,并取得了令人印象深刻的结果,包括基于对抗的方法[11,25,39,43],基于一致性的方法[21,23,36,41],以及与自监督学习相结合的方法[42,46]。相比之下,只有很少的工作将SSL应用于特定的像素级任务[7,19,20,30],他们主要关注语义分割。
在这项工作中,我们研究了SSL对不同像素级任务的泛化。这样的泛化对于SSL在新的视觉任务中使用非常重要。然而,推广现有的像素级SSL方法并不简单,因为它们是通过使用特定于任务的属性为特定任务设计的(2.2节),例如,假设输入和输出之间的语义内容相似。另一种可能的泛化方法是将用于图像分类的SSL方法应用于像素级任务。但是,如图1所示,密集输出导致了两个关键问题,导致这些方法在像素级任务上的性能不令人满意。
首先,密集输出需要的像素级预测置信度(2.3节)是难以估计的。像素级任务可以是像素级分类(如语义分割和阴影检测)或像素级回归(如图像去噪和抠图)。虽然在像素级分类中可以用最大分类概率来表示预测置信度,但在像素级回归中却不可用。第二,为SSL设计的现有扰动(2.4节)不适合密集输出。在像素级的任务中,输入中的强扰动(例如,剪切Mean Teacher[41])既会改变输入图像,也会改变真值标签。因此,来自同一原始图像的不同扰动,就会导致不同的标签,这在SSL中是不好的。此外,通过Dropout[40]的扰动在大多数像素级任务是不起作用的。尽管Dual Student[21]提出通过不同的模型初始化来产生扰动,其训练策略只能用于图像分类。
为了解决上述两个由密集输出引起的问题,我们提出了一个新的SSL框架,命名为Guided Collaborative Training(GCT),用于像素级任务。它包括三个模块:特定任务的两个模型(任务模型)和一个新的缺陷检测器。GCT克服了这两个问题:(1)通过缺陷检测器的输出(即缺陷概率图)近似像素级预测置信度;(2)将双学生任务中使用的扰动扩展到像素级任务。由于不同的模型初始化会导致对同一输入的预测不一致,因此我们可以在预测中集成可靠的像素(ensemble the reliable pixels),即缺陷概率较低的像素。此外,对缺陷概率图的最小化应有助于纠正预测中不可靠的像素。基于这些思想,我们引入了两个SSL约束,一个是任务模型之间的动态一致性约束,一个是缺陷检测器和每个任务模型之间的缺陷修正约束,这样就可以使GCT中的模块在缺陷概率图的指导下,而不是在任务特定属性的指导下,协同学习未标记的数据。因此,GCT可以应用于不同的像素级任务,只需替换任务模型而无需进行结构调整。
我们在语义分割(像素分类)和真实图像去噪(像素回归)的标准基准上评估GCT。我们也在我们自己的实际数据集上进行了实验,即含有大量未标记数据的数据集,用于人像抠图和夜间图像增强(都是逐像素回归),以证明GCT在实际应用中的普遍性。GCT超越了最新的SSL方法[19,41,46],可以应用于这四个具有挑战性的像素级任务。我们设想这项工作将有助于未来研究和开发具有稀缺标签的新视觉任务。
2相关工作
2.1用于图像分类的SSL
我们的工作与图像分类的SSL方法的两个主要分支相关。基于对抗的方法[11,25,39,43]从GAN[14]集成鉴别器,并尝试通过图像级对抗约束来匹配标记和未标记数据之间的潜在分布。基于一致性的方法[21,23,36,41]通过对不同扰动下的预测应用一致性约束,从未标记数据中学习。除此之外,一些最新的工作将自监督学习与SSL结合起来[42,46],或者通过插值有标记和无标记的数据来扩展训练集[5,6]。
2.2像素级任务的SSL
现有的像素级SSL研究主要集中在语义分割方面。GANs通过与图像分类中的SSL方法相结合,在该研究领域中占据了主导地位。例如,Hunget等人[19]提取可靠的预测来生成用于训练的伪标签。Mittalet等人[30]将Mean-Teacher[41]修改为多标签分类器,并将其用作过滤器,以去除不确定类别。此外,Lee等人[18]和Huang等人[24]研究SSL环境下的弱监督学习。然而,这些工作需要预定义的类别,这是基于分类的任务的一般属性。Chen等人[7]将SSL应用于人脸草图合成,属于像素回归。它将预先训练好的VGG[38]网络作为特征抽取器,对未标记的数据施加感知约束。遗憾的是,感知约束只能用于在输入和输出之间具有相似语义内容的任务中。例如,由于类别掩码的语义内容与输入图像不同,因此它不适用于分割。
2.3 SSL预测置信度
预测置信度是计算SSL约束所必需的,SSL约束将置信度较高的预测作为目标,即伪标签。早期的研究表明,对目标取平均有更高的置信度。例如,时态模型[23]将各个时期的预测作为目标进行累加;均值教师[41]通过指数移动平均法定义一个显式模型来生成目标;FastSWA[4]进一步对各个时期之间的模型进行平均,以生成更好的目标。其他文献[25,28,39]将最大分类概率作为预测置信度。
在像素级SSL中,鉴别器的输出用于近似预测置信度[19,30]。相反,我们建议使用缺陷检测器来估计预测置信度,有两个关键区别。首先,缺陷检测器预测具有位置信息的密集概率图,而鉴别器预测图像级概率。第二,利用有标签数据的真值生成缺陷检测器的目标。
2.4 SSL中的扰动
许多SSL方法严重依赖于扰动进行训练。基于一致性的方法[23,35,41]利用数据扩充来改变输入。为了进一步改善不一致性,VAT[31]生成虚拟对抗性噪声,而S4L[46]向输入添加旋转操作。其他如Mix-Match[6]和ReMixMatch[5]通过数据插值生成扰动样本。除了输入中的扰动外,Dropout还通过随机选择节点来扰动预测[34]。Dual Student[21]中的模型由于不同的初始化,对相同的输入有不一致的预测。
由于数据增加和数据丢失带来的扰动不适用于密集输出,GCT在产生扰动时遵循Dual Student。然而,与Dual Student不同,GCT通过基于缺陷检测器的两个SSL约束从未标记的数据中学习,从而允许GCT适用于不同的像素级任务。
3 有引导的协作训练
3.1 GCT概述
在本节中,我们首先对GCT进行概述。然后我们介绍了缺陷检测器和两个SSL约束。图2所述是GCT概览。和是两个任务模型,在下文中记为。的结构是任意的,GCT允许任务模型具有不同的结构。唯一的要求是和应该有不同的初始化,以形成它们之间的扰动(这与Dual Student相同)。表示缺陷检测器。半监督学习中,有标签数据记为,其标签记为;无标签数据记为。和的输入是完全一样的。给定输入,GCT框架首先预测,其形状为,其中根据具体任务而定。然后,将和拼接起来,输入给,以估计,其形状为。预测置信度图可以近似为。我们分两步迭代地训练GCT,就像GAN一样[14]。
第一步,我们用固定的来训练。对于有标签数据,预测值被其真值标签所监督:
其中是具体任务的约束,是像素索引。为了学习未标记数据,我们提出了动态一致性约束和缺陷修正约束,由缺陷概率图引导,分别在3.3节和3.4节介绍。的最终约束是3项约束的组合:
其中是有标签数据及其对应的标签。和是平衡两项的超参数。
第二步,从有标签数据学习。我们通过一个基于和的经典图像处理流程来计算的真值。在我们的框架中,是以均方误差MSE来训练的:
其中是的真值,会在3.2节介绍。
3.2 缺陷检测器
在有标签的数据上,缺陷检测器的目标是学习表示和之间差异的缺陷概率图,即中的缺陷区域。找到缺陷区域的一个简单方法是。然而,在许多任务上学习这个目标是困难的,因为它是稀疏和尖锐的(图3的第4列)。为了解决这个问题,我们引入了一个图像处理流程,它将转换为稠密概率图(图3的第5列)。由三个基本的图像处理操作组成:膨胀、模糊和归一化。为了估计未标记数据的缺陷概率图,我们应用了一个常见的SSL假设[48]:无标签数据的分布与有标签数据的分布相同。因此,在有标签数据上训练的也应该在无标签数据上很好地工作。
缺陷检测器的结构类似于[19]中的全卷积鉴别器。然而,由于的目标是图片级的表示真假的概率,因此在训练过程中对所有预测的像素进行平均以获得单个置信值。在像素级任务中,有的像素的预测是准确的,有的像素的预测则不准确。准确度高的像素则一般会有高的置信度,对置信度取平均以获得总体置信度是不合适的。例如,在小的局部区域中,可能比的置信度更高(更准确),尽管的平均预测置信度低于。因此,在逐像素任务中,每个像素预测置信度(来自缺陷检测器)比平均预测置信度(来自鉴别器)更有意义。图3显示了四个验证任务中和的结果。
3.3 动态一致性约束
GCT中的两个任务模型由于它们之间的扰动,对相同的输入有不一致的预测。我们使用动态一致性约束来集成和中的可靠像素。通常,标准一致性约束[23,41]是单向的,例如,从集成模型到临时模型。这里,“动态”表示我们的是双向的,其方向随缺陷概率而变化(图4a)。直观地说,如果中的一个像素具有更低的缺陷概率,我们将其视为中对应像素的伪标签。为了保证伪标签的质量,我们引入了一个缺陷阈值,对和中缺陷概率值都高于的像素禁用。通过这个过程,任务模型之间可以进行有效的知识交换,使它们成为协作者。
数学形式上,给定样本,GCT输出、,其通过前向传播对应的缺陷概率图为和。我们首先将的值归一化到[0,1],然后将像素值大于的都置为1:
是布尔值转整数的函数,条件为真时输出为1,否则为0。动态一致性约束定义为:
表示另一个任务模型,例如,如果那么。如果中的缺陷概率值小于和中相应的像素,将通过从学习该像素。我们使用MSE,因为它在SSL中被广泛使用,并且对于许多任务都通用。为了防止在训练开始时不可靠的知识交换,我们对应用了一个具有个epoch的余弦爬升操作(来自标准一致性约束)。
3.4 缺陷修正约束
除了,缺陷修正约束试图修正任务模型的不可靠预测(图4b)。的核心思想是迫使缺陷概率图中的值变为零。我们将的(同时是固定的)定义为:
我们使用二值掩膜在没有的像素(即同时在两个任务模型中预测不可靠的像素)上启用:
我们认为缺陷检测器有助于通过改进任务模型。对于仅包含一个任务模型和缺陷检测器的系统,目标可从式(3)和(6)中得出:
其中,和具有相同的分布。通过去掉像素求和运算,我们简化了公式(8)。在这种情况下,学习缺陷概率图,而使用零标签对其进行优化。如果我们假设训练过程在迭代中收敛到最优解,则有:
其中是当前迭代。因此,目标的值在训练过程中发生变化,并且在时等于。目标的一致性表明和在某种程度上是协作的。
为了说明和对抗性约束之间的区别,我们将公式(8)与LSGAN的目标进行了比较[29]。如果我们为SSL修改LSGAN,其目标应该是:
其中是试图区分和的标准鉴别器。相反,试图匹配和之间的分布。在这里,我们反转标签,即1表示假,0表示真,以与公式(8)一致。由于D的目标是常数,我们有:
也就是说,和在整个训练过程中是对立的。
4 实验
为了在不同比例的有标签数据下评估我们的框架,我们在语义分割和真实图像去噪的标准基准上进行了实验。同时,我们也在真实人像抠图和夜间图像增强数据集上进行了实验,证明了GCT在实际应用中的推广性。我们进一步进行烧蚀实验来分析GCT的各个方面。
实施细节。 我们将GCT与仅使用标记数据(SupOnly)训练的模型以及几种可应用于各种像素级任务的最新SSL方法进行了比较:(1)在[19]中提出的基于对抗的方法(AdvSSL);(2)基于一致性的平均教师(MT)[41];(3)自监督SSL(S4L)[46]。对于AdvSSL,我们去掉了要求分类概率的约束,使其与像素回归兼容。对于MT,我们使用MSE作为一致性约束。我们不添加高斯噪声作为额外的扰动,因为它会降低性能。对于S4L,在任务模型的末尾加入一个交叉熵训练的四类分类器来预测旋转角度。(0°,90°,180°,270°)
实验设置。 我们注意到,由于超参数的不一致性,现有的像素级SSL工作通常报告一个完全监督的基线,其性能低于原始文件。在图像分类中,文献[32]也讨论过类似的情况。为了公平地评估SSL的性能,我们定义了一些训练规则来改进SupOnly基线。我们将训练样本的总数表示为N=S×T×b,其中S是训练epoch数,T是每个每个epoch中的迭代次数,b是批量大小,批量大小在每个任务中是固定的。对于在标准基准上进行的实验:
(1) 我们根据原始文件中的超参数训练完全监督基线,以获得可比较的结果。在(2)和(3)中使用相同的超参数(S除外)。
(2) 我们使用与(1)相同的S来训练由有标签数据(SupOnly)监督的模型。虽然T随着有标签数据的减少而减小,但为了防止过度拟合,我们不会通过训练更多的epoch来增加N。
(3) 我们调整S,以确保SSL实验中的N与(1)相同。在SSL实验中,每个批量都包含有标签数据和无标签数据。我们将一个epoc定义为一次遍历所有无标签数据。同时,在一个epoch内,标签数据会被重复训练多次。
通过遵循这些规则,SupOnly基线获得足够好的性能,并且不会过度拟合。SSL方法训练的模型具有与完全监督基线相同的计算开销,即相同的N。在实际数据集的实验中,我们首先对SupOnly基线进行了S个epoch的训练。然后,我们用相同的S训练SSL模型。我们使用网格搜索来为所有SSL方法找到合适的超参数。
4.1语义分割实验
语义分割[9,10,27]以图像为输入,并预测一系列类别掩码,该类别掩码将输入图像中的每个像素映射到类别(图3a)。我们在PASCALVOC 2012数据集[12]上进行了实验,该数据集包含20个前景类和1个背景类。将分割边界数据集(SBD)[16]作为附加有标签数据集,以扩充数据集。因此,我们有10582个训练样本和1449个验证样本。训练期间,在随机缩放和水平翻转之后,输入图像将裁剪为321x321。和之前的工作[19,30]一样,我们使用以ResNet-101[17]为主干的DeepLab-v2[9],作为SupOnly基线和SSL方法中的任务模型。除了多尺度融合技巧外,还采用了与DeepLab-v2原始文件相同的配置。
对于SSL,我们从所有数据中随机抽取1/16、1/8、1/4、1/2个样本作为有标签数据,并将训练集的其余部分用作无标签数据。注意,所有SSL方法都使用相同的数据拆分。表1显示了在PASCAL VOC 2012数据集上的mIOU,模型是在Microsoft COCO数据集[26]上预训练过的。GCT相比监督基线性能提高了1.26%(1/2标签下)至3.76%(1/8标签下)。此外,我们的完全监督基线(75.32%)与DeepLab-v2的原始文件(75:14%)具有可比性,优于[19]中重新报告的结果(73.6%)。因此,所有SSL方法在完整标签下都有轻微的改进。
4.2 图像去噪实验
真实图像去噪[3,15,47]是一项致力于从输入自然图像中去除真实噪声而非合成噪声的任务(图3b)。我们在SIDD数据集[1]上进行了实验,这是真实图像去噪的最大基准之一。它包含160个图像对(噪声图像和干净图像)用于训练,40个图像对用于验证。我们将每个图像对分割成大小为256x256的多个图像块进行训练。训练样本总数约为3万个。我们使用在NTRIE 2019真实图像去噪挑战赛中获得第二名的方法DHDN[33]作为任务模型,因为第一名获得者的代码尚未公布。采用峰值信噪比(PSNR)作为验证指标。
在图像去噪中,即使预测值与真值之间存在很小的误差,也会产生明显的视觉伪影。这意味着很难获得可靠的伪标签,也就是说,这个任务对于SSL来说很困难。我们注意到GCT中具有相同架构的任务模型具有相似的预测。因此,来自不同初始化的扰动不够强。为了缓解这个问题,我们将其中一个任务模型替换为在NTRIE 2019挑战赛中获得第三名的DIND[45]。我们仍然使用DHDN进行验证。
我们随机抽取1/16,1/8,1/4,1/2个标记图像对用于SSL。如表2所示,我们的完全监督基线达到39.38dB(PSNR),这与SIDD基准上的顶级结果相当。虽然SSL在这个困难的任务中表现出有限的性能,但是GCT在所有标记比率下都优于其他SSL方法。值得注意的是,GCT在1/16个标签(只有10个标签图像对)的情况下,PSNR提高了0.61dB,而以前的SSL方法在PSNR上最多提高了0.33dB。
4.3 人像抠图实验
抠图[37,44]是根据图像和预定义trimap预测前景的matte。matte中的每个像素值都是[0,1]之间的概率。我们关注的是人像抠图,它在智能手机上有着重要的应用,例如背景虚化。在图3c中,通过将trimap的未知区域内的像素设置为灰色,将trimap合并到输入图像中以进行可视化。由于没有开源基准,我们首先从Flickr收集了8000张肖像图像。然后,我们根据预训练的分割模型生成trimap。之后,我们选择了300张有精细细节的图像,用PhotoShop处理(每张图像大约20分钟)。最后,我们将100幅有标签图像与7700幅无标签图像作为训练集,其余200幅有标签图像作为验证集。对于每个有标签图像,我们通过随机裁剪生成15个样本,通过背景替换(使用OpenImage dataset[22])生成35个样本。对于每个无标签图像,我们通过随机裁剪生成5个样本。本文的任务模型结构是从[44]中得来的,这是抠图任务的一个里程碑。
在本任务中,我们通过在两种配置上进行实验来验证用SSL使用无标签数据的影响。用100幅有标签的图像,(1)随机选取一半(3850幅)无标签图像进行训练,(2)用所有(7700幅)无标签图像进行训练。如表3所示,对于3850和7700个无标签图像,GCT比监督基线分别提高了1.96dB和3.99dB。这表明通过增加无标签数据量可以有效地提高SSL性能。此外,与现有的SSL方法相比,使用GCT将无标签图像的数量增加一倍可以实现更显著的改进(2.03dB)。
4.4 夜景增强
夜间图像增强[8,13]是另一种常见的视觉应用。此任务调整夜间图像中的信道系数以显示更多细节(图3d)。我们的数据集包含1900张智能手机拍摄的夜间图像,其中400张图像使用Photoshop进行标记(大约15分钟/张)。我们将200幅标记图像与1500幅未标记图像结合起来进行训练,并使用另外200幅标记图像进行测试。我们在训练期间使用水平翻转、轻微旋转和随机裁剪(到512x512)作为数据增强。我们把HDRNet[13]作为任务模型。由于数据集很小,我们只使用一个SSL配置进行了测试(表3)。与其他三个任务中的实验类似,GCT优于现有的SSL方法。
4.5 消冗实验
我们进行了烧蚀研究,分析了所提出的SSL约束、GCT中的超参数以及缺陷检测器和Mean-Teacher的组合。
SSL约束条件的影响 默认情况下,GCT同时通过两个SSL约束从未标记的数据中学习。在图5中,我们在语义分割和真实图像去噪的基准上比较了仅使用一个SSL约束训练GCT的实验。结果表明,和都是有效的。带的GCT带来的性能提升是令人印象深刻的,证明了两个任务模型之间的知识交换是可靠和有效的。同时,带的GCT的曲线表明,缺陷检测器对未标记数据的学习也起着至关重要的作用。此外,将与相结合可以使GCT获得最佳性能。
GCT中的超参数。 我们分析了GCT所需的两个超参数(3.3节):缺陷阈值和的余弦上升时间,在PASCAL VOC基准上进行1/8标签的语义分割。表4(左)显示了不同的下的结果,它控制了两个SSL约束的组合。具体而言,当=0.0时只应用,而当=1.0时只应用。实验表明,可以大致设置,例如,适合于语义分割。个epoch的余弦上升防止了不可靠的知识交换,这是由于早期训练阶段缺陷检测器未收敛造成的。表4(右)中的结果表明,GCT对具有鲁棒性,尽管余弦上升对于最佳性能是必要的。
缺陷检测器与MT相结合。 MT中的一致性约束从教师模型应用到学生模型。但是,教师模型在某些像素上可能比学生模型差,这可能会导致性能下降。为了避免这个问题,当教师预测的缺陷概率大于学生预测的缺陷概率时,我们使用缺陷检测器来禁用一致性约束。在1/8个标签的情况下,该方法在Pascal VOC上将MT的mIOU值从69:81%提高到70.47%,在SIDD上将MT的PSNR值从38.22dB提高到38.42dB。
5 结论
我们研究了SSL在不同像素级任务中的推广,指出了现有SSL方法在这些任务中的不足,据我们所知,我们是第一个。我们为像素级SSL提供了一个新的通用框架GCT。我们的实验证明了它在各种视觉任务中的有效性。同时,我们还注意到,对于需要高精度伪标签的任务(如图像去噪),SSL的性能仍然有限。一个可能的未来工作是调查这个问题,并探索如何创建更准确的伪标签。