35、AFT模块
论文《An Attention Free Transformer》
1、作用
注意力自由变换器(AFT)旨在通过去除传统Transformer中的点积自注意力机制,提供一种更高效的变换器模型。它特别适用于需要高计算效率和较低内存消耗的应用场景,如移动设备和边缘计算。
2、机制
AFT通过直接对输入特征进行变换来实现序列间的关联,不再需要复杂的自注意力计算。它使用一种简单的基于位置的加权策略,通过这种方式,每个输出元素是输入元素的加权和,权重由元素的相对位置决定。这种方法极大地降低了模型的复杂性和运行时内存需求。
3、独特优势
1、高效性:AFT由于避免了昂贵的自注意力计算,因此在执行速度和计算效率上有明显优势。
2、简化模型结构:通过消除自注意力机制,AFT简化了模型结构,使得模型更加轻量化,易于实现和部署。
3、适应性强:AFT的结构使其更容易适应于不同的任务和数据集,具有良好的泛化能力。
4、资源占用低:对于资源受限的环境,如移动设备和边缘计算设备,AFT提供了一种实用的解决方案,能够在保持较高性能的同时,降低资源消耗。
4、代码
import numpy as np
import torch
from torch import nn
from torch.nn import init
class AFT_FULL(nn.Module):
# 初始化AFT_FULL模块
def __init__(self, d_model, n=49, simple=False):
super(AFT_FULL, self).__init__()
# 定义QKV三个线性变换层
self.fc_q = nn.Linear(d_model, d_model)
self.fc_k = nn.Linear(d_model, d_model)
self.fc_v = nn.Linear(d_model, d_model)
# 根据simple参数决定位置偏置的初始化方式
if (simple):
self.position_biases = torch.zeros((n, n)) # 简单模式下为零矩阵
else:
self.position_biases = nn.Parameter(torch.ones((n, n))) # 非简单模式下为可学习的参数
self.d_model = d_model
self.n = n # 输入序列的长度
self.sigmoid = nn.Sigmoid() # 使用Sigmoid函数
self.init_weights() # 初始化模型权重
def init_weights(self):
# 对模块中的参数进行初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, input):
bs, n, dim = input.shape # 输入的批大小、序列长度和特征维度
# 通过QKV变换生成查询、键和值
q = self.fc_q(input) # bs,n,dim
k = self.fc_k(input).view(1, bs, n, dim) # 1,bs,n,dim,为了后续运算方便
v = self.fc_v(input).view(1, bs, n, dim) # 1,bs,n,dim
# 使用位置偏置和键值对进行加权求和
numerator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)) * v, dim=2) # n,bs,dim
denominator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)), dim=2) # n,bs,dim
# 计算加权求和的结果,并通过sigmoid函数调制查询向量
out = (numerator / denominator) # n,bs,dim
out = self.sigmoid(q) * (out.permute(1, 0, 2)) # bs,n,dim,最后将结果重新排列
return out
# 示例使用
if __name__ == '__main__':
block = AFT_FULL(d_model=512, n=64).cuda()
input = torch.rand(64, 64, 512).cuda()
output = block(input)
print(output.shape) # 打印输出形状