mmseg画特征图

import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from mmseg.apis import inference_model, init_model
import mmcv

config_file = '你的配置文件.py'
checkpoint_file = '你的权重文件.pth'

# 1. 初始化模型
model = init_model(config_file, checkpoint_file, device='cuda:0')

# ==================== 核心修改区 开始 ====================

# 2. 准备一个全局变量,用来“接住”特征图
captured_features = None

# 3. 定义 Hook 函数
def hook_fn(module, input, output):
    global captured_features
    # output 就是这一层输出的 Tensor,我们把它转到 CPU 上存起来
    captured_features = output.detach().cpu()

# 4. 找到你的“细节点支”并挂上 Hook (🚨 这里最关键!)
# 你需要把下面这行代码里的 `model.backbone.spatial_branch_layers`
#或`modules_dict['backbone.spatial_branch_layers']` 换成你网络中某一层的真实变量名。
# 如果你不知道它叫什么,可以加一行 print(model) 在终端里找一下层级结构。
# 或者去训练日志里找,比如backbone.spatial_branch_layers.2.0.downsample.1.bias
# target_layer = model.backbone.spatial_branch_layers[2]
# 如果你非要精确到 2 里面的 0:
# target_layer = model.backbone.spatial_branch_layers[2][0]

# 推荐下面这种更方便
# 将模型所有的层转化为一个字典
modules_dict = dict(model.named_modules())

# 通常是整个block,如下
target_layer = modules_dict['backbone.spatial_branch_layers.2']

hook_handle = target_layer.register_forward_hook(hook_fn)

# ==================== 核心修改区 结束 ====================

# 5. 测试单张图像 (这一步会自动触发 Hook)
img_path = 'demo/demo.png' 
result = inference_model(model, img_path)

# 6. 用完 Hook 记得拆掉,防止内存泄漏
hook_handle.remove()

# 7. 开始特征图处理与可视化
if captured_features is not None:
    print(f"成功截获特征图!形状为: {captured_features.shape}")
    
    # 获取单张图片的特征 (去掉 Batch 维度),变成 (C, H, W)
    features = captured_features[0]
    
    # 策略:在通道维度(dim=0)求绝对值的平均,融合成一张 (H, W) 的 2D 图
    agg_map = torch.mean(torch.abs(features), dim=0).numpy()
    
    # Min-Max 归一化到 [0, 1] 区间,为了画图更亮
    agg_map = (agg_map - np.min(agg_map)) / (np.max(agg_map) - np.min(agg_map) + 1e-8)
    
    # 画出高大上的伪彩色特征图 
    plt.figure(figsize=(8, 8))
    # 'viridis' 是一种从深蓝到黄绿色的学术渐变色,非常适合展示高频响应
    # 如果你想展示纯粹的线框感,可以把 'viridis' 改成 'gray'
    plt.imshow(agg_map, cmap='viridis', interpolation='nearest')
    plt.axis('off')
    #plt.tight_layout()
    # 核心去白边代码开始
    plt.margins(0, 0) # 将内部空白设为 0
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) # 把绘图区拉伸到填满整个画布
    # 核心去白边代码结束
    
    # 保存或展示
    plt.savefig('detail_feature_map.jpg', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("未能截获特征图,请检查 target_layer 的名字是否正确!")

增加亮度版

import torch
import cv2
import numpy as np
from mmseg.apis import inference_model, init_model
import mmcv

config_file = 'demo.py'
checkpoint_file = 'demo.pth'

# 1. 初始化模型
model = init_model(config_file, checkpoint_file, device='cuda:0')

# ==================== 核心修改区 开始 ====================

# 2. 准备一个全局变量,用来“接住”特征图
captured_features = None

# 3. 定义 Hook 函数
def hook_fn(module, input, output):
    global captured_features
    captured_features = output.detach().cpu()

# 4. 找到你的层并挂上 Hook
modules_dict = dict(model.named_modules())

# 这里可以随时切换你要提取的分支

target_layer = modules_dict['backbone.demo'] # 修改

hook_handle = target_layer.register_forward_hook(hook_fn)

# ==================== 核心修改区 结束 ====================

# 5. 测试单张图像
img_path = 'demo.jpg' 
result = inference_model(model, img_path)

# 6. 用完 Hook 记得拆掉
hook_handle.remove()

# 7. 开始特征图处理与可视化 (✨ 全新升级区 ✨)
if captured_features is not None:
    print(f"成功截获特征图!形状为: {captured_features.shape}")
    
    # 获取单张图片的特征 (C, H, W)
    features = captured_features[0]
    
    # 🌟 升级 1:从 mean 换成 max,大幅强化稀疏的高频划痕响应
    agg_map, _ = torch.max(torch.abs(features), dim=0)
    agg_map = agg_map.numpy()
    
    # 🌟 升级 2:百分位截断法 (科学提亮)
    # 99.9 表示忽略最亮的 0.1% 的极端异常像素,把剩下的特征动态范围拉满
    clip_val = np.percentile(agg_map, 99.9) 
    agg_map = np.clip(agg_map, 0, clip_val)
    
    # 正常的 Min-Max 归一化到 [0, 1]
    agg_map = (agg_map - np.min(agg_map)) / (np.max(agg_map) - np.min(agg_map) + 1e-8)
    
    # 🌟 升级 3:OpenCV 纯净保存法 (绝对无白边,尺寸严丝合缝)
    # 转为 8-bit [0, 255]
    heatmap_uint8 = np.uint8(255 * agg_map)
    
    # 映射伪彩色 (VIRIDIS 是深蓝到黄绿,学术界标配)
    colored_heatmap = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_VIRIDIS)
    
    # 获取原图尺寸,使用 Nearest(最近邻) 插值放大,保持真实的马赛克感/锐利感
    original_img = cv2.imread(img_path)
    h, w = original_img.shape[:2]
    colored_heatmap = cv2.resize(colored_heatmap, (w, h), interpolation=cv2.INTER_NEAREST)
    
    # 保存图像
    save_path = 'feature_map_pure.jpg'
    cv2.imwrite(save_path, colored_heatmap)
    print(f"✨ 完美无白边的特征图已保存至: {save_path}")

else:
    print("未能截获特征图,请检查 target_layer 的名字是否正确!")
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容