对比学习+语义分割:Cross-Image Pixel Contrasting for Semantic Segmentation

论文名称:Cross-Image Pixel Contrasting for Semantic Segmentation

这是一种将对比学习运用到全监督语义分割里的方法,主要起到辅助训练的作用,在实际推理部署的时候,原本用来对比学习的分支是去除的。

主要解决的问题:

  • 模型训练的时候只考虑当前一张图像的内容,无法站在整个数据集的内容上考虑问题。

创新点:

  • 基于像素到像素的对比(pixel to pixel)+像素到区域的对比(pixel to region),设计了更加高效的memory bank。
  • 设计了一种更加合理的难样本采样策略,Segmentation-Aware Hard Anchor Sampling。

整体结构

整体结构

通过上图可以看到,相比常规的全监督语义分割结构,本文的方法只是额外增加了一条用于对比计算的辅助分支,该分支在实际推理部署的时候是去除的,所以对于语义分割模型本身是不增加推理负担的。

Memory Bank

这个东西就是一个数据池,保存了历史数据用于对比计算,这里面保存的都是经过了模型特征提取后的D维特征,本文D=256。

Pixel to Pixel

这个是针对整个数据集的图像来操作的,就是对所有类设置一个专属的队列(整个数据集有多少个类就有多少个队列,比如COCO数据集有80类,那么就有80个队列),训练的时候从每个mini-batch中的每个类选取V个D维像素加入到对应类的队列T里,T是远大于V的。一旦T被装满了,那么就去旧留新。通过这种方式Memory Bank能动态存储绝大部分图像的内容特征。

num_pixel = idxs.shape[0]
perm = torch.randperm(num_pixel)    #随机选择一定像素
K = min(num_pixel, self.pixel_update_freq)    #跟预设值比较,减少代码出错的操作
feat = this_feat[:, perm[:K]]
feat = torch.transpose(feat, 0, 1)
ptr = int(pixel_queue_ptr[lb])

if ptr + K >= self.memory_size:     #队列满了则去旧留新
    pixel_queue[lb, -K:, :] = nn.functional.normalize(feat, p=2, dim=1)
    pixel_queue_ptr[lb] = 0
else:
    pixel_queue[lb, ptr:ptr + K, :] = nn.functional.normalize(feat, p=2, dim=1)
    pixel_queue_ptr[lb] = (pixel_queue_ptr[lb] + 1) % self.memory_size  

上面的源码可以大致看到,首先会随机选择一定数量的像素加入到队列,如果队列满了则旧的数据会被新的数据代替。

Pixel to Region

就是将区域的一大块特征用一个像素点的特征去表示,主要是用来弥补pixel to pixel 采样不充分的问题,这个方法是针对一张图像的操作,将多个一张图像的特征拼接到一起训练就能获取全局信息。怎么操作的呢?比如一张图像上有3个地方是猫的区域,首先对这个3块区域在XY坐标上进行求平均,最后变为3个D维的像素特征(D,1,1),然后再对这个3个像素点在对应通道维度上求平均,最后当前图像的3只猫就被一个D维的像素点表示。

# segment enqueue and dequeue
feat = torch.mean(this_feat[:, idxs], dim=1).squeeze(1)
ptr = int(segment_queue_ptr[lb])
segment_queue[lb, ptr, :] = nn.functional.normalize(feat.view(-1), p=2, dim=0)
segment_queue_ptr[lb] = (segment_queue_ptr[lb] + 1) % self.memory_size

这个方法是训练前期开辟一大块内存,然后每一个mini-batch都会加入一定的特征进去,越到训练后期特征越多,等一轮训练完成后就清空,再重新开始。

困难样本采样策略Segmentation-Aware Hard Anchor Sampling

这个采样策略其实很简单,相比现有的难负挖掘采样策略,它是随机采一半困难样本,剩下的一半就随机采样,这剩下的一本里面应该既有困难样本也有简单样本,这样做的目的是防止全部使用困难样本训练导致过拟合。举个例子,当前是一个类别为猫的像素特征,首先会从memory bank中选择512个D维的像素点,这些像素点属于狗、羊等其他跟猫特征接近的动物,或者是猫的但经常分类分错的像素点,再随机从memory bank中选择一些像素点放一起,源码中是总共选择1024个点用于跟当前的限度点进行损失计算。怎么确定是困难样本还是简单样本呢?就是通过模型的mask图的像素值跟label值对不对的上,mask值跟label值匹配就是困难样本。

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容