Pytorch学习记录-卷积Seq2Seq(模型训练)

Pytorch学习记录-torchtext和Pytorch的实例5

0. PyTorch Seq2Seq项目介绍

在完成基本的torchtext之后,找到了这个教程,《基于Pytorch和torchtext来理解和实现seq2seq模型》。
这个项目主要包括了6个子项目

  1. 使用神经网络训练Seq2Seq
  2. 使用RNN encoder-decoder训练短语表示用于统计机器翻译
  3. 使用共同学习完成NMT的堆砌和翻译
  4. 打包填充序列、掩码和推理
  5. 卷积Seq2Seq
  6. Transformer

5. 卷积Seq2Seq

5.1 准备数据

5.2 构建模型

5.3 训练模型

INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
EMB_DIM = 256
HID_DIM = 512
ENC_LAYERS = 10
DEC_LAYERS = 10
ENC_KERNEL_SIZE = 3
DEC_KERNEL_SIZE = 3
ENC_DROPOUT = 0.25
DEC_DROPOUT = 0.25
PAD_IDX = TRG.vocab.stoi['<pad>']
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
enc = Encoder(INPUT_DIM, EMB_DIM, HID_DIM, ENC_LAYERS, ENC_KERNEL_SIZE, ENC_DROPOUT, device)
dec = Decoder(OUTPUT_DIM, EMB_DIM, HID_DIM, DEC_LAYERS, DEC_KERNEL_SIZE, DEC_DROPOUT, PAD_IDX, device)

model = Seq2Seq(enc, dec, device).to(device)
model

Seq2Seq(
(encoder): Encoder(
(embedding): Embedding(7853, 256)
(rnn): GRU(256, 512, bidirectional=True)
(fc): Linear(in_features=1024, out_features=512, bias=True)
(dropout): Dropout(p=0.5)
)
(decoder): Decoder(
(attention): Attention(
(attn): Linear(in_features=1536, out_features=512, bias=True)
)
(embedding): Embedding(5893, 256)
(rnn): GRU(1280, 512)
(out): Linear(in_features=1792, out_features=5893, bias=True)
(dropout): Dropout(p=0.5)
)
)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 37,351,685 trainable parameters
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        src = batch.src
        trg = batch.trg
        
        optimizer.zero_grad()
        
        output, _ = model(src, trg[:,:-1])
        
        #output = [batch size, trg sent len - 1, output dim]
        #trg = [batch size, trg sent len]
        
        output = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:,1:].contiguous().view(-1)
        
        #output = [batch size * trg sent len - 1, output dim]
        #trg = [batch size * trg sent len - 1]
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output, _ = model(src, trg[:,:-1])
        
            #output = [batch size, trg sent len - 1, output dim]
            #trg = [batch size, trg sent len]

            output = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:,1:].contiguous().view(-1)

            #output = [batch size * trg sent len - 1, output dim]
            #trg = [batch size * trg sent len - 1]
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut5-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
# 10个epoch
Epoch: 01 | Time: 1m 6s
    Train Loss: 4.154 | Train PPL:  63.715
     Val. Loss: 2.897 |  Val. PPL:  18.116
Epoch: 02 | Time: 1m 6s
    Train Loss: 2.952 | Train PPL:  19.140
     Val. Loss: 2.368 |  Val. PPL:  10.680
Epoch: 03 | Time: 1m 6s
    Train Loss: 2.556 | Train PPL:  12.884
     Val. Loss: 2.125 |  Val. PPL:   8.370
Epoch: 04 | Time: 1m 6s
    Train Loss: 2.335 | Train PPL:  10.334
     Val. Loss: 1.987 |  Val. PPL:   7.291
Epoch: 05 | Time: 1m 6s
    Train Loss: 2.193 | Train PPL:   8.966
     Val. Loss: 1.926 |  Val. PPL:   6.862
Epoch: 06 | Time: 1m 6s
    Train Loss: 2.089 | Train PPL:   8.074
     Val. Loss: 1.878 |  Val. PPL:   6.538
Epoch: 07 | Time: 1m 6s
    Train Loss: 2.011 | Train PPL:   7.470
     Val. Loss: 1.835 |  Val. PPL:   6.264
Epoch: 08 | Time: 1m 6s
    Train Loss: 1.946 | Train PPL:   7.001
     Val. Loss: 1.818 |  Val. PPL:   6.159
Epoch: 09 | Time: 1m 6s
    Train Loss: 1.890 | Train PPL:   6.621
     Val. Loss: 1.802 |  Val. PPL:   6.064
Epoch: 10 | Time: 1m 6s
    Train Loss: 1.850 | Train PPL:   6.359
     Val. Loss: 1.790 |  Val. PPL:   5.988
# 20个epoch
Epoch: 01 | Time: 1m 6s
    Train Loss: 1.815 | Train PPL:   6.144
     Val. Loss: 1.771 |  Val. PPL:   5.880
Epoch: 02 | Time: 1m 6s
    Train Loss: 1.779 | Train PPL:   5.926
     Val. Loss: 1.753 |  Val. PPL:   5.772
Epoch: 03 | Time: 1m 6s
    Train Loss: 1.751 | Train PPL:   5.759
     Val. Loss: 1.732 |  Val. PPL:   5.651
Epoch: 04 | Time: 1m 6s
    Train Loss: 1.723 | Train PPL:   5.600
     Val. Loss: 1.735 |  Val. PPL:   5.671
Epoch: 05 | Time: 1m 6s
    Train Loss: 1.700 | Train PPL:   5.472
     Val. Loss: 1.736 |  Val. PPL:   5.672
Epoch: 06 | Time: 1m 6s
    Train Loss: 1.674 | Train PPL:   5.333
     Val. Loss: 1.721 |  Val. PPL:   5.589
Epoch: 07 | Time: 1m 6s
    Train Loss: 1.651 | Train PPL:   5.211
     Val. Loss: 1.720 |  Val. PPL:   5.587
Epoch: 08 | Time: 1m 6s
    Train Loss: 1.631 | Train PPL:   5.108
     Val. Loss: 1.720 |  Val. PPL:   5.585
Epoch: 09 | Time: 1m 6s
    Train Loss: 1.613 | Train PPL:   5.020
     Val. Loss: 1.722 |  Val. PPL:   5.596
Epoch: 10 | Time: 1m 6s
    Train Loss: 1.590 | Train PPL:   4.905
     Val. Loss: 1.708 |  Val. PPL:   5.520
Epoch: 11 | Time: 1m 6s
    Train Loss: 1.579 | Train PPL:   4.848
     Val. Loss: 1.719 |  Val. PPL:   5.577
Epoch: 12 | Time: 1m 6s
    Train Loss: 1.562 | Train PPL:   4.770
     Val. Loss: 1.728 |  Val. PPL:   5.632
Epoch: 13 | Time: 1m 6s
    Train Loss: 1.552 | Train PPL:   4.719
     Val. Loss: 1.703 |  Val. PPL:   5.493
Epoch: 14 | Time: 1m 6s
    Train Loss: 1.539 | Train PPL:   4.660
     Val. Loss: 1.723 |  Val. PPL:   5.602
Epoch: 15 | Time: 1m 6s
    Train Loss: 1.526 | Train PPL:   4.598
     Val. Loss: 1.710 |  Val. PPL:   5.529
Epoch: 16 | Time: 1m 6s
    Train Loss: 1.518 | Train PPL:   4.565
     Val. Loss: 1.704 |  Val. PPL:   5.494
Epoch: 17 | Time: 1m 6s
    Train Loss: 1.517 | Train PPL:   4.560
     Val. Loss: 1.726 |  Val. PPL:   5.616
Epoch: 18 | Time: 1m 6s
    Train Loss: 2.414 | Train PPL:  11.177
     Val. Loss: 2.562 |  Val. PPL:  12.961
Epoch: 19 | Time: 1m 6s
    Train Loss: 2.830 | Train PPL:  16.952
     Val. Loss: 2.583 |  Val. PPL:  13.240
Epoch: 20 | Time: 1m 6s
    Train Loss: 12.083 | Train PPL: 176818.618
     Val. Loss: 15.417 |  Val. PPL: 4961313.167

感谢Colab,要不这么多计算量我得把笔记本显卡跑废

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,591评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,448评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,823评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,204评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,228评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,190评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,078评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,923评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,334评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,550评论 2 333
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,727评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,428评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,022评论 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,672评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,826评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,734评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,619评论 2 354

推荐阅读更多精彩内容