参考swin transformer源码,我们修改了:
- 添加了DropPath策略
- 每一个stage的输出添加了norm层
- 每一个PatchMerge层添加了norm层
- 源码里每一个block, 每一个head都使用不共享的pos bias, 我们这里使用的是共享的
# --*-- coding:utf-8 --*--
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import (constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from timm.models.layers import DropPath, trunc_normal_
from mmdet.utils import get_root_logger
from ..builder import BACKBONES
class Mlp(nn.Module):
expasion = 4
def __init__(self, in_feature, hidden_feature=None, out_feature=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_feature = out_feature or in_feature
hidden_feature = hidden_feature or in_feature * self.expasion
self.fc1 = nn.Conv2d(in_feature, hidden_feature, 1, 1, 0)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_feature, out_feature, 1, 1, 0)
self.drop = nn.Dropout(drop)
def forward(self, x):
return self.drop(self.fc2(self.drop(self.act(self.fc1(x)))))
class PatchEmbedding(nn.Module):
def __init__(self, in_feature, out_feature, kernel_size=4, norm_layer=nn.LayerNorm, drop=0.):
super().__init__()
self.patch_size = kernel_size
self.fc = nn.Conv2d(in_feature, out_feature, kernel_size=kernel_size, stride=kernel_size, padding=0)
self.drop = nn.Dropout(drop)
if norm_layer is not None:
self.norm = norm_layer(out_feature)
else:
self.norm = None
def forward(self, x):
_, _, H, W = x.size()
if W % self.patch_size != 0:
x = F.pad(x, (0, self.patch_size - W % self.patch_size))
if H % self.patch_size != 0:
x = F.pad(x, (0, 0, 0, self.patch_size - H % self.patch_size))
x = self.drop(self.fc(x))
if self.norm is not None:
x = self.norm(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
return x
class PatchMerging(nn.Module):
def __init__(self, in_feature, out_feature, kernel_size=2, norm_layer=nn.LayerNorm, drop=0.):
super().__init__()
self.fc = nn.Linear(in_feature* kernel_size**2, out_feature, bias=False)
self.kernel_size = kernel_size
self.norm = norm_layer(in_feature* kernel_size**2)
self.drop = nn.Dropout(drop)
def forward(self, x):
B, C, H, W = x.size()
x = x.view(B, C, H//self.kernel_size, self.kernel_size, W//self.kernel_size, self.kernel_size).permute(0, 2, 4, 1, 3, 5).contiguous()
x = self.drop(self.fc(self.norm(torch.flatten(x, 3))))
x = x.permute(0, 3, 1, 2).contiguous()
return x
class WMSA(nn.Module):
def __init__(self, dim, head_dim=32, M=7, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.n_heads = dim // head_dim
self.scale = qk_scale or head_dim ** -0.5
self.head_dim = head_dim
self.M = M
self.q = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, bias=qkv_bias)
self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, stride=1, padding=0, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop )
self.proj = nn.Conv2d(dim, dim, 1, 1, 0)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, pos_bias, shift, masks):
"""
:param x: tensor, BCHW
:param pos_bias: tensor, M^2 x M^2
:param shift: int, 0 or 1
:param masks: dict{"top":tensor(M^2 x M^2), "left":tensor, "topleft":tensor}
:return: tensor, BCHW
"""
B, C, H, W = x.size()
wn_h, wn_w = H//self.M, W//self.M
q = self.q(x).view(B, self.n_heads, self.head_dim, wn_h, self.M, wn_w, self.M).permute(0, 3, 5, 1,2,4,6).contiguous()
q = q.view(B, wn_h, wn_w, self.n_heads, self.head_dim, -1) # B x wh x ww x n_head x head_dim x M^2
kv = self.kv(x).view(B, 2, self.n_heads, self.head_dim, wn_h, self.M, wn_w, self.M).permute(1, 0, 4, 6, 2, 3, 5, 7).contiguous() # 2 x B x wh x wn x n_head x head_dim x M^2
kv = kv.view(2, B, wn_h, wn_w, self.n_heads, self.head_dim, -1)
k, v = kv[0], kv[1]
attn = ((q.transpose(-2, -1).contiguous())@k) * self.scale + pos_bias.expand(1, 1, 1, 1, self.M**2, self.M**2) # B x wh x wn x n_head x M^2 x M^2
if shift==1:
attn[:, :-1, -1] += masks["left"].expand(1, 1, 1, self.M**2, self.M**2).to(x.device)
attn[:, -1, :-1] += masks["top"].expand(1, 1, 1, self.M**2, self.M**2).to(x.device)
attn[:, -1, -1] += masks["topleft"].expand(1, 1, self.M**2, self.M**2).to(x.device)
attn = self.attn_drop(F.softmax(attn, dim=-1))
x = (v @ (attn.transpose(-2, -1).contiguous())).view(B, wn_h, wn_w, -1, self.M, self.M) # B x wh x wn x C x M x M
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, C, H, W)
x = self.proj_drop(self.proj(x))
return x
class SwinTransformerBlock(nn.Module):
def __init__(self, in_feature, head_dim=32, M=7, shift=0, norm_layer=nn.LayerNorm, drop_path=0.):
super().__init__()
self.norm1 = norm_layer(in_feature)
self.norm2 = norm_layer(in_feature)
self.multi_attn = WMSA(in_feature, head_dim, M)
self.mlp = Mlp(in_feature)
self.M = M
self.shift = shift
self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()
def forward(self, x, pos_bias, masks):
b, c, oh, ow = x.size()
# cyclic shifted window
shift_stride = self.M//2
if oh%self.M == 0: padding_bottom = 0
else: padding_bottom = (oh//self.M+1)*self.M - oh
if ow%self.M == 0: padding_right = 0
else: padding_right = (ow//self.M+1)*self.M - ow
x = F.pad(x, (0, padding_right, 0, padding_bottom), 'constant', 0)
h, w = x.size(-2), x.size(-1)
if self.shift == 1: # top-left
x = x.roll((-shift_stride, -shift_stride), (-1, -2))
norm1 = self.norm1(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous() # B x C x H x W
z1 = self.multi_attn(norm1, pos_bias, self.shift, masks)
z2 = x + self.drop_path(z1)
norm2 = self.norm2(z2.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
z2 = z2 + self.drop_path(self.mlp(norm2))
if self.shift ==1:
z2 = torch.roll(z2, (shift_stride, shift_stride), (-1, -2))
return z2[..., :oh, :ow]
class Stage(nn.Module):
def __init__(self, in_feature, out_feature, num_layers, patch_norm=True, patch_merge=PatchEmbedding, M=7, head_dim=32, stride=2, drop_path=(0.2, 0.2)):
super().__init__()
self.downsample = patch_merge(in_feature, out_feature, kernel_size=stride, norm_layer=nn.LayerNorm if patch_norm else None)
self.blocks = nn.ModuleList()
for k in range(num_layers//2):
self.blocks.append(
SwinTransformerBlock(out_feature, head_dim, M, 0, drop_path=drop_path[k*2])
)
self.blocks.append(
SwinTransformerBlock(out_feature, head_dim, M, 1, drop_path=drop_path[2*k+1])
)
def forward(self, x, pos_bias, masks):
x = self.downsample(x)
for m in self.blocks:
x = m(x, pos_bias, masks)
return x
class PosBias(nn.Module):
def __init__(self, M):
super().__init__()
self.M = M
self.emb_dict = nn.Embedding((2*M-1)**2, 1)
def forward(self, device):
x, y = torch.meshgrid(torch.arange(self.M), torch.arange(self.M))
indices = torch.stack((x.flatten(), y.flatten()), dim=1)
indices = indices.unsqueeze(1) - indices.unsqueeze(0)
indices = (indices[..., 0] +self.M-1) * (2*self.M-1) + (indices[..., 1] + self.M-1)
indices = indices.long().to(device)
return self.emb_dict(indices).squeeze(-1) # M**2 x M**2
@BACKBONES.register_module()
class SwinTransformer(nn.Module):
"""SwinTransformer backbone
Args:
model_type (str): type of the swin transformer type, from {'T', 'S', 'B', 'L'}
out_indices (Sequence [int]): Output from which stages.
M (int): size of the window.
head_dim (int): dim of each head in MSA.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
Example:
>>> from mmdet.models import SwinTransformer
>>> import torch
>>> self = SwinTransformer(model_type="T")
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 96, 8, 8)
(1, 192, 4, 4)
(1, 384, 2, 2)
(1, 768, 1, 1)
"""
arch_settings = {
"T": (96, (2, 2, 6, 2)),
"S": (96, (2, 2, 18, 2)),
"B": (128, (2, 2, 18, 2)),
"L": (192, (2, 2, 18, 2))
}
def __init__(self, model_type, out_indices=(0,1,2,3), M=7, head_dim=32, patch_norm=True, frozen_stages=-1, drop_path_rate=0.2):
super().__init__()
init_feature, layers = self.arch_settings[model_type]
self.frozen_stages = frozen_stages
self.out_indices = out_indices
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] # stochastic depth decay rule
self.pos_bias = PosBias(M)
self.backbone = nn.ModuleList()
self.backbone.append(Stage(3, init_feature, layers[0], patch_norm, PatchEmbedding, M, head_dim, stride=4, drop_path=dpr[:layers[0]]))
out_feature_dims = [init_feature]
for i, v in enumerate(layers[1:]):
self.backbone.append(Stage(init_feature, 2*init_feature, v, True, PatchMerging, M, head_dim, stride=2,
drop_path=dpr[sum(layers[:i+1]):sum(layers[:i+2])]))
init_feature *= 2
out_feature_dims.append(init_feature)
#add a norm layer for each output
for k in out_indices:
self.add_module(f'norm_stage{k}', nn.LayerNorm(out_feature_dims[k]))
self.M = M
self.masks = {
"top": self.create_mask("top"),
"left": self.create_mask("left"),
"topleft": self.create_mask("topleft")
}
self._freeze_stages()
def create_mask(self, d):
""" get the mask according to the direction
:param d: str, (top, left, topleft)
:return : tensor, M^2 x M^2
"""
base = torch.ones(self.M, self.M)
mask = torch.zeros(self.M**2, self.M**2)
stride = self.M //2
s_stride = self.M - stride
if d == 'top':
base[:s_stride] = 0
base = base.flatten()
mask[base==0] = base
mask[base==1] = 1 - base
elif d == "left":
base[:, :s_stride] = 0
base = base.flatten()
mask[base==0] = base
mask[base==1] = 1 - base
elif d == "topleft":
base[:s_stride, :s_stride] = 0
base[s_stride:, :s_stride] = 2
base[s_stride:, s_stride:]=3
base = base.flatten()
mask[base==0] = (~(base ==0) ).float()
mask[base==1] = (~(base ==1)).float()
mask[base==2] = (~(base ==2)).float()
mask[base==3] = (~(base ==3)).float()
mask[mask>0] = float('-inf')
return mask
def _freeze_stages(self):
if self.frozen_stages>=0:
self.pos_bias.eval()
for param in self.pos_bias.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
for param in self.backbone[i].parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m)
nn.init.constant_(m.bias, 0)
if isinstance(pretrained, str):
self.apply(_init_weights)
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
outs = []
posb = self.pos_bias(x.device)
for i, m in enumerate(self.backbone):
x = m(x, posb, self.masks)
if i in self.out_indices:
norm_layer = getattr(self, f'norm_stage{i}')
outs.append(norm_layer(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous())
return tuple(outs)
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
```