中间特征可视化

中间特征可视化

参考: PyTorch | 提取神经网络中间层特征进行可视化

    def get_feature(self):
        input=self.process_image()
        print(input.shape)  
        x=input
        for index,layer in enumerate(self.pretrained_model):
            x=layer(x)
            if (index == self.selected_layer):
                return x

参考: pytorch模型中间层特征的提取

class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers
 
    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
        # 目前不展示全连接层
            if "fc" in name: 
                x = x.view(x.size(0), -1)
            print(module)
            x = module(x)
            print(name)
            if name in self.extracted_layers:
                outputs.append(x)
        return outputs

综上, 完整代码 ⤵️

import cv2
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import torch.nn as nn
from PIL import Image


simple_transform = transforms.Compose([transforms.Resize((224, 224)),
                                       transforms.ToTensor(),  # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255
                                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
                                       ])


# 中间特征提取
class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layer):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layer

    def forward(self, x):
        for name, module in self.submodule._modules.items():
            if name is "fc":
                x = x.view(x.size(0), -1)
            x = module(x)
            print("moudle_name", name)
            if name in self.extracted_layers:
                return x


class VisualSingleFeature():
    def __init__(self, extract_features, save_path):
        self.extract_features = extract_features
        self.save_path = save_path

    def get_single_feature(self):
        print(self.extract_features.shape)  # ex. torch.Size([1, 128, 112, 112])

        extract_feature = self.extract_features[:, 0, :, :]
        print(extract_feature.shape)  # ex. torch.Size([1, 112, 112])

        extract_feature = extract_feature.view(extract_feature.shape[1], extract_feature.shape[2])
        print(extract_feature.shape)  # ex. torch.Size([112, 112])

        return extract_feature

    def save_feature_to_img(self):
        # to numpy
        extract_feature = self.get_single_feature().data.numpy()
        # use sigmod to [0,1]
        extract_feature = 1.0/(1+np.exp(-1*extract_feature))
        # to [0,255]
        extract_feature = np.round(extract_feature*255)
        print(extract_feature[0])
        # save image
        cv2.imwrite(self.save_path, extract_feature)


def single_image_sample():
    img_path = './snorlax.png'
    input_img = Image.open(img_path).convert('RGB')  # 读取图像
    input_tensor = simple_transform(input_img)
    print(input_tensor.shape)  # torch.Size([3, 224, 224])
    x = input_tensor[np.newaxis, :, :, :]
    print(x.shape)  # torch.Size([1, 3, 224, 224])
    return x
snorlax.png
    # test VGG16
    x = single_image_sample()
    for target_layer in range(0, 31):
        pretrained_module = models.vgg16(pretrained=True).features
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_layer))
        target_features = myexactor(x)
        savepath = './VGG16/layer_{}.jpg'.format(target_layer)  # 需手动创建文件夹`./VGG16`
        VisualSingleFeature(target_features, savepath).save_feature_to_img()
    # test resnet50 sequential
    x = single_image_sample()
    for target_sequential in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']:
        pretrained_module = models.resnet50(pretrained=True)
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
        target_features = myexactor(x)
        savepath = './Resnet50/{}.jpg'.format(target_sequential)
        VisualSingleFeature(target_features, savepath).save_feature_to_img()
    # test resnet50 layer1
    x = single_image_sample()
    pretrained_model = models.resnet50(pretrained=True)
    pretrained_module = pretrained_model.layer1

    for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
        pre_module = getattr(pretrained_model, pre_sequential)
        x = pre_module(x)

    for target_Bottleneck_index in range(0, 3):
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_Bottleneck_index))
        target_features = myexactor(x)
        savepath = './Resnet50/layer1/{}.jpg'.format(target_Bottleneck_index)
        VisualSingleFeature(target_features, savepath).save_feature_to_img()
    # test resnet50 layer1 Bottleneck0
    x = single_image_sample()
    pretrained_model = models.resnet50(pretrained=True)
    pretrained_module = pretrained_model.layer1._modules['0']

    for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
        pre_module = getattr(pretrained_model, pre_sequential)
        x = pre_module(x)

    for target_sequential in ['conv1', 'bn1', 'conv2', 'bn2', 'conv3', 'bn3', 'relu']:
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
        target_features = myexactor(x)
        savepath = './Resnet50/layer1/Bottleneck0/{}.jpg'.format(target_sequential)
        VisualSingleFeature(target_features, savepath).save_feature_to_img()


优化:visual all feature outputs, not only single feature

 # [基于Pytorch的特征图提取](https://blog.csdn.net/ZOUZHEN_ID/article/details/84025943)
 # 特征输出可视化
    for i in range(feature_channel_number):
        ax = plt.subplot(6, 6, i + 1)
        ax.set_title('Feature {}'.format(i))
        ax.axis('off')
        plt.imshow(target_features.data.numpy()[0,i,:,:],cmap='jet')

    plt.show()

上述中间特征提取方法的“致命”缺点:Sequential只能表示简单的顺序模型结构(连U-Net的的跳跃连接都表示不了),复杂模型无力,这部分最好用register_forward_hook来做! ⤵️

from torchvision import transforms
import numpy as np

simple_transform = transforms.Compose([transforms.Resize((224, 224)),
                                       transforms.ToTensor(),  # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255
                                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
                                       ])


def get_single_feature(extract_features):
    print(extract_features.shape)
    extract_feature = extract_features[:, 0, :, :]
    print(extract_feature.shape)
    extract_feature = extract_feature.view(extract_feature.shape[1], extract_feature.shape[2])
    print(extract_feature.shape)

    return extract_feature


def save_feature_to_img(extract_features, save_path):
    extract_feature = get_single_feature(extract_features).data.numpy()
    extract_feature = 1.0/(1+np.exp(-1*extract_feature))
    extract_feature = np.round(extract_feature*255)
    print(extract_feature[0])
    import cv2
    cv2.imwrite(save_path, extract_feature)
    print('{} saved.'.format(save_path))


def hook(module, input, output):
    save_feature_to_img(output.data, save_path)


if __name__ == '__main__':

    from PIL import Image
    img_path = './snorlax.png'
    input_img = Image.open(img_path).convert('RGB')
    input_tensor = simple_transform(input_img)
    x = input_tensor[np.newaxis, :, :, :]
    
    target_layer = 5
    save_path = './VGG16/layer_{}.png'.format(str(target_layer))

    import torchvision.models as models
    net = models.vgg16(pretrained=True)
    handle = getattr(net.features, '_modules')[target_layer].register_forward_hook(hook)
    _ = net(x)
    handle.remove()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,366评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,521评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,689评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,925评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,942评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,727评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,447评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,349评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,820评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,990评论 3 337
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,127评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,812评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,471评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,017评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,142评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,388评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,066评论 2 355