24、ViP模块
论文《VISION PERMUTATOR: A PERMUTABLE MLP-LIKE ARCHITECTURE FOR VISUAL RECOGNITION》
1、作用
论文提出的Vision Permutator是一种简单、数据高效的类MLP(多层感知机)架构,用于视觉识别。不同于其他MLP类模型,它通过线性投影分别对特征表示在高度和宽度维度进行编码,能够保留2D特征表示中的位置信息,有效捕获沿一个空间方向的长距离依赖关系,同时保留另一方向上的精确位置信息。
2、机制
1、视觉置换器:
Vision Permutator采用与视觉变换器类似的令牌化操作,将输入图像均匀划分为小块,并通过线性投影将它们映射为令牌嵌入。随后,这些令牌嵌入被送入一系列Permutator块中进行特征编码。
2、Permute-MLP:
Permutator块包含一个用于空间信息编码的Permute-MLP和一个用于通道信息混合的Channel-MLP。Permute-MLP通过独立处理令牌表示沿高度和宽度的维度,生成具有特定方向信息的令牌,这对于视觉识别至关重要。
3、加权Permute-MLP:
在简单的Permute-MLP基础上,引入加权Permute-MLP来重新校准不同分支的重要性,进一步提高模型性能。
3、独特优势
1、空间信息编码:
Vision Permutator通过在高度和宽度维度上分别对特征进行编码,相比于其他将两个空间维度混合为一个进行处理的MLP类模型,能够更有效地保留空间位置信息,从而提高模型对图像中对象的识别能力。
2、性能提升:
实验表明,即使在不使用额外大规模训练数据的情况下,Vision Permutator也能达到81.5%的ImageNet顶级-1准确率,并且仅使用25M可学习参数,这比大多数同等大小模型的CNNs和视觉变换器都要好。
3、模型高效:
Vision Permutator的结构简单、数据高效,在确保高准确性的同时提高了训练和推理速度,展现了MLP类模型在视觉识别任务中的潜力。
4、代码
import torch
from torch import nn
class MLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features, act_layer=nn.GELU, drop=0.1):
super().__init__()
# 第一层全连接层
self.fc1 = nn.Linear(in_features, hidden_features)
# 激活函数
self.act = act_layer()
# 第二层全连接层
self.fc2 = nn.Linear(hidden_features, out_features)
# Dropout层
self.drop = nn.Dropout(drop)
def forward(self, x):
# 顺序通过第一层全连接层、激活函数、Dropout、第二层全连接层、Dropout
return self.drop(self.fc2(self.drop(self.act(self.fc1(x)))))
class WeightedPermuteMLP(nn.Module):
def __init__(self, dim, seg_dim=8, qkv_bias=False, proj_drop=0.):
super().__init__()
# 分段维度,用于在特定维度上分段处理特征
self.seg_dim = seg_dim
# 定义对通道C、高度H、宽度W的MLP处理层
self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)
self.mlp_h = nn.Linear(dim, dim, bias=qkv_bias)
self.mlp_w = nn.Linear(dim, dim, bias=qkv_bias)
# 重置权重的MLP层
self.reweighting = MLP(dim, dim // 4, dim * 3)
# 最终投影层
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, H, W, C = x.shape
# 通道维度的处理
c_embed = self.mlp_c(x)
# 高度维度的处理
S = C // self.seg_dim
h_embed = x.reshape(B, H, W, self.seg_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.seg_dim, W, H * S)
h_embed = self.mlp_h(h_embed).reshape(B, self.seg_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)
# 宽度维度的处理
w_embed = x.reshape(B, H, W, self.seg_dim, S).permute(0, 3, 1, 2, 4).reshape(B, self.seg_dim, H, W * S)
w_embed = self.mlp_w(w_embed).reshape(B, self.seg_dim, H, W, S).permute(0, 2, 3, 1, 4).reshape(B, H, W, C)
# 计算三个维度的权重并应用softmax进行归一化
weight = (c_embed + h_embed + w_embed).permute(0, 3, 1, 2).flatten(2).mean(2)
weight = self.reweighting(weight).reshape(B, C, 3).permute(2, 0, 1).softmax(0).unsqueeze(2).unsqueeze(2)
# 加权融合处理后的特征
x = c_embed * weight[0] + w_embed * weight[1] + h_embed * weight[2]
# 应用投影层和Dropout
x = self.proj_drop(self.proj(x))
return x
if __name__ == '__main__':
input = torch.randn(64, 8, 8, 512) # 模拟输入数据
seg_dim = 8 # 定义分段维度
vip = WeightedPermuteMLP(512, seg_dim) # 初始化模型
out = vip(input) # 前向传播
print(out.shape)