中间特征可视化
参考: PyTorch | 提取神经网络中间层特征进行可视化
def get_feature(self):
input=self.process_image()
print(input.shape)
x=input
for index,layer in enumerate(self.pretrained_model):
x=layer(x)
if (index == self.selected_layer):
return x
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = []
for name, module in self.submodule._modules.items():
# 目前不展示全连接层
if "fc" in name:
x = x.view(x.size(0), -1)
print(module)
x = module(x)
print(name)
if name in self.extracted_layers:
outputs.append(x)
return outputs
综上, 完整代码 ⤵️
import cv2
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import torch.nn as nn
from PIL import Image
simple_transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(), # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 中间特征提取
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layer):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layer
def forward(self, x):
for name, module in self.submodule._modules.items():
if name is "fc":
x = x.view(x.size(0), -1)
x = module(x)
print("moudle_name", name)
if name in self.extracted_layers:
return x
class VisualSingleFeature():
def __init__(self, extract_features, save_path):
self.extract_features = extract_features
self.save_path = save_path
def get_single_feature(self):
print(self.extract_features.shape) # ex. torch.Size([1, 128, 112, 112])
extract_feature = self.extract_features[:, 0, :, :]
print(extract_feature.shape) # ex. torch.Size([1, 112, 112])
extract_feature = extract_feature.view(extract_feature.shape[1], extract_feature.shape[2])
print(extract_feature.shape) # ex. torch.Size([112, 112])
return extract_feature
def save_feature_to_img(self):
# to numpy
extract_feature = self.get_single_feature().data.numpy()
# use sigmod to [0,1]
extract_feature = 1.0/(1+np.exp(-1*extract_feature))
# to [0,255]
extract_feature = np.round(extract_feature*255)
print(extract_feature[0])
# save image
cv2.imwrite(self.save_path, extract_feature)
def single_image_sample():
img_path = './snorlax.png'
input_img = Image.open(img_path).convert('RGB') # 读取图像
input_tensor = simple_transform(input_img)
print(input_tensor.shape) # torch.Size([3, 224, 224])
x = input_tensor[np.newaxis, :, :, :]
print(x.shape) # torch.Size([1, 3, 224, 224])
return x
snorlax.png
# test VGG16
x = single_image_sample()
for target_layer in range(0, 31):
pretrained_module = models.vgg16(pretrained=True).features
myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_layer))
target_features = myexactor(x)
savepath = './VGG16/layer_{}.jpg'.format(target_layer) # 需手动创建文件夹`./VGG16`
VisualSingleFeature(target_features, savepath).save_feature_to_img()
# test resnet50 sequential
x = single_image_sample()
for target_sequential in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']:
pretrained_module = models.resnet50(pretrained=True)
myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
target_features = myexactor(x)
savepath = './Resnet50/{}.jpg'.format(target_sequential)
VisualSingleFeature(target_features, savepath).save_feature_to_img()
# test resnet50 layer1
x = single_image_sample()
pretrained_model = models.resnet50(pretrained=True)
pretrained_module = pretrained_model.layer1
for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
pre_module = getattr(pretrained_model, pre_sequential)
x = pre_module(x)
for target_Bottleneck_index in range(0, 3):
myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_Bottleneck_index))
target_features = myexactor(x)
savepath = './Resnet50/layer1/{}.jpg'.format(target_Bottleneck_index)
VisualSingleFeature(target_features, savepath).save_feature_to_img()
# test resnet50 layer1 Bottleneck0
x = single_image_sample()
pretrained_model = models.resnet50(pretrained=True)
pretrained_module = pretrained_model.layer1._modules['0']
for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
pre_module = getattr(pretrained_model, pre_sequential)
x = pre_module(x)
for target_sequential in ['conv1', 'bn1', 'conv2', 'bn2', 'conv3', 'bn3', 'relu']:
myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
target_features = myexactor(x)
savepath = './Resnet50/layer1/Bottleneck0/{}.jpg'.format(target_sequential)
VisualSingleFeature(target_features, savepath).save_feature_to_img()
优化:
visual all feature outputs, not only single feature
# [基于Pytorch的特征图提取](https://blog.csdn.net/ZOUZHEN_ID/article/details/84025943)
# 特征输出可视化
for i in range(feature_channel_number):
ax = plt.subplot(6, 6, i + 1)
ax.set_title('Feature {}'.format(i))
ax.axis('off')
plt.imshow(target_features.data.numpy()[0,i,:,:],cmap='jet')
plt.show()
上述中间特征提取方法的“致命”缺点:Sequential只能表示简单的顺序模型结构(连U-Net的的跳跃连接都表示不了),复杂模型无力,这部分最好用register_forward_hook
来做! ⤵️
from torchvision import transforms
import numpy as np
simple_transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(), # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def get_single_feature(extract_features):
print(extract_features.shape)
extract_feature = extract_features[:, 0, :, :]
print(extract_feature.shape)
extract_feature = extract_feature.view(extract_feature.shape[1], extract_feature.shape[2])
print(extract_feature.shape)
return extract_feature
def save_feature_to_img(extract_features, save_path):
extract_feature = get_single_feature(extract_features).data.numpy()
extract_feature = 1.0/(1+np.exp(-1*extract_feature))
extract_feature = np.round(extract_feature*255)
print(extract_feature[0])
import cv2
cv2.imwrite(save_path, extract_feature)
print('{} saved.'.format(save_path))
def hook(module, input, output):
save_feature_to_img(output.data, save_path)
if __name__ == '__main__':
from PIL import Image
img_path = './snorlax.png'
input_img = Image.open(img_path).convert('RGB')
input_tensor = simple_transform(input_img)
x = input_tensor[np.newaxis, :, :, :]
target_layer = 5
save_path = './VGG16/layer_{}.png'.format(str(target_layer))
import torchvision.models as models
net = models.vgg16(pretrained=True)
handle = getattr(net.features, '_modules')[target_layer].register_forward_hook(hook)
_ = net(x)
handle.remove()