Mutihead-Self-Attention in Computer Vision

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image

class Attention(nn.Module):
    def __init__(self,dim=768,num_heads=8,qkv_bias=False,attn_drop=0.,proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim//num_heads
        self.qkv = nn.Linear(dim,3*dim,bias=qkv_bias)
        self.qkv_scale = self.head_dim**-0.5
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim,dim,bias=False)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self,x):
        #具体操作是将矩阵沿着dim方向均分num_head份,每一份进行矩阵乘积,最后把各个结果沿着head_dim方向进行concat输出
        B,N,C = x.shape
        #permute转置和transpose转置:https://blog.csdn.net/qq_41740004/article/details/104712173
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)
        #qkv:(3,B,num_heads,N,head_dim)
        q,k,v = qkv[0],qkv[1],qkv[2]
        #numpy点乘:使用*或者np.multiply();叉乘:@, np.dot(), np.matmul()
        #pytorch点乘:*或者np.multiply();叉乘:@,torch.mm()
        #点乘-矩阵内积;叉乘-矩阵乘积-矩阵向量积
        #https://blog.csdn.net/ykf173/article/details/104630175
        attn = (q @ k.transpose(-2,-1))*self.qkv_scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        attn = (attn @ v).transpose(1,2).reshape(B,N,C)
        x = self.proj(attn)
        x = self.proj_drop(x)
        return x
class PatchEmbed(nn.Module):
    def __init__(self,img_size=224,patch_size=16,in_chans=3,dim=768,norm_layer=nn.LayerNorm):
        super().__init__()
        self.img_size = img_size if isinstance(img_size,tuple) else (img_size,img_size)
        self.patch_size = patch_size if isinstance(patch_size,tuple) else (patch_size,patch_size)
        self.num_patches = self.img_size[0]//self.patch_size[0]*self.img_size[1]//self.patch_size[1]
        self.conv = nn.Conv2d(in_chans,dim,(self.patch_size[0],self.patch_size[1]),(self.patch_size[0],self.patch_size[1]),padding=0)
        self.norm = norm_layer(dim)
    def forward(self,x):
        B,C,H,W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1]
        x = self.conv(x)
        x = x.flatten(2).transpose(-2,-1)\
        x = self.norm(x)
        return x
class MLP(nn.Module):
    def __init__(self,c1,c_,c2,drop_rate,act_layer=nn.GELU):
        super().__init__()
        c_ = c_ or c1
        c2 = c2 or c1
        self.fc1 = nn.Linear(c1,c_)
        self.act = act_layer()
        self.fc2 = nn.Linear(c_,c2)
        self.drop = nn.Dropout(drop_rate)
    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
class Block(nn.Module):
    def __init__(self,dim=768,mlp_ratio=1.,num_heads=8,attn_drop=0.,qkv_scale=1.,drop=0.,norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim,num_heads,attn_drop,qkv_scale,drop)
        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        hidden_dim = int(dim*mlp_ratio)
        self.mlp = MLP(dim,hidden_dim,dim,drop)
    def forward(self,x):
        x = x+self.drop_path(self.attn(self.norm1(x)))
        x = x+self.drop_path(self.mlp(self.norm2(x)))
        return x
class Simple_VIT(nn.Module):
    def __init__(self,img_size=224,patch_size=16,in_chans=3,embedding_dim=768,depth=8,drop=0.,num_class=10):
        super().__init__()

        self.PatchEmbed = PatchEmbed(img_size,patch_size,in_chans,embedding_dim)
        num_patches = self.PatchEmbed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1,1,embedding_dim))#第一轮与之相关的权重不会更新,因为w*x中x为0,所以反向传播后>,w的梯度为0,但是本轮x会更新,下一轮w可以进行更新,因为下一轮x就不是0了,验证:position_embedding_20210521.py
        self.position_embed = nn.Parameter(torch.zeros(1,num_patches+1,embedding_dim))
        self.pos_drop = nn.Dropout(drop)
        self.blocks = nn.Sequential(*[Block(dim=embedding_dim,drop=drop) for i in range(depth)])
        self.norm = nn.LayerNorm(embedding_dim)
        self.pre_logits = nn.Identity()
        self.num_features = self.embedding_dim = embedding_dim
        self.head = nn.Linear(embedding_dim,num_class) if num_class>0 else nn.Identity()
    def forward_feature(self,x):
        x = self.PatchEmbed(x)
        cls_token = self.cls_token.expand(x.shape[0],-1,-1)
        x = torch.cat((cls_token,x),dim=1)
        x = x+self.position_embed
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return self.pre_logits(x[:,0])
    def forward(self,x):
        x = self.forward_feature(x)
        x = self.head(x)
        return x

image = Image.open('multi.jpg').resize((224,224))
image.save('src.png')
to_tensor = transforms.ToTensor()
to_PIL = transforms.ToPILImage()
model = Simple_VIT()
model.eval()
img = to_tensor(image)
img = img.unsqueeze(0)
res = model(img)
#res= to_PIL(res.squeeze(0))
print(res.softmax(dim=-1).data)
print(torch.argmax(res,dim=1).item())

1.关于attention操作之后为什么要进行scale操作:
比较大的输入会使得后续softmax的梯度变得很小,甚至导致梯度消失
softmax求导:
https://zhuanlan.zhihu.com/p/105758059
这里给出结论:
1.1.softmax loss求导:a-y,其中a是label,y是softmax的输出
1.2.softmax求导:

image.png

2.scale的取值为什么可以是head\_dim^{-0.5}:
假设attention操作中的q和k每个分量是相互独立的均值为0方差为1的随机变量,则q和k进行向量积操作之后q@k=attn,attn的每个分量变为均值为0方差为head_dim的随机变量

量级越大,softmax输出结果越"集中"

方差越大分量越有可能取到较大的量级,导致sotfmax操作之后的结果某一个\hat y取值接近1而其他y取值接近于0,导致梯度反向传播到attn的时候导致梯度消失,而对每个分量乘以head\_dim^{-0.5}会将其方差限制回1。
注意:如果softmax位于输出层,则不用过于考虑输入量级对softmax结果的影响;而attention中softmax位于网络中间层,所以需要考虑。

参考:
1.https://www.zhihu.com/search?type=content&q=attention%E4%B8%BA%E4%BB%80%E4%B9%88%E8%A6%81scale
2.https://zhuanlan.zhihu.com/p/105758059
3.https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

  • 参考链接: https://github.com/DA-southampton/NLP_ability/blob/...
    张知道q阅读 10,085评论 0 0
  • 一、AutoML架构图 https://zhuanlan.zhihu.com/p/44850626[https:/...
    加油11dd23阅读 2,944评论 0 1
  • [更新中...]---------------------------------Reference-------...
    chuuuing阅读 4,223评论 0 0
  • 我是黑夜里大雨纷飞的人啊 1 “又到一年六月,有人笑有人哭,有人欢乐有人忧愁,有人惊喜有人失落,有的觉得收获满满有...
    陌忘宇阅读 12,753评论 28 53
  • 信任包括信任自己和信任他人 很多时候,很多事情,失败、遗憾、错过,源于不自信,不信任他人 觉得自己做不成,别人做不...
    吴氵晃阅读 11,358评论 4 8

友情链接更多精彩内容