中间特征可视化

中间特征可视化

参考: PyTorch | 提取神经网络中间层特征进行可视化

    def get_feature(self):
        input=self.process_image()
        print(input.shape)  
        x=input
        for index,layer in enumerate(self.pretrained_model):
            x=layer(x)
            if (index == self.selected_layer):
                return x

参考: pytorch模型中间层特征的提取

class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers
 
    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
        # 目前不展示全连接层
            if "fc" in name: 
                x = x.view(x.size(0), -1)
            print(module)
            x = module(x)
            print(name)
            if name in self.extracted_layers:
                outputs.append(x)
        return outputs

综上, 完整代码 ⤵️

import cv2
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import torch.nn as nn
from PIL import Image


simple_transform = transforms.Compose([transforms.Resize((224, 224)),
                                       transforms.ToTensor(),  # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255
                                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
                                       ])


# 中间特征提取
class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layer):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layer

    def forward(self, x):
        for name, module in self.submodule._modules.items():
            if name is "fc":
                x = x.view(x.size(0), -1)
            x = module(x)
            print("moudle_name", name)
            if name in self.extracted_layers:
                return x


class VisualSingleFeature():
    def __init__(self, extract_features, save_path):
        self.extract_features = extract_features
        self.save_path = save_path

    def get_single_feature(self):
        print(self.extract_features.shape)  # ex. torch.Size([1, 128, 112, 112])

        extract_feature = self.extract_features[:, 0, :, :]
        print(extract_feature.shape)  # ex. torch.Size([1, 112, 112])

        extract_feature = extract_feature.view(extract_feature.shape[1], extract_feature.shape[2])
        print(extract_feature.shape)  # ex. torch.Size([112, 112])

        return extract_feature

    def save_feature_to_img(self):
        # to numpy
        extract_feature = self.get_single_feature().data.numpy()
        # use sigmod to [0,1]
        extract_feature = 1.0/(1+np.exp(-1*extract_feature))
        # to [0,255]
        extract_feature = np.round(extract_feature*255)
        print(extract_feature[0])
        # save image
        cv2.imwrite(self.save_path, extract_feature)


def single_image_sample():
    img_path = './snorlax.png'
    input_img = Image.open(img_path).convert('RGB')  # 读取图像
    input_tensor = simple_transform(input_img)
    print(input_tensor.shape)  # torch.Size([3, 224, 224])
    x = input_tensor[np.newaxis, :, :, :]
    print(x.shape)  # torch.Size([1, 3, 224, 224])
    return x
snorlax.png
    # test VGG16
    x = single_image_sample()
    for target_layer in range(0, 31):
        pretrained_module = models.vgg16(pretrained=True).features
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_layer))
        target_features = myexactor(x)
        savepath = './VGG16/layer_{}.jpg'.format(target_layer)  # 需手动创建文件夹`./VGG16`
        VisualSingleFeature(target_features, savepath).save_feature_to_img()
    # test resnet50 sequential
    x = single_image_sample()
    for target_sequential in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']:
        pretrained_module = models.resnet50(pretrained=True)
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
        target_features = myexactor(x)
        savepath = './Resnet50/{}.jpg'.format(target_sequential)
        VisualSingleFeature(target_features, savepath).save_feature_to_img()
    # test resnet50 layer1
    x = single_image_sample()
    pretrained_model = models.resnet50(pretrained=True)
    pretrained_module = pretrained_model.layer1

    for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
        pre_module = getattr(pretrained_model, pre_sequential)
        x = pre_module(x)

    for target_Bottleneck_index in range(0, 3):
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_Bottleneck_index))
        target_features = myexactor(x)
        savepath = './Resnet50/layer1/{}.jpg'.format(target_Bottleneck_index)
        VisualSingleFeature(target_features, savepath).save_feature_to_img()
    # test resnet50 layer1 Bottleneck0
    x = single_image_sample()
    pretrained_model = models.resnet50(pretrained=True)
    pretrained_module = pretrained_model.layer1._modules['0']

    for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
        pre_module = getattr(pretrained_model, pre_sequential)
        x = pre_module(x)

    for target_sequential in ['conv1', 'bn1', 'conv2', 'bn2', 'conv3', 'bn3', 'relu']:
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
        target_features = myexactor(x)
        savepath = './Resnet50/layer1/Bottleneck0/{}.jpg'.format(target_sequential)
        VisualSingleFeature(target_features, savepath).save_feature_to_img()


优化:visual all feature outputs, not only single feature

 # [基于Pytorch的特征图提取](https://blog.csdn.net/ZOUZHEN_ID/article/details/84025943)
 # 特征输出可视化
    for i in range(feature_channel_number):
        ax = plt.subplot(6, 6, i + 1)
        ax.set_title('Feature {}'.format(i))
        ax.axis('off')
        plt.imshow(target_features.data.numpy()[0,i,:,:],cmap='jet')

    plt.show()

上述中间特征提取方法的“致命”缺点:Sequential只能表示简单的顺序模型结构(连U-Net的的跳跃连接都表示不了),复杂模型无力,这部分最好用register_forward_hook来做! ⤵️

from torchvision import transforms
import numpy as np

simple_transform = transforms.Compose([transforms.Resize((224, 224)),
                                       transforms.ToTensor(),  # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255
                                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
                                       ])


def get_single_feature(extract_features):
    print(extract_features.shape)
    extract_feature = extract_features[:, 0, :, :]
    print(extract_feature.shape)
    extract_feature = extract_feature.view(extract_feature.shape[1], extract_feature.shape[2])
    print(extract_feature.shape)

    return extract_feature


def save_feature_to_img(extract_features, save_path):
    extract_feature = get_single_feature(extract_features).data.numpy()
    extract_feature = 1.0/(1+np.exp(-1*extract_feature))
    extract_feature = np.round(extract_feature*255)
    print(extract_feature[0])
    import cv2
    cv2.imwrite(save_path, extract_feature)
    print('{} saved.'.format(save_path))


def hook(module, input, output):
    save_feature_to_img(output.data, save_path)


if __name__ == '__main__':

    from PIL import Image
    img_path = './snorlax.png'
    input_img = Image.open(img_path).convert('RGB')
    input_tensor = simple_transform(input_img)
    x = input_tensor[np.newaxis, :, :, :]
    
    target_layer = 5
    save_path = './VGG16/layer_{}.png'.format(str(target_layer))

    import torchvision.models as models
    net = models.vgg16(pretrained=True)
    handle = getattr(net.features, '_modules')[target_layer].register_forward_hook(hook)
    _ = net(x)
    handle.remove()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。