接着前面的文章说到的transformer,本篇将要介绍在图像中如何将transformer运用到图片分类中去的。我们知道CNN具有平移不变形,但是transformer基于self-attentation可以获得long-range信息(更大的感受野),但是CNN需要更多深层的Conv-layers来不断增大感受野。
这里将给出论文地址及代码地址:
论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
官方代码地址:https://github.com/google-research/vision_transformer(本文讲解对象)
jeonsworld/ViT-pytorch(相对来说容易理解)
本博客讲解代码地址:https://github.com/lucidrains/vit-pytorch
这里选择代码地址的原因首先是因为其star比较高,其次拥有多种变形模型及使用的pytorch框架进行编写,便于代码阅读。
一、论文阅读
这里主要讲解论文重点的部分
1. 优点:
相当于卷积模型来比,transformer在减少计算资源的同时获得了非常出色的结果。当对中等规模数据集(例如ImageNet)进行训练的时候,此模型所产生的适合的精度要比同等规模的ResNet低几个百分点。数据量越大,模型越友好。其中的Attention机制是一个很重要的机制我们通过一下图可以看出其优势, 让我们关注我们需要的物体,忽略没有用的东西。
传统的卷积神经网络需要大的感受野需要不断地卷积,才能获得更大的感受野,但是在我们
Transformer
的模式下,我们在浅层就已经获得很大的感受野,在做attention
的时候就已经看到了所有的信息,所以说其根本不需要堆叠,直接就可以获得全局信息(这里看完文章就会有感悟)。2. 启发:
很多图像特征提取器将CNN与专门注意力机制结合在一起,但是未能在硬件加速器扩展。
3. 与CNN的差距原因:
transformer缺乏CNN固有的一些感应偏差,例如平移不变性和局部性,因此在训练不足的数量时候,很难有好的效果。
4. VIT原理:
- 步骤一(input):
输入图像大小尺寸为(), 首先我们将图片进行切分,按照patch_size
进行切分,这样我们就得到了大小的一个个图块, 这里的图块数量为. 联想到transformer,这里的N就可以理解是序列长度,其中序列中每个element的维度dim
称之为patch embedding
。
在我们进行图片分类的时候我们一般在序列前加入一个element,我们称此element为, 这样我们得到序列长度为N+1,在训练的时候我们可以通过此element进行图片分类。最后再加上位置矩阵(注意这里是add不是concate)构成我们的输入矩阵z0
。
步骤二(forward):
transformer编码器主要由两个components构成分别是MSA(multi-head self-attention)
和 MLP([MultiLayer Perceptron)
组成。下面是前向传播的计算公式:
-
第一个公式
这里表示的是类别element,表示的是输入的每个patch,E代表的对应的权重,N表示的patch的数量 代表的是position的信息,这里不加位置编码,加一维编码(即1,2,3。。。),以及加位置编码的效果如下, 我们发现有位置编码比没有的效果好,但是多少维的效果差不多,我们一般采用2维度
。下面是分类不同维度编码效果,但是检测任务就不一定了。
-
第二个公式
这里的LN表示的Layer Normalization,MSA的公式如下:
这里的qkv
矩阵之前说过了,如通过输入z
与权重得到而来,我们在通过公式得到我们的Attention权重。最终利用v矩阵与attention权重相乘得到。因为考虑到多头记住所以我们得到如下公式:
还有一点注意的是根据如下公式可以看到vit模型结构也采用了残差机制:
5 微调和更高的分辨率
微调:
可删除预训练的head,附加初始化foward层, K是类的数量。更高分辨率:
当提供更高分辨率的图像时,我们将图块大小保持不变,这会导致更大的有效序列长度。ViT可以处理任意序列长度(直到内存限制),但是,预训练的位置embedding可能不再有意义。因此,我们根据预先训练的位置嵌入在原始图像中的位置执行2D插值。请注意,只有在分辨率调整和色块提取中,将有关图像2D结构的感应偏差手动注入到Vision Transformer中。
二、代码解读
本博客讲解代码地址:https://github.com/lucidrains/vit-pytorch
这里主要做的是猫狗分类模型,图片大小为256*256
1. 主函数
这里的main函数可以理解为是常规操作,大家稍微看下就理解了。
#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Maocheng Hu
@project_name:vit
@file:train.py
@time:2021/04/27/14/17
@ide:PyCharm
@email: wojiaohumaocheng@gmail.com
┏┓ ┏┓
┏┛┻━━━┛┻┓
┃ ┃
┃ ┳┛ ┗┳ ┃
┃ ┻ ┃
┗━┓ ┏━┛
┃ ┗━━━┓
┃ 神兽保佑 ┣┓
┃ 永无BUG! ┏┛
┗┓┓┏━┳┓┏┛
┃┫┫ ┃┫┫
┗┻┛ ┗┻┛
"""
import os
import glob
from tqdm import tqdm
import torch
import torch.nn as nn
from PIL import Image
import torch.optim as optim
from linformer import Linformer
# from vit_pytorch.efficient import ViT
from vit_pytorch import ViT
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
# image augmentation
def image_augmentation():
train_transforms = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
val_transforms = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
test_transforms = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
return train_transforms, val_transforms, test_transforms
# load data
def load_data(dataset):
train_list = glob.glob(os.path.join("{}/{}".format(dataset, "train"), "*.jpg"))
test_list = glob.glob(os.path.join("{}/{}".format(dataset, "test"), "*.jpg"))
train_label_list = [path.split('/')[-1].split('.')[0] for path in train_list]
# stratify for balancing classes
train_list, valid_list = train_test_split(train_list,
test_size=0.2,
stratify=train_label_list,
random_state=2021)
return train_list, valid_list, test_list
class CatsDogsDataset(Dataset):
def __init__(self, file_list, transform=None):
self.file_list = file_list
self.transformer = transform
def __len__(self):
self.file_length = len(self.file_list)
return self.file_length
def __getitem__(self, idx):
img_path = self.file_list[idx]
img = Image.open(img_path)
img_transformed = self.transformer(img)
label = img_path.split("/")[-1].split(".")[0]
label = 1 if label == "dog" else 0
return img_transformed, label
def main():
dataset = "dataset"
# Training settings
batch_size = 1
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42
device = 'cuda'
train_transforms, val_transforms, test_transforms = image_augmentation()
train_list, valid_list, test_list = load_data(dataset)
train_data = CatsDogsDataset(train_list, transform=train_transforms)
valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
test_data = CatsDogsDataset(test_list, transform=test_transforms)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True, num_workers=0)
# efficient_transformer = Linformer(
# dim=128,
# seq_len=49 + 1, # 7x7 patches + 1 cls-token
# depth=12,
# heads=8,
# k=64
# )
model = ViT(
image_size=224,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=16,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1
).to(device)
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
for epoch in range(epochs):
epoch_loss = 0
epoch_accuracy = 0
for data, label in tqdm(train_loader):
data = data.to(device)
label = label.to(device)
output = model(data)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (output.argmax(dim=1) == label).float().mean()
epoch_accuracy += acc / len(train_loader)
epoch_loss += loss / len(train_loader)
with torch.no_grad():
epoch_val_accuracy = 0
epoch_val_loss = 0
for data, label in valid_loader:
data = data.to(device)
label = label.to(device)
val_output = model(data)
val_loss = criterion(val_output, label)
acc = (val_output.argmax(dim=1) == label).float().mean()
epoch_val_accuracy += acc / len(valid_loader)
epoch_val_loss += val_loss / len(valid_loader)
print(
f"Epoch : {epoch + 1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
)
if __name__ == '__main__':
main()
2. 输入构成
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
dim_head=64, dropout=0., emb_dropout=0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2 # total patch dimension
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
关于输入我们只要看如下代码即可
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
首先我们先进行参数介绍
-
image_size
图片大小这里为224*224 -
patch_size
表示的是patch大小这里为 32 * 32,所以我们可以得到patch_num为7 * 7 -
num_classes
为图片类别数量 这里为 2 -
dim
表示的是序列每一个element的维度大小,这里为 1024 -
depth
表示的transformer模型的层数 -
heads
表示的是Multi-head Attention layer的head数,这里为16 -
mlp_dim
MLP层的hidden dim -
emb_dropout
对于输入做dropout
(1)确定patch size 大小
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' # 注意这里的patch size的大小必须能被图片尺寸整除
num_patches = (image_size // patch_size) ** 2 # 我们可以得到 num_patch数量为224^2//32^2 为7 * 7即49个。
(2) 确定patch dim大小
patch_dim = channels * patch_size ** 2 # 可以理解将图片所有像素重新排列等到patch的通道数即我们这一步可以得到(patch_num, patch_dim), 这里的patch_dim = 3 * 32 * 32 = 3072
(3) 使用分类方法:
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' # 这里我们使用 cls , 即单独用一个element来综合我们的特征信息,如果是用mean的话总是用我们的到的z(除去cls的特征信息的平局值)的平局值来综和我们的信息。如果这里不理解,后面也会介绍的。
(4) 维度转换:
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
nn.Linear(patch_dim, dim),
)
# 1. 这里的b代表batch_size, c代表channel,h以及w分别代表图像的高以及宽,p1及p2代表图像横纵切分的patch_size大小,所以Rerrange的矩阵大小为[batch_size, (7, 7),(32, 32, 3)] -->[batch_size, 49, 3072]
# 2. 经过nn.linear(3072, 1024), 最后我们得到(batch_size, 49, 1024)
也有的用卷积方式
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=patch_size)
(5) 加入类别位置
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
# 这里的操作相当于给input(batch_size, 49, 1024)第二个维度在增加一个类别维度得到(batch_size, 50, 1024)
(6) 地址编码
如下图所示。我们发现,位置越接近,往往具有更相似的位置编码。此外,出现了行列结构;同一行/列中的patch具有相似的位置编码。
# 这里相当于进行位置编码
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
x += self.pos_embedding[:, :(n + 1)]
# 注意这里的位置编码是直接相加的
(7)输入加dropout
self.dropout = nn.Dropout(emb_dropout)
x = self.dropout(x)
最终我们得到的输入尺寸为(batch_size, 50, 1024)
3. 模型构建
通过上述的输入x,将输入到我们的transformer模型里
# 1. 输入模型得到结果
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
x = self.transformer(x)
# 2. Transformer模型
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
# 3. Attentation机制
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 4. layer norm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# 5. forward 机制
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 6 输出
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
return self.mlp_head(x)
下面我们分解说下
(1)输入模型得到结果
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
x = self.transformer(x)
(2) Transformer机制
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
这里我们主要关注两个部分:
第一个部分:
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
第一个部分是attention机制, 第二个部分是forward机制。同时for循环,表示的多层机制。
第二个部分
x = attn(x) + x
x = ff(x) + x
这里使用残差的方式。
(3)Attention机制
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.): # dim为50, heads为16
super().__init__()
inner_dim = dim_head * heads # 这里的inner_dim 为 单个 头维度64,heads为头数量(这里为16),所以inner_dim为1024
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) # 这里的to_qkv为[50, 1024 * 3]
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads # 这里的b为batch_size, n为50, _为1024, h为16
qkv = self.to_qkv(x).chunk(3, dim=-1) # 通过同一个线性并行权重计算得到[50, 1024*3], 再通过chunk最后一个维度切分得到([50, 1024], [50, 1024], [50, 1024])的tuple形式。
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) # 通过一下转换分别得到q, k, v矩阵, 就是讲[batch_size, 50, 16, 64], 转换成[batch_size, 16, 50, 64]
# 下面需要根据公式进行qkv操作了,从而得到输出z。即z = softmax(Q * K^T/sqrt(d_k))V
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # 这里是将q与k相乘并除上sqrt(d_k)得到z_0
attn = self.attend(dots) # 对z_0进行softmax得到z_1
out = einsum('b h i j, b h j d -> b h i d', attn, v) # 这里将结果z_1与矩阵V进行相乘
out = rearrange(out, 'b h n d -> b n (h d)') # 这里对最终结果进行rerange得到shap为[batch_size, 50, 16,*64]
return self.to_out(out) # 过线性连接得到矩阵shape[1, 50, 1024]
这里是先用全连接在差分qkv矩阵
也有分别在一开始直接用全连接生成qkv矩阵
, 如下所示:
self.query = Linear(config.hidden_size, self.all_head_size)
self.key = Linear(config.hidden_size, self.all_head_size)
self.value = Linear(config.hidden_size, self.all_head_size)
self.out = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
(4) layer norm
# 这里相当于每一次无论是attention还是forward都是首先对输入矩阵进行layer norm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
(5)Feedforward
下面很容易理解了就是MLP的一部分了。相当于用了线性连接,这里dim=50, hidden_dim=64
, 这里使用的激活函数为GELU。最终我们得到[batch_size, 50,1024]矩阵。记住这里的多层矩阵输出仍然是[batch_size, 50, 1024]
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
(6)输出
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] # 这里x的shape为[batch_size, 50, 1024]输出
就是我之前说的如果我们采用mean, 则是在第二个维度求平均,如果不是mean其实第二个维度第一个就是特征表达。最终我们得到[1, 1024]个维度。
x = self.to_latent(x) # 相当于一个容器,把输入都保留下来了。这里我认为相当于保存特征,方便后面finetune操作。
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
return self.mlp_head(x) # 多层感知机, 输出类别的概率
4. 损失函数
# loss function
criterion = nn.CrossEntropyLoss()
参考