Mutihead-Self-Attention in Computer Vision

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image

class Attention(nn.Module):
    def __init__(self,dim=768,num_heads=8,qkv_bias=False,attn_drop=0.,proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim//num_heads
        self.qkv = nn.Linear(dim,3*dim,bias=qkv_bias)
        self.qkv_scale = self.head_dim**-0.5
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim,dim,bias=False)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self,x):
        #具体操作是将矩阵沿着dim方向均分num_head份,每一份进行矩阵乘积,最后把各个结果沿着head_dim方向进行concat输出
        B,N,C = x.shape
        #permute转置和transpose转置:https://blog.csdn.net/qq_41740004/article/details/104712173
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)
        #qkv:(3,B,num_heads,N,head_dim)
        q,k,v = qkv[0],qkv[1],qkv[2]
        #numpy点乘:使用*或者np.multiply();叉乘:@, np.dot(), np.matmul()
        #pytorch点乘:*或者np.multiply();叉乘:@,torch.mm()
        #点乘-矩阵内积;叉乘-矩阵乘积-矩阵向量积
        #https://blog.csdn.net/ykf173/article/details/104630175
        attn = (q @ k.transpose(-2,-1))*self.qkv_scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        attn = (attn @ v).transpose(1,2).reshape(B,N,C)
        x = self.proj(attn)
        x = self.proj_drop(x)
        return x
class PatchEmbed(nn.Module):
    def __init__(self,img_size=224,patch_size=16,in_chans=3,dim=768,norm_layer=nn.LayerNorm):
        super().__init__()
        self.img_size = img_size if isinstance(img_size,tuple) else (img_size,img_size)
        self.patch_size = patch_size if isinstance(patch_size,tuple) else (patch_size,patch_size)
        self.num_patches = self.img_size[0]//self.patch_size[0]*self.img_size[1]//self.patch_size[1]
        self.conv = nn.Conv2d(in_chans,dim,(self.patch_size[0],self.patch_size[1]),(self.patch_size[0],self.patch_size[1]),padding=0)
        self.norm = norm_layer(dim)
    def forward(self,x):
        B,C,H,W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1]
        x = self.conv(x)
        x = x.flatten(2).transpose(-2,-1)\
        x = self.norm(x)
        return x
class MLP(nn.Module):
    def __init__(self,c1,c_,c2,drop_rate,act_layer=nn.GELU):
        super().__init__()
        c_ = c_ or c1
        c2 = c2 or c1
        self.fc1 = nn.Linear(c1,c_)
        self.act = act_layer()
        self.fc2 = nn.Linear(c_,c2)
        self.drop = nn.Dropout(drop_rate)
    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
class Block(nn.Module):
    def __init__(self,dim=768,mlp_ratio=1.,num_heads=8,attn_drop=0.,qkv_scale=1.,drop=0.,norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim,num_heads,attn_drop,qkv_scale,drop)
        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        hidden_dim = int(dim*mlp_ratio)
        self.mlp = MLP(dim,hidden_dim,dim,drop)
    def forward(self,x):
        x = x+self.drop_path(self.attn(self.norm1(x)))
        x = x+self.drop_path(self.mlp(self.norm2(x)))
        return x
class Simple_VIT(nn.Module):
    def __init__(self,img_size=224,patch_size=16,in_chans=3,embedding_dim=768,depth=8,drop=0.,num_class=10):
        super().__init__()

        self.PatchEmbed = PatchEmbed(img_size,patch_size,in_chans,embedding_dim)
        num_patches = self.PatchEmbed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1,1,embedding_dim))#第一轮与之相关的权重不会更新,因为w*x中x为0,所以反向传播后>,w的梯度为0,但是本轮x会更新,下一轮w可以进行更新,因为下一轮x就不是0了,验证:position_embedding_20210521.py
        self.position_embed = nn.Parameter(torch.zeros(1,num_patches+1,embedding_dim))
        self.pos_drop = nn.Dropout(drop)
        self.blocks = nn.Sequential(*[Block(dim=embedding_dim,drop=drop) for i in range(depth)])
        self.norm = nn.LayerNorm(embedding_dim)
        self.pre_logits = nn.Identity()
        self.num_features = self.embedding_dim = embedding_dim
        self.head = nn.Linear(embedding_dim,num_class) if num_class>0 else nn.Identity()
    def forward_feature(self,x):
        x = self.PatchEmbed(x)
        cls_token = self.cls_token.expand(x.shape[0],-1,-1)
        x = torch.cat((cls_token,x),dim=1)
        x = x+self.position_embed
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return self.pre_logits(x[:,0])
    def forward(self,x):
        x = self.forward_feature(x)
        x = self.head(x)
        return x

image = Image.open('multi.jpg').resize((224,224))
image.save('src.png')
to_tensor = transforms.ToTensor()
to_PIL = transforms.ToPILImage()
model = Simple_VIT()
model.eval()
img = to_tensor(image)
img = img.unsqueeze(0)
res = model(img)
#res= to_PIL(res.squeeze(0))
print(res.softmax(dim=-1).data)
print(torch.argmax(res,dim=1).item())

1.关于attention操作之后为什么要进行scale操作:
比较大的输入会使得后续softmax的梯度变得很小,甚至导致梯度消失
softmax求导:
https://zhuanlan.zhihu.com/p/105758059
这里给出结论:
1.1.softmax loss求导:a-y,其中a是label,y是softmax的输出
1.2.softmax求导:

image.png

2.scale的取值为什么可以是head\_dim^{-0.5}:
假设attention操作中的q和k每个分量是相互独立的均值为0方差为1的随机变量,则q和k进行向量积操作之后q@k=attn,attn的每个分量变为均值为0方差为head_dim的随机变量

量级越大,softmax输出结果越"集中"

方差越大分量越有可能取到较大的量级,导致sotfmax操作之后的结果某一个\hat y取值接近1而其他y取值接近于0,导致梯度反向传播到attn的时候导致梯度消失,而对每个分量乘以head\_dim^{-0.5}会将其方差限制回1。
注意:如果softmax位于输出层,则不用过于考虑输入量级对softmax结果的影响;而attention中softmax位于网络中间层,所以需要考虑。

参考:
1.https://www.zhihu.com/search?type=content&q=attention%E4%B8%BA%E4%BB%80%E4%B9%88%E8%A6%81scale
2.https://zhuanlan.zhihu.com/p/105758059
3.https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,133评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,682评论 3 390
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,784评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,508评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,603评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,607评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,604评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,359评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,805评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,121评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,280评论 1 344
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,959评论 5 339
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,588评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,206评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,442评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,193评论 2 367
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,144评论 2 352

推荐阅读更多精彩内容

  • 参考链接: https://github.com/DA-southampton/NLP_ability/blob/...
    张知道q阅读 4,357评论 0 0
  • 一、AutoML架构图 https://zhuanlan.zhihu.com/p/44850626[https:/...
    加油11dd23阅读 268评论 0 1
  • [更新中...]---------------------------------Reference-------...
    chuuuing阅读 640评论 0 0
  • 我是黑夜里大雨纷飞的人啊 1 “又到一年六月,有人笑有人哭,有人欢乐有人忧愁,有人惊喜有人失落,有的觉得收获满满有...
    陌忘宇阅读 8,534评论 28 53
  • 信任包括信任自己和信任他人 很多时候,很多事情,失败、遗憾、错过,源于不自信,不信任他人 觉得自己做不成,别人做不...
    吴氵晃阅读 6,187评论 4 8