总觉得对大模型的一些细节把握不准,为了印在心里,这次增加了对数据维度的测试。代码也比之前的要正规。为了能把验证过程的参数深入到第一行的代码,我甚至加了训练过程额外的参数传递,感觉还是有参考和记录价值的。下一步,就是像玩一样把第一步的数据维度变幻和内容反复求证,不留盲点。
一,代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
# ===================== 核心配置 =====================
VOCAB_SIZE = 1000 # 词表大小
EMBED_DIM = 16 # Token嵌入维度
NUM_HEADS = 4 # 注意力头数,必须满足 EMBED_DIM % NUM_HEADS == 0
HEAD_DIM = EMBED_DIM // NUM_HEADS # 每个注意力头的维度
NUM_LAYERS = 3 # Decoder层数
SEQ_LEN = 22 # 训练序列长度
GEN_MAX_LEN = 50 # 生成序列最大长度
BATCH_SIZE = 9 # 批次大小
LR = 1e-3 # 学习率
EPOCHS = 5 # 训练轮数
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===================== 数据集:自回归语言建模 =====================
class DummyTextDataset(Dataset):
"""生成虚拟Token序列,模拟预训练语料"""
def __init__(self, vocab_size, seq_len, sample_num=1000):
self.data = torch.randint(0, vocab_size, (sample_num, seq_len)) # [样本数, 序列长度]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx] # 输入序列
y = torch.roll(x, -1, dims=-1) # 标签:输入右移一位(预测下一个Token)
y[-1] = 0 # 最后一个Token无后续,标签置0
return x.to(DEVICE), y.to(DEVICE)
# ===================== 核心组件1:带因果掩码的多头自注意力 =====================
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, head_dim):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = head_dim
# Q/K/V 线性映射(共享权重,维度不变)
self.w_q = nn.Linear(embed_dim, embed_dim, bias=False)
self.w_k = nn.Linear(embed_dim, embed_dim, bias=False)
self.w_v = nn.Linear(embed_dim, embed_dim, bias=False)
# 输出投影层
self.w_o = nn.Linear(embed_dim, embed_dim, bias=False)
print('self.w_q', self.w_q.weight.shape)
print('self.w_k', self.w_k.weight.shape)
print('self.w_v', self.w_v.weight.shape)
print('self.w_o', self.w_o.weight.shape)
def generate_causal_mask(self, seq_len):
"""生成因果掩码:上三角为-inf,下三角为0,防止看到未来Token
mask形状: [1, seq_len, seq_len]
"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(DEVICE)
mask = mask.masked_fill(mask == 1, float("-inf"))
return mask
def scaled_dot_product_attention(self, q, k, v, mask=None):
"""
缩放点积注意力核心公式:Attention(Q,K,V) = softmax(QK^T/√d_k)V
q/k/v形状: [batch_size, num_heads, seq_len, head_dim]
"""
# 1. 计算QK^T,缩放因子为√head_dim
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(
torch.tensor(self.head_dim, dtype=torch.float32)).to(DEVICE)
# 2. 应用因果掩码(未来Token注意力分数置为-inf,softmax后为0)
if mask is not None:
attn_scores = attn_scores + mask
# 3. softmax归一化,得到注意力权重
attn_weights = F.softmax(attn_scores, dim=-1)
# 4. 加权求和V,得到注意力输出
output = torch.matmul(attn_weights, v)
return output, attn_weights
def split_heads(self, x):
"""将输入拆分为多个头:[batch, seq_len, embed_dim] → [batch, num_heads, seq_len, head_dim]"""
batch_size, seq_len, embed_dim = x.shape
return x.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
def concat_heads(self, x):
"""将多个头的输出拼接:[batch, num_heads, seq_len, head_dim] → [batch, seq_len, embed_dim]"""
batch_size, num_heads, seq_len, head_dim = x.shape
return x.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
def forward(self, x, batch_idx, epoch):
batch_size, seq_len, _ = x.shape
# 1. Q/K/V 线性映射
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
# 2. 拆分多头
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v)
# 3. 生成因果掩码
mask = self.generate_causal_mask(seq_len)
# 4. 缩放点积注意力计算
attn_out, _ = self.scaled_dot_product_attention(q, k, v, mask)
# 5. 拼接多头输出
attn_out_concat = self.concat_heads(attn_out)
# 6. 输出投影
output = self.w_o(attn_out_concat)
if batch_idx == 1 and epoch == 1:
print('attn.q', q.shape)
print('attn.mask', mask.shape)
print('attn_out.', attn_out.shape)
print('attn_out_concat', attn_out_concat.shape)
print('attn.output', output.shape)
return output
# ===================== 核心组件2:前馈神经网络 FFN =====================
class FeedForwardNetwork(nn.Module):
def __init__(self, embed_dim, hidden_dim=None):
super().__init__()
# FFN结构:Linear → ReLU → Linear,隐藏层维度通常为embed_dim的4倍
self.hidden_dim = hidden_dim if hidden_dim else embed_dim * 4
self.fc1 = nn.Linear(embed_dim, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, embed_dim)
print('self.fc1', self.fc1.weight.shape)
print('self.fc2', self.fc2.weight.shape)
self.relu = nn.ReLU() # GPT用GELU,这里简化用ReLU
def forward(self, x, batch_idx, epoch):
# x: [batch, seq_len, embed_dim]
if batch_idx == 1 and epoch == 1:
print('forward.x', x.shape)
return self.fc2(self.relu(self.fc1(x)))
# ===================== 核心组件3:Decoder Block =====================
class DecoderBlock(nn.Module):
def __init__(self, embed_dim, num_heads, head_dim):
super().__init__()
# 子层1:多头自注意力 + 残差连接 + 层归一化
self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, head_dim)
self.norm1 = nn.LayerNorm(embed_dim)
# 子层2:前馈网络 + 残差连接 + 层归一化
self.ffn = FeedForwardNetwork(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
# Dropout防止过拟合
self.dropout = nn.Dropout(0.1)
def forward(self, x, batch_idx, epoch):
# 残差连接:输入 + 子层输出
attn_out = self.self_attn(x, batch_idx, epoch)
x = self.norm1(x + self.dropout(attn_out)) # Pre-LN结构(主流大模型方案)
ffn_out = self.ffn(x, batch_idx, epoch)
x = self.norm2(x + self.dropout(ffn_out))
if batch_idx == 1 and epoch == 1:
print('DecoderBlock x.shape', x.shape)
return x
# 自定义支持多参数的 Sequential 容器,为了更透明的了解参数传递过程
class MultiInputSequential(nn.Sequential):
def forward(self, x, *args, **kwargs):
# 遍历 Sequential 中的每个模块,依次传递所有参数
for module in self:
# 将 x 和额外的 args/kwargs 都传给每个 DecoderBlock
x = module(x, *args, **kwargs)
return x
# ===================== 整体模型:Decoder-only 大模型 =====================
class DecoderOnlyLM(nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, head_dim, num_layers):
super().__init__()
# 1. Token嵌入层:将Token ID映射为向量
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
print('self.token_embedding', self.token_embedding.weight.shape)
# 2. 位置嵌入层:学习序列位置信息(非正弦编码,简化易实现)
self.pos_embedding = nn.Embedding(GEN_MAX_LEN, embed_dim) # 适配生成最大长度
print('self.pos_embedding', self.pos_embedding.weight.shape)
# 3. Decoder层堆叠
self.decoder_blocks = MultiInputSequential(
*[DecoderBlock(embed_dim, num_heads, head_dim) for _ in range(num_layers)]
)
# 4. 输出映射层
self.fc_out = nn.Linear(embed_dim, vocab_size)
# 权重共享:嵌入层和输出层共享参数(大模型常用trick,减少参数量)
self.fc_out.weight = self.token_embedding.weight
print('fc_out', self.fc_out.weight.shape)
def forward(self, x, batch_idx, epoch):
batch_size, seq_len = x.shape
if batch_idx == 1 and epoch == 1:
print('x', x)
print('batch_size, seq_len', batch_size, seq_len)
print('x.shape', x.shape)
# 生成位置索引:[0,1,2,...,seq_len-1] → [batch_size, seq_len]
pos_ids = torch.arange(seq_len, device=DEVICE).unsqueeze(0).repeat(batch_size, 1)
# Token嵌入 + 位置嵌入(核心:词向量+位置向量)
x_1 = self.token_embedding(x)
x_2 = self.pos_embedding(pos_ids)
if batch_idx == 1 and epoch == 1:
print('pos_ids', pos_ids.shape, pos_ids)
print('self.token_embedding', self.token_embedding.weight.shape)
print('self.pos_embedding', self.pos_embedding.weight.shape)
print('x_1', x_1.shape)
print('x_2', x_2.shape)
x = x_1 + x_2 # [batch, seq_len, embed_dim]
if batch_idx == 1 and epoch == 1:
print('x = token_embedding + self.pos_embedding', x.shape)
# 经过多层Decoder Block
x = self.decoder_blocks(x, batch_idx, epoch)
# 映射到词表空间,输出logits
logits = self.fc_out(x) # [batch, seq_len, vocab_size]
if batch_idx == 1 and epoch == 1:
print('after decode_blocks', x.shape)
print('logits', logits.shape)
return logits
@torch.no_grad() # 推理阶段禁用梯度计算,节省显存
def generate(self, prompt, max_len=GEN_MAX_LEN):
"""
贪心解码生成文本
:param prompt: 输入提示序列,shape [batch_size, prompt_len]
:param max_len: 生成序列最大长度
:return: 生成的完整序列,shape [batch_size, max_len]
"""
self.eval() # 切换到评估模式
generated = prompt.to(DEVICE)
for _ in range(max_len - prompt.shape[1]):
# 1. 前向传播获取logits
logits = self(generated, 0, 0)
# 2. 取最后一个Token的logits,预测下一个Token
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(1)
# 3. 拼接新Token到生成序列
generated = torch.cat([generated, next_token], dim=1)
# 4. 防止生成过长(兜底)
if generated.shape[1] >= max_len:
break
self.train() # 切回训练模式
return generated
# ===================== 训练 + 推理流程 =====================
def main():
# 1. 构建数据集和加载器
dataset = DummyTextDataset(VOCAB_SIZE, SEQ_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print(dataloader)
# 2. 初始化模型、损失函数、优化器
model = DecoderOnlyLM(VOCAB_SIZE, EMBED_DIM, NUM_HEADS, HEAD_DIM, NUM_LAYERS).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
# 3. 训练循环
print("===== 开始预训练 =====")
for epoch in range(EPOCHS):
total_loss = 0.0
for batch_idx, (x, y) in enumerate(dataloader):
if batch_idx == 1 and epoch == 1:
print('batch_idx: ', batch_idx)
print('x', x)
print('y', y)
# 前向传播
logits = model(x, batch_idx, epoch) # [batch, seq_len, vocab_size]
# 调整形状计算损失
loss = criterion(logits.reshape(-1, VOCAB_SIZE), y.reshape(-1))
if batch_idx == 1 and epoch == 1:
print('logits.reshape(-1, VOCAB_SIZE): ', logits.reshape(-1, VOCAB_SIZE).shape)
print('y.reshape(-1)', y.reshape(-1).shape)
print('loss', loss)
# 反向传播与优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch [{epoch + 1}/{EPOCHS}], Avg Loss: {avg_loss:.4f}")
# 4. 推理生成示例
print("\n===== 开始生成文本 =====")
# 随机生成提示序列(实际场景替换为真实Token序列)
prompt = torch.randint(0, VOCAB_SIZE, (1, 5)) # batch_size=1, prompt_len=5
print(f"输入提示Token序列: {prompt.squeeze().tolist()}")
# 生成文本
generated_seq = model.generate(prompt, max_len=GEN_MAX_LEN)
print(f"生成完整Token序列: {generated_seq.squeeze().tolist()}")
if __name__ == "__main__":
main()
二,输出
D:\book\PythonAI\Code\.venv\Scripts\python.exe D:\book\PythonAI\Code\ch09\MiniLLM.py
<torch.utils.data.dataloader.DataLoader object at 0x00000177D64403B0>
self.token_embedding torch.Size([1000, 16])
self.pos_embedding torch.Size([50, 16])
self.w_q torch.Size([16, 16])
self.w_k torch.Size([16, 16])
self.w_v torch.Size([16, 16])
self.w_o torch.Size([16, 16])
self.fc1 torch.Size([64, 16])
self.fc2 torch.Size([16, 64])
self.w_q torch.Size([16, 16])
self.w_k torch.Size([16, 16])
self.w_v torch.Size([16, 16])
self.w_o torch.Size([16, 16])
self.fc1 torch.Size([64, 16])
self.fc2 torch.Size([16, 64])
self.w_q torch.Size([16, 16])
self.w_k torch.Size([16, 16])
self.w_v torch.Size([16, 16])
self.w_o torch.Size([16, 16])
self.fc1 torch.Size([64, 16])
self.fc2 torch.Size([16, 64])
fc_out torch.Size([1000, 16])
===== 开始预训练 =====
Epoch [1/5], Avg Loss: 11.0605
batch_idx: 1
x tensor([[622, 653, 952, 627, 247, 76, 148, 553, 565, 770, 854, 863, 972, 819,
382, 605, 906, 869, 455, 992, 57, 813],
[976, 34, 440, 770, 24, 446, 408, 181, 642, 59, 496, 517, 898, 278,
179, 598, 899, 534, 475, 325, 862, 116],
[593, 236, 265, 407, 334, 744, 667, 458, 584, 190, 556, 503, 700, 158,
210, 65, 372, 390, 346, 574, 304, 337],
[146, 993, 54, 799, 111, 226, 416, 868, 110, 551, 917, 458, 748, 82,
13, 978, 378, 794, 368, 664, 601, 702],
[449, 436, 253, 184, 945, 676, 79, 992, 460, 438, 662, 804, 298, 657,
838, 742, 251, 806, 121, 880, 40, 486],
[978, 408, 174, 915, 315, 236, 869, 151, 3, 201, 696, 682, 578, 83,
284, 621, 295, 868, 884, 665, 579, 733],
[502, 845, 855, 238, 364, 444, 833, 901, 264, 675, 569, 536, 914, 675,
483, 93, 310, 370, 618, 302, 336, 179],
[886, 143, 754, 534, 534, 741, 742, 34, 871, 341, 275, 794, 519, 103,
116, 824, 684, 302, 995, 224, 326, 800],
[367, 309, 779, 168, 378, 475, 349, 773, 391, 319, 134, 111, 800, 275,
918, 860, 477, 632, 259, 23, 56, 330]])
y tensor([[653, 952, 627, 247, 76, 148, 553, 565, 770, 854, 863, 972, 819, 382,
605, 906, 869, 455, 992, 57, 813, 0],
[ 34, 440, 770, 24, 446, 408, 181, 642, 59, 496, 517, 898, 278, 179,
598, 899, 534, 475, 325, 862, 116, 0],
[236, 265, 407, 334, 744, 667, 458, 584, 190, 556, 503, 700, 158, 210,
65, 372, 390, 346, 574, 304, 337, 0],
[993, 54, 799, 111, 226, 416, 868, 110, 551, 917, 458, 748, 82, 13,
978, 378, 794, 368, 664, 601, 702, 0],
[436, 253, 184, 945, 676, 79, 992, 460, 438, 662, 804, 298, 657, 838,
742, 251, 806, 121, 880, 40, 486, 0],
[408, 174, 915, 315, 236, 869, 151, 3, 201, 696, 682, 578, 83, 284,
621, 295, 868, 884, 665, 579, 733, 0],
[845, 855, 238, 364, 444, 833, 901, 264, 675, 569, 536, 914, 675, 483,
93, 310, 370, 618, 302, 336, 179, 0],
[143, 754, 534, 534, 741, 742, 34, 871, 341, 275, 794, 519, 103, 116,
824, 684, 302, 995, 224, 326, 800, 0],
[309, 779, 168, 378, 475, 349, 773, 391, 319, 134, 111, 800, 275, 918,
860, 477, 632, 259, 23, 56, 330, 0]])
x tensor([[622, 653, 952, 627, 247, 76, 148, 553, 565, 770, 854, 863, 972, 819,
382, 605, 906, 869, 455, 992, 57, 813],
[976, 34, 440, 770, 24, 446, 408, 181, 642, 59, 496, 517, 898, 278,
179, 598, 899, 534, 475, 325, 862, 116],
[593, 236, 265, 407, 334, 744, 667, 458, 584, 190, 556, 503, 700, 158,
210, 65, 372, 390, 346, 574, 304, 337],
[146, 993, 54, 799, 111, 226, 416, 868, 110, 551, 917, 458, 748, 82,
13, 978, 378, 794, 368, 664, 601, 702],
[449, 436, 253, 184, 945, 676, 79, 992, 460, 438, 662, 804, 298, 657,
838, 742, 251, 806, 121, 880, 40, 486],
[978, 408, 174, 915, 315, 236, 869, 151, 3, 201, 696, 682, 578, 83,
284, 621, 295, 868, 884, 665, 579, 733],
[502, 845, 855, 238, 364, 444, 833, 901, 264, 675, 569, 536, 914, 675,
483, 93, 310, 370, 618, 302, 336, 179],
[886, 143, 754, 534, 534, 741, 742, 34, 871, 341, 275, 794, 519, 103,
116, 824, 684, 302, 995, 224, 326, 800],
[367, 309, 779, 168, 378, 475, 349, 773, 391, 319, 134, 111, 800, 275,
918, 860, 477, 632, 259, 23, 56, 330]])
batch_size, seq_len 9 22
x.shape torch.Size([9, 22])
pos_ids torch.Size([9, 22]) tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21]])
self.token_embedding torch.Size([1000, 16])
self.pos_embedding torch.Size([50, 16])
x_1 torch.Size([9, 22, 16])
x_2 torch.Size([9, 22, 16])
x = token_embedding + self.pos_embedding torch.Size([9, 22, 16])
attn.q torch.Size([9, 4, 22, 4])
attn.mask torch.Size([22, 22])
attn_out. torch.Size([9, 4, 22, 4])
attn_out_concat torch.Size([9, 22, 16])
attn.output torch.Size([9, 22, 16])
forward.x torch.Size([9, 22, 16])
DecoderBlock x.shape torch.Size([9, 22, 16])
attn.q torch.Size([9, 4, 22, 4])
attn.mask torch.Size([22, 22])
attn_out. torch.Size([9, 4, 22, 4])
attn_out_concat torch.Size([9, 22, 16])
attn.output torch.Size([9, 22, 16])
forward.x torch.Size([9, 22, 16])
DecoderBlock x.shape torch.Size([9, 22, 16])
attn.q torch.Size([9, 4, 22, 4])
attn.mask torch.Size([22, 22])
attn_out. torch.Size([9, 4, 22, 4])
attn_out_concat torch.Size([9, 22, 16])
attn.output torch.Size([9, 22, 16])
forward.x torch.Size([9, 22, 16])
DecoderBlock x.shape torch.Size([9, 22, 16])
after decode_blocks torch.Size([9, 22, 16])
logits torch.Size([9, 22, 1000])
logits.reshape(-1, VOCAB_SIZE): torch.Size([198, 1000])
y.reshape(-1) torch.Size([198])
loss tensor(9.2886, grad_fn=<NllLossBackward0>)
Epoch [2/5], Avg Loss: 8.6743
Epoch [3/5], Avg Loss: 7.7679
Epoch [4/5], Avg Loss: 7.2881
Epoch [5/5], Avg Loss: 6.9872
===== 开始生成文本 =====
输入提示Token序列: [643, 749, 939, 185, 893]
生成完整Token序列: [643, 749, 939, 185, 893, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
进程已结束,退出代码为 0