import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
class Attention(nn.Module):
def __init__(self,dim=768,num_heads=8,qkv_bias=False,attn_drop=0.,proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim//num_heads
self.qkv = nn.Linear(dim,3*dim,bias=qkv_bias)
self.qkv_scale = self.head_dim**-0.5
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim,dim,bias=False)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self,x):
#具体操作是将矩阵沿着dim方向均分num_head份,每一份进行矩阵乘积,最后把各个结果沿着head_dim方向进行concat输出
B,N,C = x.shape
#permute转置和transpose转置:https://blog.csdn.net/qq_41740004/article/details/104712173
qkv = self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)
#qkv:(3,B,num_heads,N,head_dim)
q,k,v = qkv[0],qkv[1],qkv[2]
#numpy点乘:使用*或者np.multiply();叉乘:@, np.dot(), np.matmul()
#pytorch点乘:*或者np.multiply();叉乘:@,torch.mm()
#点乘-矩阵内积;叉乘-矩阵乘积-矩阵向量积
#https://blog.csdn.net/ykf173/article/details/104630175
attn = (q @ k.transpose(-2,-1))*self.qkv_scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
attn = (attn @ v).transpose(1,2).reshape(B,N,C)
x = self.proj(attn)
x = self.proj_drop(x)
return x
class PatchEmbed(nn.Module):
def __init__(self,img_size=224,patch_size=16,in_chans=3,dim=768,norm_layer=nn.LayerNorm):
super().__init__()
self.img_size = img_size if isinstance(img_size,tuple) else (img_size,img_size)
self.patch_size = patch_size if isinstance(patch_size,tuple) else (patch_size,patch_size)
self.num_patches = self.img_size[0]//self.patch_size[0]*self.img_size[1]//self.patch_size[1]
self.conv = nn.Conv2d(in_chans,dim,(self.patch_size[0],self.patch_size[1]),(self.patch_size[0],self.patch_size[1]),padding=0)
self.norm = norm_layer(dim)
def forward(self,x):
B,C,H,W = x.shape
assert H == self.img_size[0] and W == self.img_size[1]
x = self.conv(x)
x = x.flatten(2).transpose(-2,-1)\
x = self.norm(x)
return x
class MLP(nn.Module):
def __init__(self,c1,c_,c2,drop_rate,act_layer=nn.GELU):
super().__init__()
c_ = c_ or c1
c2 = c2 or c1
self.fc1 = nn.Linear(c1,c_)
self.act = act_layer()
self.fc2 = nn.Linear(c_,c2)
self.drop = nn.Dropout(drop_rate)
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self,dim=768,mlp_ratio=1.,num_heads=8,attn_drop=0.,qkv_scale=1.,drop=0.,norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim,num_heads,attn_drop,qkv_scale,drop)
self.drop_path = nn.Identity()
self.norm2 = norm_layer(dim)
hidden_dim = int(dim*mlp_ratio)
self.mlp = MLP(dim,hidden_dim,dim,drop)
def forward(self,x):
x = x+self.drop_path(self.attn(self.norm1(x)))
x = x+self.drop_path(self.mlp(self.norm2(x)))
return x
class Simple_VIT(nn.Module):
def __init__(self,img_size=224,patch_size=16,in_chans=3,embedding_dim=768,depth=8,drop=0.,num_class=10):
super().__init__()
self.PatchEmbed = PatchEmbed(img_size,patch_size,in_chans,embedding_dim)
num_patches = self.PatchEmbed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1,1,embedding_dim))#第一轮与之相关的权重不会更新,因为w*x中x为0,所以反向传播后>,w的梯度为0,但是本轮x会更新,下一轮w可以进行更新,因为下一轮x就不是0了,验证:position_embedding_20210521.py
self.position_embed = nn.Parameter(torch.zeros(1,num_patches+1,embedding_dim))
self.pos_drop = nn.Dropout(drop)
self.blocks = nn.Sequential(*[Block(dim=embedding_dim,drop=drop) for i in range(depth)])
self.norm = nn.LayerNorm(embedding_dim)
self.pre_logits = nn.Identity()
self.num_features = self.embedding_dim = embedding_dim
self.head = nn.Linear(embedding_dim,num_class) if num_class>0 else nn.Identity()
def forward_feature(self,x):
x = self.PatchEmbed(x)
cls_token = self.cls_token.expand(x.shape[0],-1,-1)
x = torch.cat((cls_token,x),dim=1)
x = x+self.position_embed
x = self.pos_drop(x)
x = self.blocks(x)
x = self.norm(x)
return self.pre_logits(x[:,0])
def forward(self,x):
x = self.forward_feature(x)
x = self.head(x)
return x
image = Image.open('multi.jpg').resize((224,224))
image.save('src.png')
to_tensor = transforms.ToTensor()
to_PIL = transforms.ToPILImage()
model = Simple_VIT()
model.eval()
img = to_tensor(image)
img = img.unsqueeze(0)
res = model(img)
#res= to_PIL(res.squeeze(0))
print(res.softmax(dim=-1).data)
print(torch.argmax(res,dim=1).item())
1.关于attention操作之后为什么要进行scale操作:
比较大的输入会使得后续softmax的梯度变得很小,甚至导致梯度消失
softmax求导:
https://zhuanlan.zhihu.com/p/105758059
这里给出结论:
1.1.softmax loss求导:a-y,其中a是label,y是softmax的输出
1.2.softmax求导:
2.scale的取值为什么可以是:
假设attention操作中的q和k每个分量是相互独立的均值为0方差为1的随机变量,则q和k进行向量积操作之后q@k=attn,attn的每个分量变为均值为0方差为head_dim的随机变量
方差越大分量越有可能取到较大的量级,导致sotfmax操作之后的结果某一个取值接近1而其他取值接近于0,导致梯度反向传播到attn的时候导致梯度消失,而对每个分量乘以会将其方差限制回1。
注意:如果softmax位于输出层,则不用过于考虑输入量级对softmax结果的影响;而attention中softmax位于网络中间层,所以需要考虑。
参考:
1.https://www.zhihu.com/search?type=content&q=attention%E4%B8%BA%E4%BB%80%E4%B9%88%E8%A6%81scale
2.https://zhuanlan.zhihu.com/p/105758059
3.https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py