transformer教学代码,摘自d2l库

  • tf_learn.py
import torch
import math 
import torch.nn as nn

def sequence_mask(X, valid_len, value=-1e9):
    """根据valid_len将X中的非关联元素设为value。
    args:
        X: torch.Tensor, 输入的张量,形状为(batch_size * Q_timesteps, K_timesteps)
        valid_len: torch.Tensor, 有效长度,形状为(batch_size*time_steps,)
        value: float, 要替换的值, 默认为-1e9, 用于softmax操作, 使得非关联元素接近0
    return:
        torch.Tensor, 返回更新后的X
    """
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :]
    mask = mask < valid_len[:, None]
    X[~mask] = value
    return X

def masked_softmax(X, valid_lens):
    """这个函数在张量`X`的最后一个轴上执行softmax操作,但在此之前,它会根据`valid_lens`屏蔽某些元素。
    args:
        X: torch.Tensor, 输入张量,形状为(batch_size, Q_timesteps, K_timesteps)
        valid_lens: torch.Tensor, 有效长度,形状为(batch_size,) 即直接表示每个输入序列上的有效长度,或者 (batch_size, num_steps)即每个输入序列的每个时间点的有效长度
    return:
        torch.Tensor, 返回softmax操作后的张量
    """
    # `X`: 3D张量, `valid_lens`: 1D或2D张量
    # 如果`valid_lens`为None,表示无需屏蔽任何元素,我们只需在`X`的最后一个轴上执行常规的softmax操作。
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        # 存储`X`的形状以备后用。
        shape = X.shape
        # 如果`valid_lens`是1D张量,我们将`valid_lens`中的每个元素重复`shape[1]`次。
        # 这是因为我们想创建一个与`X`的第二个维度匹配的掩码。
        if valid_lens.dim() == 1:
            # 如果`valid_lens`是1D张量,我们将`valid_lens`中的每个元素重复`time_steps`次,形成1D张量。
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 如果`valid_lens`不是1D张量,我们将其展平并转换为1D张量。
            valid_lens = valid_lens.reshape(-1)
        # 我们将`X`重塑为2D (-1, shape[-1]),并应用`sequence_mask`函数。
        # 这个函数将替换被掩码的元素(不在有效长度内的元素)为一个非常大的负值(-1e9)。
        # 当我们稍后应用softmax函数时,这些大的负值将变为0,有效地“屏蔽”这些元素。
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e9)
        # 最后,我们将`X`重塑回原来的形状,并在最后一个轴上执行softmax操作。
        # 结果是一个与`X`形状相同的张量,但最后一个轴上的某些元素被屏蔽(设为0)。
        return nn.functional.softmax(X.reshape(shape), dim=-1)

# X = torch.arange(24).reshape(2, 12).type(torch.float32)
# valid_lens = torch.tensor([3, 2])
# ms = masked_softmax(X, valid_lens)
# print(ms)

class DotProductAttention(nn.Module):
    """
    这是一个实现了缩放点积注意力机制的类。
    """
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        """前向传播函数,接收查询、键、值和有效长度作为输入。

        Args:
            queries: torch.Tensor, 查询张量,形状为(batch_size, num_Q_timesteps, d_model)
            keys: torch.Tensor, 键张量,形状为(batch_size, num_K_timesteps, d_model)
            values: torch.Tensor, 值张量,形状为(batch_size, num_V_timesteps, d_model)
            valid_lens: torch.Tensor, 有效长度,形状为(batch_size,) 或 (batch_size, num_Q_timesteps)
        Returns:
            torch.Tensor, 返回输出张量,形状为(batch_size, num_Q_timesteps, d_model)
        """
        # 获取查询的最后一个维度的大小,即d_model。
        d = queries.shape[-1]
        # 计算查询和键的点积,然后除以sqrt(d)进行缩放,得到得分。
        # 使用`transpose`函数交换键的最后两个维度,以便进行矩阵乘法。
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        # 使用`masked_softmax`函数对得分进行softmax操作并生成掩码,得到注意力权重。
        self.attention_weights = masked_softmax(scores, valid_lens)
        # 将注意力权重应用到值上,得到输出。在应用注意力权重之前,先对其进行dropout操作。
        return torch.bmm(self.dropout(self.attention_weights), values)

def transpose_qkv(X, num_heads):
    """Transposition for parallel computation of multiple attention heads.

    Defined in :numref:`sec_multihead-attention`"""
    # Shape of input `X`:
    # (`batch_size`, no. of queries or key-value pairs, `d_model`).
    # Shape of output `X`:
    # (`batch_size`, no. of queries or key-value pairs, `num_heads`,
    # `d_model` / `num_heads`)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # Shape of output `X`:
    # (`batch_size`, `num_heads`, no. of queries or key-value pairs,
    # `d_model` / `num_heads`)
    X = X.permute(0, 2, 1, 3)

    # Shape of `output`:
    # (`batch_size` * `num_heads`, no. of queries or key-value pairs,
    # `d_model` / `num_heads`)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """Reverse the operation of `transpose_qkv`.

    Defined in :numref:`sec_multihead-attention`"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class MultiHeadAttention(nn.Module):
    """Multi-head attention.

    Defined in :numref:`sec_multihead-attention`"""
    def __init__(self, key_size, query_size, value_size, d_model,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, d_model, bias=bias)
        self.W_k = nn.Linear(key_size, d_model, bias=bias)
        self.W_v = nn.Linear(value_size, d_model, bias=bias)
        self.W_o = nn.Linear(d_model, d_model, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # Shape of `queries`, `keys`, or `values`:
        # (`batch_size`, no. of queries or key-value pairs, `d_model`)
        # Shape of `valid_lens`:
        # (`batch_size`,) or (`batch_size`, no. of queries)
        # After transposing, shape of output `queries`, `keys`, or `values`:
        # (`batch_size` * `num_heads`, no. of queries or key-value pairs,
        # `d_model` / `num_heads`)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for
            # `num_heads` times, then copy the next item, and so on
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
        # `d_model` / `num_heads`)
        output = self.attention(queries, keys, values, valid_lens)

        # Shape of `output_concat`:
        # (`batch_size`, no. of queries, `d_model`)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

class PositionWiseFFN(nn.Module):
    """Positionwise feed-forward network.

    Defined in :numref:`sec_transformer`"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

# ffn = PositionWiseFFN(4, 4, 8)
# ffn.eval()
# ffn(torch.ones((2, 3, 4)))

class AddNorm(nn.Module):
    """Residual connection followed by layer normalization.

    Defined in :numref:`sec_transformer`"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

class EncoderBlock(nn.Module):
    """Transformer encoder block.

    Defined in :numref:`sec_transformer`"""
    def __init__(self, key_size, query_size, value_size, d_model,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(
            key_size, query_size, value_size, d_model, num_heads, dropout,
            use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(
            ffn_num_input, ffn_num_hiddens, d_model)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
encoder_blk.eval()
encoder_blk(X, valid_lens).shape

class PositionalEncoding(nn.Module):
    """Positional encoding.

    Defined in :numref:`sec_self-attention-and-positional-encoding`"""
    def __init__(self, d_model, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # Create a long enough `P`
        self.P = torch.zeros((1, max_len, d_model))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, d_model, 2, dtype=torch.float32) / d_model)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

class TransformerEncoder(nn.Module):
    """Transformer encoder.

    Defined in :numref:`sec_transformer`"""
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 d_model, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                EncoderBlock(key_size, query_size, value_size, d_model,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        # Since positional encoding values are between -1 and 1, the embedding
        # values are multiplied by the square root of the embedding dimension
        # to rescale before they are summed up
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.d_model))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[
                i] = blk.attention.attention.attention_weights
        return X

# encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
# encoder.eval()
# encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

class DecoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, d_model,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(
            key_size, query_size, value_size, d_model, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(
            key_size, query_size, value_size, d_model, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
                                   d_model)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # During training, all the tokens of any output sequence are processed
        # at the same time, so `state[2][self.i]` is `None` as initialized.
        # During prediction, `state[2][self.i]` contains previous tokens of
        # the output sequence
        if state[2][self.i] is None: # 当前是训练阶段,或者第一次预测
            key_values = X
        else: # 当前是预测阶段
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            # Shape of `dec_valid_lens`: (`batch_size`, `num_steps`), where
            # every row is [1, 2, ..., `num_steps`]
            dec_valid_lens = torch.arange(1, num_steps + 1,
                                          device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        # Self-attention
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # Encoder-decoder attention. Shape of `enc_outputs`:
        # (`batch_size`, `num_steps`, `d_model`)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 d_model, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.d_model = d_model
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                DecoderBlock(key_size, query_size, value_size, d_model,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, i))
        self.dense = nn.Linear(d_model, vocab_size)

    def init_state(self, enc_outputs, env_valid_lens, *args):
        return [enc_outputs, env_valid_lens, [None]*self.num_layers]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.d_model))
        self._attention_weights = [[None, None] for _ in range(self.num_layers)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            # Decoder self-attention weights
            self._attention_weights[i][0] = blk.attention1.attention.attention_weights
            # Encoder-decoder attention weights
            self._attention_weights[i][1] = blk.attention2.attention.attention_weights
        return self.dense(X), state

    @property
    def attention_weights(self):
        return self._attention_weights
    

class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, key_size, query_size, value_size,
                 d_model, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, **kwargs):
        super(Transformer, self).__init__(**kwargs)
        self.encoder = TransformerEncoder(
            src_vocab, key_size, query_size, value_size, d_model,
            norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
            num_layers, dropout)
        self.decoder = TransformerDecoder(
            tgt_vocab, key_size, query_size, value_size, d_model,
            norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
            num_layers, dropout)
    
    def forward(self, src, tgt, src_valid_lens):
        enc_outputs = self.encoder(src, src_valid_lens)
        return self.decoder(tgt, self.decoder.init_state(enc_outputs, src_valid_lens))
  • tf_learn_train.py
import torch
from tf_learn import Transformer
import d2l.torch as d2l

d_model, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_input, ffn_num_hiddens, num_heads = d_model, 64, 4
key_size, query_size, value_size = d_model, d_model, d_model
norm_shape = [d_model] # layer normalization shape

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

net = Transformer(len(src_vocab), len(tgt_vocab), key_size, query_size,
                            value_size, d_model, norm_shape, ffn_num_input,
                            ffn_num_hiddens, num_heads, num_layers, dropout)

# net = d2l.EncoderDecoder(encoder, decoder)

d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

# save model
torch.save(net.state_dict(), 'transformer.pth')
  • tf_learn_predict.py
import torch
from tf_learn import Transformer
import d2l.torch as d2l

d_model, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_input, ffn_num_hiddens, num_heads = d_model, 64, 4
key_size, query_size, value_size = d_model, d_model, d_model
norm_shape = [d_model] # layer normalization shape

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

net = Transformer(len(src_vocab), len(tgt_vocab), key_size, query_size,
                            value_size, d_model, norm_shape, ffn_num_input,
                            ffn_num_hiddens, num_heads, num_layers, dropout)

net.load_state_dict(torch.load('transformer.pth'))
net.to(device)
net.eval()

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']

for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ',
          f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,744评论 6 502
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,505评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 163,105评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,242评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,269评论 6 389
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,215评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,096评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,939评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,354评论 1 311
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,573评论 2 333
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,745评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,448评论 5 344
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,048评论 3 327
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,683评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,838评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,776评论 2 369
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,652评论 2 354

推荐阅读更多精彩内容