import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from mmseg.apis import inference_model, init_model
import mmcv
config_file = '你的配置文件.py'
checkpoint_file = '你的权重文件.pth'
# 1. 初始化模型
model = init_model(config_file, checkpoint_file, device='cuda:0')
# ==================== 核心修改区 开始 ====================
# 2. 准备一个全局变量,用来“接住”特征图
captured_features = None
# 3. 定义 Hook 函数
def hook_fn(module, input, output):
global captured_features
# output 就是这一层输出的 Tensor,我们把它转到 CPU 上存起来
captured_features = output.detach().cpu()
# 4. 找到你的“细节点支”并挂上 Hook (🚨 这里最关键!)
# 你需要把下面这行代码里的 `model.backbone.spatial_branch_layers`
#或`modules_dict['backbone.spatial_branch_layers']` 换成你网络中某一层的真实变量名。
# 如果你不知道它叫什么,可以加一行 print(model) 在终端里找一下层级结构。
# 或者去训练日志里找,比如backbone.spatial_branch_layers.2.0.downsample.1.bias
# target_layer = model.backbone.spatial_branch_layers[2]
# 如果你非要精确到 2 里面的 0:
# target_layer = model.backbone.spatial_branch_layers[2][0]
# 推荐下面这种更方便
# 将模型所有的层转化为一个字典
modules_dict = dict(model.named_modules())
# 通常是整个block,如下
target_layer = modules_dict['backbone.spatial_branch_layers.2']
hook_handle = target_layer.register_forward_hook(hook_fn)
# ==================== 核心修改区 结束 ====================
# 5. 测试单张图像 (这一步会自动触发 Hook)
img_path = 'demo/demo.png'
result = inference_model(model, img_path)
# 6. 用完 Hook 记得拆掉,防止内存泄漏
hook_handle.remove()
# 7. 开始特征图处理与可视化
if captured_features is not None:
print(f"成功截获特征图!形状为: {captured_features.shape}")
# 获取单张图片的特征 (去掉 Batch 维度),变成 (C, H, W)
features = captured_features[0]
# 策略:在通道维度(dim=0)求绝对值的平均,融合成一张 (H, W) 的 2D 图
agg_map = torch.mean(torch.abs(features), dim=0).numpy()
# Min-Max 归一化到 [0, 1] 区间,为了画图更亮
agg_map = (agg_map - np.min(agg_map)) / (np.max(agg_map) - np.min(agg_map) + 1e-8)
# 画出高大上的伪彩色特征图
plt.figure(figsize=(8, 8))
# 'viridis' 是一种从深蓝到黄绿色的学术渐变色,非常适合展示高频响应
# 如果你想展示纯粹的线框感,可以把 'viridis' 改成 'gray'
plt.imshow(agg_map, cmap='viridis', interpolation='nearest')
plt.axis('off')
#plt.tight_layout()
# 核心去白边代码开始
plt.margins(0, 0) # 将内部空白设为 0
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) # 把绘图区拉伸到填满整个画布
# 核心去白边代码结束
# 保存或展示
plt.savefig('detail_feature_map.jpg', dpi=300, bbox_inches='tight')
plt.show()
else:
print("未能截获特征图,请检查 target_layer 的名字是否正确!")
增加亮度版
import torch
import cv2
import numpy as np
from mmseg.apis import inference_model, init_model
import mmcv
config_file = 'demo.py'
checkpoint_file = 'demo.pth'
# 1. 初始化模型
model = init_model(config_file, checkpoint_file, device='cuda:0')
# ==================== 核心修改区 开始 ====================
# 2. 准备一个全局变量,用来“接住”特征图
captured_features = None
# 3. 定义 Hook 函数
def hook_fn(module, input, output):
global captured_features
captured_features = output.detach().cpu()
# 4. 找到你的层并挂上 Hook
modules_dict = dict(model.named_modules())
# 这里可以随时切换你要提取的分支
target_layer = modules_dict['backbone.demo'] # 修改
hook_handle = target_layer.register_forward_hook(hook_fn)
# ==================== 核心修改区 结束 ====================
# 5. 测试单张图像
img_path = 'demo.jpg'
result = inference_model(model, img_path)
# 6. 用完 Hook 记得拆掉
hook_handle.remove()
# 7. 开始特征图处理与可视化 (✨ 全新升级区 ✨)
if captured_features is not None:
print(f"成功截获特征图!形状为: {captured_features.shape}")
# 获取单张图片的特征 (C, H, W)
features = captured_features[0]
# 🌟 升级 1:从 mean 换成 max,大幅强化稀疏的高频划痕响应
agg_map, _ = torch.max(torch.abs(features), dim=0)
agg_map = agg_map.numpy()
# 🌟 升级 2:百分位截断法 (科学提亮)
# 99.9 表示忽略最亮的 0.1% 的极端异常像素,把剩下的特征动态范围拉满
clip_val = np.percentile(agg_map, 99.9)
agg_map = np.clip(agg_map, 0, clip_val)
# 正常的 Min-Max 归一化到 [0, 1]
agg_map = (agg_map - np.min(agg_map)) / (np.max(agg_map) - np.min(agg_map) + 1e-8)
# 🌟 升级 3:OpenCV 纯净保存法 (绝对无白边,尺寸严丝合缝)
# 转为 8-bit [0, 255]
heatmap_uint8 = np.uint8(255 * agg_map)
# 映射伪彩色 (VIRIDIS 是深蓝到黄绿,学术界标配)
colored_heatmap = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_VIRIDIS)
# 获取原图尺寸,使用 Nearest(最近邻) 插值放大,保持真实的马赛克感/锐利感
original_img = cv2.imread(img_path)
h, w = original_img.shape[:2]
colored_heatmap = cv2.resize(colored_heatmap, (w, h), interpolation=cv2.INTER_NEAREST)
# 保存图像
save_path = 'feature_map_pure.jpg'
cv2.imwrite(save_path, colored_heatmap)
print(f"✨ 完美无白边的特征图已保存至: {save_path}")
else:
print("未能截获特征图,请检查 target_layer 的名字是否正确!")