CAM系列(二)之Grad-CAM(原理讲解和代码实现)

上篇文章介绍了CAM的开篇之作CAM系列(一)之CAM(原理讲解和PyTorch代码实现),本篇接着介绍泛化性和通用性能更好的Grad-CAM。

Grad-CAM的提出背景:CAM揭示了卷积神经网络分类模型中图像的空间特征与其类别权重之间的联系,然鹅,CAM只适用于模型中有全局平均池化层并且只有一个全连接层(即输出层)的情形,如ResNet,MobileNet等。因为CAM算法中生成类激活图所需要的类别权重,即为全局平均池化层和全连接输出层之间的,对应着图片类别的权重。对于VggNet,DenseNet等有着多个全连接层的模型,CAM则不再适用,因为无法获取到类别权重。为了解决这一问题,Grad-CAM应运而生。

左:满足CAM使用场景的模型;右:不满足CAM使用场景的模型

如上图所示,当模型包含全局平均池化层,且只有一个全连接层时,其输出的某一类别与各特征图通道有着明确的一一对应关系;当模型不包含全局平均池化层,且有着多个全连接层时,其输出的某一类别与特征图的关系不再明确,因此CAM不再适用。

Grad-CAM原理:顾名思义,就是采用梯度的CAM算法。哪里需要梯度呢?根据上一篇CAM系列(一)之CAM对类激活图算法原理的讲解,该算法实际有两个关键要素:

  1. 输入图像经过CNN处理后,其最后一层卷积层的输出特征图;
  2. 与输入图像类别相关的、数量和特征图通道数一致的权重。

有了这两个要素,只需要将二者对应相乘相加即可得到类别显著图。

对于VggNet,DenseNet等这一类模型,因为应用CAM时不符合上述第二点要素,因此,Grad-CAM另辟蹊径,采用特征图的梯度信息生成了相应的权重。需要注意的是:这里的梯度不是模型训练时由 Loss反向传播计算得到的梯度,而是模型输出的类别置信分数反向传播计算得到的梯度。因为这里的权重必须包含类别信息才有意义,如CAM中用的权重就是直接生成图像相应的类别分数的权重。

为什么必须包含类别信息呢?再举一个例子,下图中包含两个类别的目标,使用CAM定位时,如何决定是定位猫,还是定位狗呢?如果使用与输出类别中的猫相关的信息获得的权重筛选出来的特征图高亮了猫所在的区域,而用输出类别中编码狗的类别分数信息获得的权重筛选出来的特征图高亮了狗所在的区域,那么就很好的解释了CNN是如何判别猫或者是狗的。因此,CAM算法中的权重其实有着通过类别信息来筛选相应的特征图的功能。

一幅图像中包含多个类别的目标

Grad-CAM就是将图像输入CNN,先前向传播获得第一个要素——最后一层的输出特征图(维度为[C, H, W]),并获得模型输出的类别 logits(未经softmax映射)。然后利用待定位的类别logit(如猫的logit为2.35)进行反向传播,获得最后一层输出特征图关于这个类别分数的梯度(维度为[C, H, W])。最后对特征图梯度的空间维度计算平均值(维度变为[C, ]),得到第二个要素——与类别信息有关且与特征图通道数一致的权重。其它步骤都和CAM相同,也就是说,Grad-CAM只是提出了一种更加通用的权重获取方法。

Grad-CAM代码实现:

本文以PyTorch自带的VGG11-BN为例,分步骤讲解并用代码实现Grad-CAM的整个流程和细节。
Grad-CAM前面的几个实现步骤与CAM相同,这里照搬。

1.准备工作

首先导入需要用到的包:

import math
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from typing import Optional, List
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
from torch import Tensor
from matplotlib import cm
from torchvision.transforms.functional import to_pil_image

定义输入图片路径,和保存输出的类激活图的路径:

img_path = '/home/dell/img/1.JPEG'     # 输入图片的路径
save_path = '/home/dell/cam/CAM1.png'    # 类激活图保存路径

定义输入图片预处理方式。由于本文用的输入图片来自ILSVRC-2012验证集,因此采用PyTorch官方文档提供的ImageNet验证集处理流程:

preprocess = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
2.获取CNN最后一层卷积层的输出特征图

本文选用的CNN模型是PyTorch自带的VGG-11-BN,首先导入预训练模型:

net = models.vgg11_bn(pretrained=True).cuda()   # 导入模型
# print(net)

由于特征图是模型前向传播时的中间变量,不能直接从模型中获取,需要用到PyTorch提供的hook工具,补课请参考我的这两篇博客:hook1hook2

通过输出模型(print(net))我们就能看到VGG11-BN输出最后一层特征图的层名为net.features。我们用hook工具注册这一层,以便获得它的输出特征图:

feature_map = []     # 建立列表容器,用于盛放输出特征图

def forward_hook(module, inp, outp):     # 定义hook
    feature_map.append(outp)    # 把输出装入字典feature_map

net.features.register_forward_hook(forward_hook)    # 对net.layer4这一层注册前向传播

由于Grad-CAM需要获取最后一层卷积层输出特征图的梯度,梯度也是中间变量,需要用到hook工具注册获取:

grad = []     # 建立列表容器,用于盛放特征图的梯度

def backward_hook(module, inp, outp):    # 定义hook 
    grad.append(outp)    # 把输出装入列表grad

net.features.register_full_backward_hook(backward_hook)    # 对net.features这一层注册反向传播

做好了前向和后向hook的定义和注册工作,现在只需要对输入图片进行预处理,然后分别执行一次模型前向传播和反向传播,即可获得CNN最后一层卷积层的输出特征图及其梯度:

orign_img = Image.open(img_path).convert('RGB')    # 打开图片并转换为RGB模型
img = preprocess(orign_img)     # 图片预处理
img = torch.unsqueeze(img, 0)     # 增加batch维度 [1, 3, 224, 224]

out = net(img.cuda())     # 前向传播 
cls_idx = torch.argmax(out).item()    # 获取预测类别编码 
score = out[:, cls_idx].sum()    # 获取预测类别分数
net.zero_grad() 
score.backward(retain_graph=True)    # 由预测类别分数反向传播

这时我们想要的特征图已经装在列表feature_map中了,特征图的梯度装在了列表grad中。

3.获取权重

如前文所述,Grad-CAM所需的权重是特征图关于类别分数的梯度的空间平均值,因此我们只需要对上一步获得的特征图梯度在空间上求平均即可获得权重。

由于我也不知道这张图的类别标签,这里假设模型对这张图像分类正确,我们来获得其输出类别所对应的权重:

weights = grad[0][0].squeeze(0).mean(dim=(1, 2))    # 获得权重  
4.对特征图的通道进行加权叠加,获得Grad-CAM
grad_cam = (weights.view(*weights.shape, 1, 1) * feature_map[0].squeeze(0)).sum(0)
5.对Grad-CAM进行ReLU激活和归一化

我们首先定义归一化函数:

def _normalize(cams: Tensor) -> Tensor:
        """CAM normalization"""
        cams.sub_(cams.flatten(start_dim=-2).min(-1).values.unsqueeze(-1).unsqueeze(-1))
        cams.div_(cams.flatten(start_dim=-2).max(-1).values.unsqueeze(-1).unsqueeze(-1))

        return cams

然后对类激活图执行ReLU激活和归一化,并利用PyTorch的 to_pil_image函数将其转换为PIL格式以便下步处理:

grad_cam = _normalize(F.relu(grad_cam, inplace=True)).cpu()
mask = to_pil_image(grad_cam.detach().numpy(), mode='F')
6.将类激活图覆盖到输入图像上,实现目标定位

我们将两个图像交叠融合的过程封装成了函数:

def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = 'jet', alpha: float = 0.6) -> Image.Image:
    """Overlay a colormapped mask on a background image

    Args:
        img: background image
        mask: mask to be overlayed in grayscale
        colormap: colormap to be applied on the mask
        alpha: transparency of the background image

    Returns:
        overlayed image
    """

    if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image):
        raise TypeError('img and mask arguments need to be PIL.Image')

    if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
        raise ValueError('alpha argument is expected to be of type float between 0 and 1')

    cmap = cm.get_cmap(colormap)    
    # Resize mask and apply colormap
    overlay = mask.resize(img.size, resample=Image.BICUBIC)
    overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 1:]).astype(np.uint8)
    # Overlay the image with the mask
    overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))

    return overlayed_img

接下来就是激动人心的时刻了!!!将类激活图作为掩码,以一定的比例覆盖到原始输入图像上,生成类激活图:

result = overlay_mask(orign_img, mask) 

这里的变量result已经是有着PIL图片格式的类激活图了,我们可以通过:

result.show()

可视化输出,也可以通过:

result.save(save_path)

将图片保存在本地查看。我们在这里展示一下输入图像和输出定位图像的对比:


(左)输入图像;(右)定位图像
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,589评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,615评论 3 396
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,933评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,976评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,999评论 6 393
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,775评论 1 307
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,474评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,359评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,854评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,007评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,146评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,826评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,484评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,029评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,153评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,420评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,107评论 2 356

推荐阅读更多精彩内容