扣细节,关于LLM的一些数据维度测试的代码

总觉得对大模型的一些细节把握不准,为了印在心里,这次增加了对数据维度的测试。代码也比之前的要正规。为了能把验证过程的参数深入到第一行的代码,我甚至加了训练过程额外的参数传递,感觉还是有参考和记录价值的。下一步,就是像玩一样把第一步的数据维度变幻和内容反复求证,不留盲点。

一,代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

# ===================== 核心配置 =====================
VOCAB_SIZE = 1000  # 词表大小
EMBED_DIM = 16  # Token嵌入维度
NUM_HEADS = 4  # 注意力头数,必须满足 EMBED_DIM % NUM_HEADS == 0
HEAD_DIM = EMBED_DIM // NUM_HEADS  # 每个注意力头的维度
NUM_LAYERS = 3  # Decoder层数
SEQ_LEN = 22  # 训练序列长度
GEN_MAX_LEN = 50  # 生成序列最大长度
BATCH_SIZE = 9  # 批次大小
LR = 1e-3  # 学习率
EPOCHS = 5  # 训练轮数
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ===================== 数据集:自回归语言建模 =====================
class DummyTextDataset(Dataset):
    """生成虚拟Token序列,模拟预训练语料"""

    def __init__(self, vocab_size, seq_len, sample_num=1000):
        self.data = torch.randint(0, vocab_size, (sample_num, seq_len))  # [样本数, 序列长度]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]  # 输入序列
        y = torch.roll(x, -1, dims=-1)  # 标签:输入右移一位(预测下一个Token)
        y[-1] = 0  # 最后一个Token无后续,标签置0
        return x.to(DEVICE), y.to(DEVICE)


# ===================== 核心组件1:带因果掩码的多头自注意力 =====================
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, head_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = head_dim

        # Q/K/V 线性映射(共享权重,维度不变)
        self.w_q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.w_k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.w_v = nn.Linear(embed_dim, embed_dim, bias=False)
        # 输出投影层
        self.w_o = nn.Linear(embed_dim, embed_dim, bias=False)
        print('self.w_q', self.w_q.weight.shape)
        print('self.w_k', self.w_k.weight.shape)
        print('self.w_v', self.w_v.weight.shape)
        print('self.w_o', self.w_o.weight.shape)

    def generate_causal_mask(self, seq_len):
        """生成因果掩码:上三角为-inf,下三角为0,防止看到未来Token
        mask形状: [1, seq_len, seq_len]
        """
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(DEVICE)
        mask = mask.masked_fill(mask == 1, float("-inf"))
        return mask

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        """
        缩放点积注意力核心公式:Attention(Q,K,V) = softmax(QK^T/√d_k)V
        q/k/v形状: [batch_size, num_heads, seq_len, head_dim]
        """
        # 1. 计算QK^T,缩放因子为√head_dim
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(self.head_dim, dtype=torch.float32)).to(DEVICE)
        # 2. 应用因果掩码(未来Token注意力分数置为-inf,softmax后为0)
        if mask is not None:
            attn_scores = attn_scores + mask
        # 3. softmax归一化,得到注意力权重
        attn_weights = F.softmax(attn_scores, dim=-1)
        # 4. 加权求和V,得到注意力输出
        output = torch.matmul(attn_weights, v)
        return output, attn_weights

    def split_heads(self, x):
        """将输入拆分为多个头:[batch, seq_len, embed_dim] → [batch, num_heads, seq_len, head_dim]"""
        batch_size, seq_len, embed_dim = x.shape
        return x.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    def concat_heads(self, x):
        """将多个头的输出拼接:[batch, num_heads, seq_len, head_dim] → [batch, seq_len, embed_dim]"""
        batch_size, num_heads, seq_len, head_dim = x.shape
        return x.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)

    def forward(self, x, batch_idx, epoch):
        batch_size, seq_len, _ = x.shape
        # 1. Q/K/V 线性映射
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        # 2. 拆分多头
        q = self.split_heads(q)
        k = self.split_heads(k)
        v = self.split_heads(v)
        # 3. 生成因果掩码
        mask = self.generate_causal_mask(seq_len)
        # 4. 缩放点积注意力计算
        attn_out, _ = self.scaled_dot_product_attention(q, k, v, mask)
        # 5. 拼接多头输出
        attn_out_concat = self.concat_heads(attn_out)
        # 6. 输出投影
        output = self.w_o(attn_out_concat)
        if batch_idx == 1 and epoch == 1:
            print('attn.q', q.shape)
            print('attn.mask', mask.shape)
            print('attn_out.', attn_out.shape)
            print('attn_out_concat', attn_out_concat.shape)
            print('attn.output', output.shape)
        return output


# ===================== 核心组件2:前馈神经网络 FFN =====================
class FeedForwardNetwork(nn.Module):
    def __init__(self, embed_dim, hidden_dim=None):
        super().__init__()
        # FFN结构:Linear → ReLU → Linear,隐藏层维度通常为embed_dim的4倍
        self.hidden_dim = hidden_dim if hidden_dim else embed_dim * 4
        self.fc1 = nn.Linear(embed_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, embed_dim)
        print('self.fc1', self.fc1.weight.shape)
        print('self.fc2', self.fc2.weight.shape)
        self.relu = nn.ReLU()  # GPT用GELU,这里简化用ReLU

    def forward(self, x, batch_idx, epoch):
        # x: [batch, seq_len, embed_dim]
        if batch_idx == 1 and epoch == 1:
            print('forward.x', x.shape)
        return self.fc2(self.relu(self.fc1(x)))


# ===================== 核心组件3:Decoder Block =====================
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, head_dim):
        super().__init__()
        # 子层1:多头自注意力 + 残差连接 + 层归一化
        self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, head_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        # 子层2:前馈网络 + 残差连接 + 层归一化
        self.ffn = FeedForwardNetwork(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        # Dropout防止过拟合
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, batch_idx, epoch):
        # 残差连接:输入 + 子层输出
        attn_out = self.self_attn(x, batch_idx, epoch)
        x = self.norm1(x + self.dropout(attn_out))  # Pre-LN结构(主流大模型方案)
        ffn_out = self.ffn(x, batch_idx, epoch)
        x = self.norm2(x + self.dropout(ffn_out))
        if batch_idx == 1 and epoch == 1:
            print('DecoderBlock x.shape', x.shape)
        return x

# 自定义支持多参数的 Sequential 容器,为了更透明的了解参数传递过程
class MultiInputSequential(nn.Sequential):
    def forward(self, x, *args, **kwargs):
        # 遍历 Sequential 中的每个模块,依次传递所有参数
        for module in self:
            # 将 x 和额外的 args/kwargs 都传给每个 DecoderBlock
            x = module(x, *args, **kwargs)
        return x

# ===================== 整体模型:Decoder-only 大模型 =====================
class DecoderOnlyLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, head_dim, num_layers):
        super().__init__()
        # 1. Token嵌入层:将Token ID映射为向量
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        print('self.token_embedding', self.token_embedding.weight.shape)
        # 2. 位置嵌入层:学习序列位置信息(非正弦编码,简化易实现)
        self.pos_embedding = nn.Embedding(GEN_MAX_LEN, embed_dim)  # 适配生成最大长度
        print('self.pos_embedding', self.pos_embedding.weight.shape)
        # 3. Decoder层堆叠
        self.decoder_blocks = MultiInputSequential(
            *[DecoderBlock(embed_dim, num_heads, head_dim) for _ in range(num_layers)]
        )
        # 4. 输出映射层
        self.fc_out = nn.Linear(embed_dim, vocab_size)

        # 权重共享:嵌入层和输出层共享参数(大模型常用trick,减少参数量)
        self.fc_out.weight = self.token_embedding.weight
        print('fc_out', self.fc_out.weight.shape)

    def forward(self, x, batch_idx, epoch):
        batch_size, seq_len = x.shape
        if batch_idx == 1 and epoch == 1:
            print('x', x)
            print('batch_size, seq_len', batch_size, seq_len)
            print('x.shape', x.shape)

        # 生成位置索引:[0,1,2,...,seq_len-1] → [batch_size, seq_len]
        pos_ids = torch.arange(seq_len, device=DEVICE).unsqueeze(0).repeat(batch_size, 1)
        # Token嵌入 + 位置嵌入(核心:词向量+位置向量)
        x_1 = self.token_embedding(x)
        x_2 = self.pos_embedding(pos_ids)
        if batch_idx == 1 and epoch == 1:
            print('pos_ids', pos_ids.shape, pos_ids)
            print('self.token_embedding', self.token_embedding.weight.shape)
            print('self.pos_embedding', self.pos_embedding.weight.shape)
            print('x_1', x_1.shape)
            print('x_2', x_2.shape)

        x = x_1 + x_2   # [batch, seq_len, embed_dim]
        if batch_idx == 1 and epoch == 1:
            print('x = token_embedding + self.pos_embedding', x.shape)
        # 经过多层Decoder Block
        x = self.decoder_blocks(x, batch_idx, epoch)

        # 映射到词表空间,输出logits
        logits = self.fc_out(x)  # [batch, seq_len, vocab_size]
        if batch_idx == 1 and epoch == 1:
            print('after decode_blocks', x.shape)
            print('logits', logits.shape)
        return logits

    @torch.no_grad()  # 推理阶段禁用梯度计算,节省显存
    def generate(self, prompt, max_len=GEN_MAX_LEN):
        """
        贪心解码生成文本
        :param prompt: 输入提示序列,shape [batch_size, prompt_len]
        :param max_len: 生成序列最大长度
        :return: 生成的完整序列,shape [batch_size, max_len]
        """
        self.eval()  # 切换到评估模式
        generated = prompt.to(DEVICE)

        for _ in range(max_len - prompt.shape[1]):
            # 1. 前向传播获取logits
            logits = self(generated, 0, 0)
            # 2. 取最后一个Token的logits,预测下一个Token
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(1)
            # 3. 拼接新Token到生成序列
            generated = torch.cat([generated, next_token], dim=1)
            # 4. 防止生成过长(兜底)
            if generated.shape[1] >= max_len:
                break

        self.train()  # 切回训练模式
        return generated


# ===================== 训练 + 推理流程 =====================
def main():
    # 1. 构建数据集和加载器
    dataset = DummyTextDataset(VOCAB_SIZE, SEQ_LEN)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    print(dataloader)

    # 2. 初始化模型、损失函数、优化器
    model = DecoderOnlyLM(VOCAB_SIZE, EMBED_DIM, NUM_HEADS, HEAD_DIM, NUM_LAYERS).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    # 3. 训练循环
    print("===== 开始预训练 =====")
    for epoch in range(EPOCHS):
        total_loss = 0.0
        for batch_idx, (x, y) in enumerate(dataloader):
            if batch_idx == 1 and epoch == 1:
                print('batch_idx: ', batch_idx)
                print('x', x)
                print('y', y)
            # 前向传播
            logits = model(x, batch_idx, epoch)  # [batch, seq_len, vocab_size]
            # 调整形状计算损失

            loss = criterion(logits.reshape(-1, VOCAB_SIZE), y.reshape(-1))
            if batch_idx == 1 and epoch == 1:
                print('logits.reshape(-1, VOCAB_SIZE): ', logits.reshape(-1, VOCAB_SIZE).shape)
                print('y.reshape(-1)', y.reshape(-1).shape)
                print('loss', loss)
            # 反向传播与优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch + 1}/{EPOCHS}], Avg Loss: {avg_loss:.4f}")

    # 4. 推理生成示例
    print("\n===== 开始生成文本 =====")
    # 随机生成提示序列(实际场景替换为真实Token序列)
    prompt = torch.randint(0, VOCAB_SIZE, (1, 5))  # batch_size=1, prompt_len=5
    print(f"输入提示Token序列: {prompt.squeeze().tolist()}")

    # 生成文本
    generated_seq = model.generate(prompt, max_len=GEN_MAX_LEN)
    print(f"生成完整Token序列: {generated_seq.squeeze().tolist()}")


if __name__ == "__main__":
    main()

二,输出

D:\book\PythonAI\Code\.venv\Scripts\python.exe D:\book\PythonAI\Code\ch09\MiniLLM.py 
<torch.utils.data.dataloader.DataLoader object at 0x00000177D64403B0>
self.token_embedding torch.Size([1000, 16])
self.pos_embedding torch.Size([50, 16])
self.w_q torch.Size([16, 16])
self.w_k torch.Size([16, 16])
self.w_v torch.Size([16, 16])
self.w_o torch.Size([16, 16])
self.fc1 torch.Size([64, 16])
self.fc2 torch.Size([16, 64])
self.w_q torch.Size([16, 16])
self.w_k torch.Size([16, 16])
self.w_v torch.Size([16, 16])
self.w_o torch.Size([16, 16])
self.fc1 torch.Size([64, 16])
self.fc2 torch.Size([16, 64])
self.w_q torch.Size([16, 16])
self.w_k torch.Size([16, 16])
self.w_v torch.Size([16, 16])
self.w_o torch.Size([16, 16])
self.fc1 torch.Size([64, 16])
self.fc2 torch.Size([16, 64])
fc_out torch.Size([1000, 16])
===== 开始预训练 =====
Epoch [1/5], Avg Loss: 11.0605
batch_idx:  1
x tensor([[622, 653, 952, 627, 247,  76, 148, 553, 565, 770, 854, 863, 972, 819,
         382, 605, 906, 869, 455, 992,  57, 813],
        [976,  34, 440, 770,  24, 446, 408, 181, 642,  59, 496, 517, 898, 278,
         179, 598, 899, 534, 475, 325, 862, 116],
        [593, 236, 265, 407, 334, 744, 667, 458, 584, 190, 556, 503, 700, 158,
         210,  65, 372, 390, 346, 574, 304, 337],
        [146, 993,  54, 799, 111, 226, 416, 868, 110, 551, 917, 458, 748,  82,
          13, 978, 378, 794, 368, 664, 601, 702],
        [449, 436, 253, 184, 945, 676,  79, 992, 460, 438, 662, 804, 298, 657,
         838, 742, 251, 806, 121, 880,  40, 486],
        [978, 408, 174, 915, 315, 236, 869, 151,   3, 201, 696, 682, 578,  83,
         284, 621, 295, 868, 884, 665, 579, 733],
        [502, 845, 855, 238, 364, 444, 833, 901, 264, 675, 569, 536, 914, 675,
         483,  93, 310, 370, 618, 302, 336, 179],
        [886, 143, 754, 534, 534, 741, 742,  34, 871, 341, 275, 794, 519, 103,
         116, 824, 684, 302, 995, 224, 326, 800],
        [367, 309, 779, 168, 378, 475, 349, 773, 391, 319, 134, 111, 800, 275,
         918, 860, 477, 632, 259,  23,  56, 330]])
y tensor([[653, 952, 627, 247,  76, 148, 553, 565, 770, 854, 863, 972, 819, 382,
         605, 906, 869, 455, 992,  57, 813,   0],
        [ 34, 440, 770,  24, 446, 408, 181, 642,  59, 496, 517, 898, 278, 179,
         598, 899, 534, 475, 325, 862, 116,   0],
        [236, 265, 407, 334, 744, 667, 458, 584, 190, 556, 503, 700, 158, 210,
          65, 372, 390, 346, 574, 304, 337,   0],
        [993,  54, 799, 111, 226, 416, 868, 110, 551, 917, 458, 748,  82,  13,
         978, 378, 794, 368, 664, 601, 702,   0],
        [436, 253, 184, 945, 676,  79, 992, 460, 438, 662, 804, 298, 657, 838,
         742, 251, 806, 121, 880,  40, 486,   0],
        [408, 174, 915, 315, 236, 869, 151,   3, 201, 696, 682, 578,  83, 284,
         621, 295, 868, 884, 665, 579, 733,   0],
        [845, 855, 238, 364, 444, 833, 901, 264, 675, 569, 536, 914, 675, 483,
          93, 310, 370, 618, 302, 336, 179,   0],
        [143, 754, 534, 534, 741, 742,  34, 871, 341, 275, 794, 519, 103, 116,
         824, 684, 302, 995, 224, 326, 800,   0],
        [309, 779, 168, 378, 475, 349, 773, 391, 319, 134, 111, 800, 275, 918,
         860, 477, 632, 259,  23,  56, 330,   0]])
x tensor([[622, 653, 952, 627, 247,  76, 148, 553, 565, 770, 854, 863, 972, 819,
         382, 605, 906, 869, 455, 992,  57, 813],
        [976,  34, 440, 770,  24, 446, 408, 181, 642,  59, 496, 517, 898, 278,
         179, 598, 899, 534, 475, 325, 862, 116],
        [593, 236, 265, 407, 334, 744, 667, 458, 584, 190, 556, 503, 700, 158,
         210,  65, 372, 390, 346, 574, 304, 337],
        [146, 993,  54, 799, 111, 226, 416, 868, 110, 551, 917, 458, 748,  82,
          13, 978, 378, 794, 368, 664, 601, 702],
        [449, 436, 253, 184, 945, 676,  79, 992, 460, 438, 662, 804, 298, 657,
         838, 742, 251, 806, 121, 880,  40, 486],
        [978, 408, 174, 915, 315, 236, 869, 151,   3, 201, 696, 682, 578,  83,
         284, 621, 295, 868, 884, 665, 579, 733],
        [502, 845, 855, 238, 364, 444, 833, 901, 264, 675, 569, 536, 914, 675,
         483,  93, 310, 370, 618, 302, 336, 179],
        [886, 143, 754, 534, 534, 741, 742,  34, 871, 341, 275, 794, 519, 103,
         116, 824, 684, 302, 995, 224, 326, 800],
        [367, 309, 779, 168, 378, 475, 349, 773, 391, 319, 134, 111, 800, 275,
         918, 860, 477, 632, 259,  23,  56, 330]])
batch_size, seq_len 9 22
x.shape torch.Size([9, 22])
pos_ids torch.Size([9, 22]) tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21]])
self.token_embedding torch.Size([1000, 16])
self.pos_embedding torch.Size([50, 16])
x_1 torch.Size([9, 22, 16])
x_2 torch.Size([9, 22, 16])
x = token_embedding + self.pos_embedding torch.Size([9, 22, 16])
attn.q torch.Size([9, 4, 22, 4])
attn.mask torch.Size([22, 22])
attn_out. torch.Size([9, 4, 22, 4])
attn_out_concat torch.Size([9, 22, 16])
attn.output torch.Size([9, 22, 16])
forward.x torch.Size([9, 22, 16])
DecoderBlock x.shape torch.Size([9, 22, 16])
attn.q torch.Size([9, 4, 22, 4])
attn.mask torch.Size([22, 22])
attn_out. torch.Size([9, 4, 22, 4])
attn_out_concat torch.Size([9, 22, 16])
attn.output torch.Size([9, 22, 16])
forward.x torch.Size([9, 22, 16])
DecoderBlock x.shape torch.Size([9, 22, 16])
attn.q torch.Size([9, 4, 22, 4])
attn.mask torch.Size([22, 22])
attn_out. torch.Size([9, 4, 22, 4])
attn_out_concat torch.Size([9, 22, 16])
attn.output torch.Size([9, 22, 16])
forward.x torch.Size([9, 22, 16])
DecoderBlock x.shape torch.Size([9, 22, 16])
after decode_blocks torch.Size([9, 22, 16])
logits torch.Size([9, 22, 1000])
logits.reshape(-1, VOCAB_SIZE):  torch.Size([198, 1000])
y.reshape(-1) torch.Size([198])
loss tensor(9.2886, grad_fn=<NllLossBackward0>)
Epoch [2/5], Avg Loss: 8.6743
Epoch [3/5], Avg Loss: 7.7679
Epoch [4/5], Avg Loss: 7.2881
Epoch [5/5], Avg Loss: 6.9872

===== 开始生成文本 =====
输入提示Token序列: [643, 749, 939, 185, 893]
生成完整Token序列: [643, 749, 939, 185, 893, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

进程已结束,退出代码为 0

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容