深度学习模块36-AFT模块

35、AFT模块

论文《An Attention Free Transformer》

1、作用

注意力自由变换器(AFT)旨在通过去除传统Transformer中的点积自注意力机制,提供一种更高效的变换器模型。它特别适用于需要高计算效率和较低内存消耗的应用场景,如移动设备和边缘计算。

2、机制

AFT通过直接对输入特征进行变换来实现序列间的关联,不再需要复杂的自注意力计算。它使用一种简单的基于位置的加权策略,通过这种方式,每个输出元素是输入元素的加权和,权重由元素的相对位置决定。这种方法极大地降低了模型的复杂性和运行时内存需求。

3、独特优势

1、高效性:AFT由于避免了昂贵的自注意力计算,因此在执行速度和计算效率上有明显优势。

2、简化模型结构:通过消除自注意力机制,AFT简化了模型结构,使得模型更加轻量化,易于实现和部署。

3、适应性强:AFT的结构使其更容易适应于不同的任务和数据集,具有良好的泛化能力。

4、资源占用低:对于资源受限的环境,如移动设备和边缘计算设备,AFT提供了一种实用的解决方案,能够在保持较高性能的同时,降低资源消耗。

4、代码

import numpy as np
import torch
from torch import nn
from torch.nn import init

class AFT_FULL(nn.Module):
    # 初始化AFT_FULL模块
    def __init__(self, d_model, n=49, simple=False):
        super(AFT_FULL, self).__init__()
        # 定义QKV三个线性变换层
        self.fc_q = nn.Linear(d_model, d_model)
        self.fc_k = nn.Linear(d_model, d_model)
        self.fc_v = nn.Linear(d_model, d_model)
        # 根据simple参数决定位置偏置的初始化方式
        if (simple):
            self.position_biases = torch.zeros((n, n))  # 简单模式下为零矩阵
        else:
            self.position_biases = nn.Parameter(torch.ones((n, n)))  # 非简单模式下为可学习的参数
        self.d_model = d_model
        self.n = n  # 输入序列的长度
        self.sigmoid = nn.Sigmoid()  # 使用Sigmoid函数

        self.init_weights()  # 初始化模型权重

    def init_weights(self):
        # 对模块中的参数进行初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, input):
        bs, n, dim = input.shape  # 输入的批大小、序列长度和特征维度

        # 通过QKV变换生成查询、键和值
        q = self.fc_q(input)  # bs,n,dim
        k = self.fc_k(input).view(1, bs, n, dim)  # 1,bs,n,dim,为了后续运算方便
        v = self.fc_v(input).view(1, bs, n, dim)  # 1,bs,n,dim

        # 使用位置偏置和键值对进行加权求和
        numerator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)) * v, dim=2)  # n,bs,dim
        denominator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)), dim=2)  # n,bs,dim

        # 计算加权求和的结果,并通过sigmoid函数调制查询向量
        out = (numerator / denominator)  # n,bs,dim
        out = self.sigmoid(q) * (out.permute(1, 0, 2))  # bs,n,dim,最后将结果重新排列

        return out

# 示例使用
if __name__ == '__main__':
    block = AFT_FULL(d_model=512, n=64).cuda()
    input = torch.rand(64, 64, 512).cuda()
    output = block(input)
    print(output.shape) # 打印输出形状

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容