论文《AXIAL ATTENTION IN MULTIDIMENSIONAL TRANSFORMERS》
1、作用
Axial Attention 提出了一种用于图像和其他作为高维张量组织的数据的自注意力基的自回归模型。传统的自回归模型要么因高维数据而导致计算资源需求过大,要么为了减少资源需求而在分布表达性或实现简便性方面做出妥协。Axial Transformers 设计旨在在保持数据上联合分布的完整表达性和易于使用标准深度学习框架实现的同时,要求合理的内存和计算资源,并在标准生成建模基准上实现最先进的结果。
2、机制
1、轴向注意力:
与对张量元素的序列应用标准自注意力不同,Axial Transformer 沿着张量的单个轴应用注意力,称为“轴向注意力”,而不是展平张量。这种操作在计算和内存使用上比标准自注意力节省显著,因为它自然地与张量的多个维度对齐。
2、半并行结构:
Axial Transformer 的层结构允许在解码时并行计算绝大多数上下文,而无需引入任何独立性假设,这对于即使是非常大的Axial Transformer也是广泛适用的。
3、独特优势
1、计算效率:
Axial Transformer 通过轴向注意力操作在资源使用上实现了显著节省,对于具有 N = N1/d × · · · × N1/d 形状的 d 维张量,相比标准自注意力,轴向注意力在资源上节省了 O(N(d−1)/d) 因子。
2、完全表达性:
尽管Axial Transformer沿单个轴应用注意力,但其结构设计确保了模型可以表达数据的全局依赖性,不丢失对前一个像素的依赖性。
3、简单易实现:
Axial Transformer 不需要为GPU或TPU编写特定的子程序,它可以使用深度学习框架中广泛可用的高效操作(主要是密集的MatMul操作)简单实现。
4、代码
import torch
from torch import nn
from operator import itemgetter
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
# 定义一个模块包装器,确保通过保存和恢复随机数生成器(RNG)状态的确定性行为。
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net # 要包装的网络
self.cpu_state = None # CPU RNG状态
self.cuda_in_fwd = None # 前向传递中是否使用了CUDA
self.gpu_devices = None # 使用的GPU设备
self.gpu_states = None # GPU RNG状态
# 记录当前的随机状态
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
# 前向传递
def forward(self, *args, record_rng=False, set_rng=False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# 可逆块模块,实现可逆网络中的一个块
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f) # 包装f函数,确保确定性
self.g = Deterministic(g) # 包装g函数,确保确定性
# 前向传递,实现可逆计算
def forward(self, x, f_args={}, g_args={}):
x1, x2 = torch.chunk(x, 2, dim=1) # 将输入分为两部分
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args) # 计算y1
y2 = x2 + self.g(y1, record_rng=self.training, **g_args) # 计算y2
return torch.cat([y1, y2], dim=1) # 返回合并后的结果
# 反向传递,用于梯度计算
def backward_pass(self, y, dy, f_args={}, g_args={}):
y1, y2 = torch.chunk(y, 2, dim=1)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=1)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=1)
dx = torch.cat([dx1, dx2], dim=1)
return x, dx
# 不可逆块模块,对比可逆块的实现
class IrreversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = f# 直接使用f函数
self.g = g# 直接使用g函数
def forward(self, x, f_args, g_args):
x1, x2 = torch.chunk(x, 2, dim=1)
y1 = x1 + self.f(x2, **f_args)
y2 = x2 + self.g(y1, **g_args)
return torch.cat([y1, y2], dim=1)
# 可逆函数实现,用于在可逆网络中应用自定义的可逆操作
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, kwargs):
ctx.kwargs = kwargs
for block in blocks:
x = block(x, **kwargs)
ctx.y = x.detach()
ctx.blocks = blocks
return x
@staticmethod
def backward(ctx, dy):
y = ctx.y
kwargs = ctx.kwargs
for block in ctx.blocks[::-1]:
y, dy = block.backward_pass(y, dy, **kwargs)
return dy, None, None
class ReversibleSequence(nn.Module): #逆块串联起来,构成一个可逆的网络结构。
def __init__(self, blocks, ):
super().__init__()
self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])# 将传入的函数对构建为可逆块,并加入模块列表
def forward(self, x, arg_route=(True, True), **kwargs):
f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)# 将传入的函数对构建为可逆块,并加入模块列表
block_kwargs = {'f_args': f_args, 'g_args': g_args}
x = torch.cat((x, x), dim=1) # 将输入复制一份并合并,为可逆计算做准备
x = _ReversibleFunction.apply(x, self.blocks, block_kwargs)# 通过_ReversibleFunction执行可逆序列的前向计算
return torch.stack(x.chunk(2, dim=1)).mean(dim=0)# 将结果拆分并取均值,完成前向传递
# 检查值是否非None
def exists(val):
return val is not None
# 从数组中按索引映射元素
def map_el_ind(arr, ind):
return list(map(itemgetter(ind), arr))
# 对数组进行排序并返回原始索引
def sort_and_return_indices(arr):
indices = [ind for ind in range(len(arr))]# 创建索引列表
arr = zip(arr, indices) # 将数组的元素与它们的索引配对
arr = sorted(arr) # 对配对进行排序
return map_el_ind(arr, 0), map_el_ind(arr, 1) # 返回排序后的数组和对应的原始索引
# 计算维度排列
def calculate_permutations(num_dimensions, emb_dim):
total_dimensions = num_dimensions + 2
emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions)
axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]
permutations = []
for axial_dim in axial_dims:
last_two_dims = [axial_dim, emb_dim]
dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
permutation = [*dims_rest, *last_two_dims]
permutations.append(permutation)
return permutations
# 通道层归一化
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (std + self.eps) * self.g + self.b
# 前置归一化
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
# 顺序执行模块
class Sequential(nn.Module):
def __init__(self, blocks):
super().__init__()
self.blocks = blocks
def forward(self, x):
for f, g in self.blocks:
x = x + f(x)
x = x + g(x)
return x
# 维度置换
class PermuteToFrom(nn.Module):
def __init__(self, permutation, fn):
super().__init__()
self.fn = fn
_, inv_permutation = sort_and_return_indices(permutation)
self.permutation = permutation
self.inv_permutation = inv_permutation
def forward(self, x, **kwargs):
axial = x.permute(*self.permutation).contiguous()
shape = axial.shape
*_, t, d = shape
axial = axial.reshape(-1, t, d)
axial = self.fn(axial, **kwargs)
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
return axial
#轴向位置嵌入
class AxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape, emb_dim_index=1):
super().__init__()
parameters = []
total_dimensions = len(shape) + 2
ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
self.num_axials = len(shape)
for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
shape = [1] * total_dimensions
shape[emb_dim_index] = dim
shape[axial_dim_index] = axial_dim
parameter = nn.Parameter(torch.randn(*shape))
setattr(self, f'param_{i}', parameter)
def forward(self, x):
for i in range(self.num_axials):
x = x + getattr(self, f'param_{i}')
return x
#自注意力模块
class SelfAttention(nn.Module):
def __init__(self, dim, heads, dim_heads=None):
super().__init__()
self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
dim_hidden = self.dim_heads * heads
self.heads = heads
self.to_q = nn.Linear(dim, dim_hidden, bias=False)
self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias=False)
self.to_out = nn.Linear(dim_hidden, dim)
def forward(self, x, kv=None):
kv = x if kv is None else kv
q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))
b, t, d, h, e = *q.shape, self.heads, self.dim_heads
merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
q, k, v = map(merge_heads, (q, k, v))
dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)
dots = dots.softmax(dim=-1)
out = torch.einsum('bij,bje->bie', dots, v)
out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
out = self.to_out(out)
return out
#轴向注意力模块
class AxialAttention(nn.Module):
def __init__(self, dim, num_dimensions=2, heads=8, dim_heads=None, dim_index=-1, sum_axial_out=True):
assert (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'
super().__init__()
self.dim = dim# 特征维度
self.total_dimensions = num_dimensions + 2# 总维度数
self.dim_index = dim_index if dim_index > 0 else (dim_index + self.total_dimensions)
attentions = []
for permutation in calculate_permutations(num_dimensions, dim_index):
attentions.append(PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads)))
self.axial_attentions = nn.ModuleList(attentions)
self.sum_axial_out = sum_axial_out
def forward(self, x):
assert len(x.shape) == self.total_dimensions, 'input tensor does not have the correct number of dimensions'
assert x.shape[self.dim_index] == self.dim, 'input tensor does not have the correct input dimension'
if self.sum_axial_out:
return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))
out = x
for axial_attn in self.axial_attentions:
out = axial_attn(out)
return out
class AxialImageTransformer(nn.Module):
def __init__(self, dim, depth, heads=8, dim_heads=None, dim_index=1, reversible=True, axial_pos_emb_shape=None):
super().__init__()
permutations = calculate_permutations(2, dim_index)
get_ff = lambda: nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, dim * 4, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(dim * 4, dim, 3, padding=1)
)
self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if exists(
axial_pos_emb_shape) else nn.Identity()
layers = nn.ModuleList([])
for _ in range(depth):
attn_functions = nn.ModuleList(
[PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads))) for permutation in
permutations])
conv_functions = nn.ModuleList([get_ff(), get_ff()])
layers.append(attn_functions)
layers.append(conv_functions)
execute_type = ReversibleSequence if reversible else Sequential
self.layers = execute_type(layers)
def forward(self, x):
x = self.pos_emb(x)
return self.layers(x)
if __name__ == '__main__':
block = AxialImageTransformer(
dim=64,
depth=12,
reversible=True
).cuda()
input = torch.rand(1, 64, 64, 64).cuda()
output = block(input)
print(output.shape)