Datawhale AI 夏令营 AI for Science 催化反应产率预测 Task2 Baseline 学习

简介

这一次课程使用RNN进行催化反应产率预测。

总体想法是把数据集中的反应物和产物通过SMILES字符串表示出来,然后根据基本的原子、连接键等将化学反应对应的SMILES字符串转化为整数序列,再通过RNN进行训练和预测。

下面的代码是对DataWhale提供的baseline的学习。

过程

1. 导入必要的库

import re
import time
import pandas as pd
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

re模块是为了从化学反应的SMILES字符串中提取出反应的基本组成元素。

RNN主要使用pytorch中的RNN。

2. 定义RNN模型

# 定义RNN模型
class RNNModel(nn.Module):
    def __init__(self, num_embed, input_size, hidden_size, output_size, num_layers, dropout, device):
        super(RNNModel, self).__init__()
        self.embed = nn.Embedding(num_embed, input_size)
        self.rnn = nn.RNN(input_size, hidden_size, num_layers=num_layers, 
                          batch_first=True, dropout=dropout, bidirectional=True)
        self.fc = nn.Sequential(nn.Linear(2 * num_layers * hidden_size, output_size),
                                nn.Sigmoid(),
                                nn.Linear(output_size, 1),
                                nn.Sigmoid())

    def forward(self, x):
        # x : [bs, seq_len]
        x = self.embed(x)
        # x : [bs, seq_len, input_size]
        _, hn = self.rnn(x) # hn : [2*num_layers, bs, h_dim]
        hn = hn.transpose(0,1)
        z = hn.reshape(hn.shape[0], -1) # z shape: [bs, 2*num_layers*h_dim]
        output = self.fc(z).squeeze(-1) # output shape: [bs, 1]
        return output
  1. 这个RNNModel包括了三层:

    • 嵌入层:nn.Embedding(num_embeddings, embedding_dim),将整数索引的序列转换为稠密向量表示:

      • num_embeddings:词汇表的大小,由于vocab_full.txt有294行,这里填写294
      • embedding_dim:每个嵌入向量的维度大小,也就是输出向量(即RNN层输入向量)的大小,这里是input_size
      • 输入nn.Embedding(num_embeddings, embedding_dim)(x):这里面x的shape是(batch_size, seq_len),seq_len是每个样本的元素数量
      • 输出:即RNN层的输入,shape是(batch_size, seq_len, input_size)
    • RNN层:nn.rnn(input_size, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True):

      • input_size:输入张量的特征数,即输入RNN的单元的向量的大小。输入300是因为
      • hidden_size:RNN 单元的隐藏状态向量的大小,也是输出向量的大小
      • num_layers:RNN 层的数量
      • dropout:浮点数,表示在训练过程中 RNN 层之间的输出被丢弃的概率,可防止过拟合
      • bidirectional:布尔值,指示是否使用双向 RNN。若为True,则是双向 RNN
    • 全连接层:

      • nn.Linear,包括两个线性层和两个 Sigmoid 激活函数,将RNN层的输出转化为一个预测结果
  2. 前向传播self.forward(x)方法:

    输入的数据x是DataLoader划分后得数据。x的shape是(batch_size, seq_len)。

    这个 forward方法首先通过嵌入层将输入 x 转换为嵌入表示,然后将其馈送到 RNN 层中。之后,它处理 RNN 的最后一个隐藏状态 (hn),并将其通过全连接层产生最终的输出。
    由于hn的shape是(2*num_layers, batch_size, hidden_size),通过hn.transpose(0,1)将其转换为( batch_size, 2*num_layers, hidden_size),以便作为全连接层的输入。

    z = hn.reshape(hn.shape[0], -1),将hn的shape变为(bs, 2*num_layers*h_dim)。

    output = self.fc(z).squeeze(-1) ,将张量z经过全连接层处理,最终得到模型的输出。

    这几步算是前向传播的流水线操作,必须记住这些步骤。

3. 数据预处理

# tokenizer,鉴于SMILES的特性,这里需要自己定义tokenizer和vocab
# 这里直接将smiles str按字符拆分,并替换为词汇表中的序号
class Smiles_tokenizer():
    def __init__(self, pad_token, regex, vocab_file, max_length):
        self.pad_token = pad_token
        self.regex = regex
        self.vocab_file = vocab_file
        self.max_length = max_length

        with open(self.vocab_file, "r") as f:
            lines = f.readlines()
        lines = [line.strip("\n") for line in lines]
        vocab_dic = {}
        for index, token in enumerate(lines):
            vocab_dic[token] = index
        self.vocab_dic = vocab_dic

    def _regex_match(self, smiles):
        regex_string = r"(" + self.regex + r"|"
        regex_string += r".)"
        prog = re.compile(regex_string)

        tokenised = []
        for smi in smiles:
            tokens = prog.findall(smi)
            if len(tokens) > self.max_length:
                tokens = tokens[:self.max_length]
            tokenised.append(tokens) # 返回一个所有的字符串列表
        return tokenised
    
    def tokenize(self, smiles):
        tokens = self._regex_match(smiles)
        # 添加上表示开始和结束的token:<cls>, <end>
        tokens = [["<CLS>"] + token + ["<SEP>"] for token in tokens]
        tokens = self._pad_seqs(tokens, self.pad_token)
        token_idx = self._pad_token_to_idx(tokens)
        return tokens, token_idx

    def _pad_seqs(self, seqs, pad_token):
        pad_length = max([len(seq) for seq in seqs])
        padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs]
        return padded

    def _pad_token_to_idx(self, tokens):
        idx_list = []
        for token in tokens:
            tokens_idx = []
            for i in token:
                if i in self.vocab_dic.keys():
                    tokens_idx.append(self.vocab_dic[i])
                else:
                    self.vocab_dic[i] = max(self.vocab_dic.values()) + 1
                    tokens_idx.append(self.vocab_dic[i])
            idx_list.append(tokens_idx)
        
        return idx_list

# 读数据并处理
def read_data(file_path, train=True):
    df = pd.read_csv(file_path)
    reactant1 = df["Reactant1"].tolist()
    reactant2 = df["Reactant2"].tolist()
    product = df["Product"].tolist()
    additive = df["Additive"].tolist()
    solvent = df["Solvent"].tolist()
    if train:
        react_yield = df["Yield"].tolist()
    else:
        react_yield = [0 for i in range(len(reactant1))]
    
    # 将reactant拼到一起,之间用.分开。product也拼到一起,用>分开
    input_data_list = []
    for react1, react2, prod, addi, sol in zip(reactant1, reactant2, product, additive, solvent):
        input_info = ".".join([react1, react2])
        input_info = ">".join([input_info, prod])
        input_data_list.append(input_info)
    output = [(react, y) for react, y in zip(input_data_list, react_yield)]

    return output

class ReactionDataset(Dataset):
    def __init__(self, data: List[Tuple[List[str], float]]):
        self.data = data
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def collate_fn(batch):
    REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
    tokenizer = Smiles_tokenizer("<PAD>", REGEX, "../vocab_full.txt", max_length=300)
    smi_list = []
    yield_list = []
    for i in batch:
        smi_list.append(i[0])
        yield_list.append(i[1])
    tokenizer_batch = torch.tensor(tokenizer.tokenize(smi_list)[1])
    yield_list = torch.tensor(yield_list)
    return tokenizer_batch, yield_list

这里面有两个类,一个是对数据集进行处理的,一个是自定义的反应数据集。详细看一下Smiles_tokenizer类的主要作用:

3.1 从read_data函数开始
def read_data(file_path, train=True):
    df = pd.read_csv(file_path)
    reactant1 = df["Reactant1"].tolist()
    reactant2 = df["Reactant2"].tolist()
    product = df["Product"].tolist()
    additive = df["Additive"].tolist()
    solvent = df["Solvent"].tolist()
    if train:
        react_yield = df["Yield"].tolist()
    else:
        react_yield = [0 for i in range(len(reactant1))]
    input_data_list = []

这一部分代码是讲数据集中的各列提取出来并转化为列表,其中反应物、产物、催化剂和溶剂均为SMILES字符串组成的列表,对于训练集,react_yield是各反应对应的产率组成的列表,对于测试集,react_yield都是0构成的列表。经过这一步,我们可以得到下面各个列表:

reactant1:['c1ccc2c(c1)Nc1ccccc1O2','c1ccc2c(c1)Nc1ccccc1O2', ...]
reactant2:['Brc1ccccc1I', 'Brc1ccccc1I', ...]
product:['Brc1ccccc1N1c2ccccc2Oc2ccccc21','Brc1ccccc1N1c2ccccc2Oc2ccccc21', ...]
additive:['CC(C)(C)[O-].CC(C)(C)[PH+](C(C)(C)C)C(C)(C)C.F[B-](F)(F)F.F[B-](F)(F)F.O=C(C=Cc1ccccc1)C=Cc1ccccc1.O=C(C=Cc1ccccc1)C=Cc1ccccc1.[H+].[Na+].[Pd]','C1COCCOCCOCCOCCOCCO1.O=C([O-])[O-].[Cu+].[I-].[K+].[K+]',]
solvent:['Cc1ccccc1', 'Clc1ccccc1Cl', ...]
react_yield:[0.78, 0.9, ...]
def read_data(file_path, train=True):
    ...
    input_data_list = []
    for react1, react2, prod, addi, sol in zip(reactant1, reactant2, product, additive, solvent):
        input_info = ".".join([react1, react2])
        input_info = ">".join([input_info, prod])
        input_data_list.append(input_info)
        output = [(react, y) for react, y in zip(input_data_list, react_yield)]
    return output

这一部分是将上面得到的各列表转化为“反应物1.反应物2>产物的SMILES字符串”,然后输出“(反应物1.反应物2>产物的SMILES字符串, 产率)”组成的列表:

[('c1ccc2c(c1)Nc1ccccc1O2.Brc1ccccc1I>Brc1ccccc1N1c2ccccc2Oc2ccccc21', 0.78),
 ('c1ccc2c(c1)Nc1ccccc1O2.Brc1ccccc1I>Brc1ccccc1N1c2ccccc2Oc2ccccc21', 0.9),
...
]
3.2 Smiles_tokenizer类

首先从构造函数开始:

class Smiles_tokenizer():
    def __init__(self, pad_token, regex, vocab_file, max_length):
        self.pad_token = pad_token
        self.regex = regex
        self.vocab_file = vocab_file
        self.max_length = max_length

        with open(self.vocab_file, "r") as f:
            lines = f.readlines()
        lines = [line.strip("\n") for line in lines]
        vocab_dic = {}
        for index, token in enumerate(lines):
            vocab_dic[token] = index
        self.vocab_dic = vocab_dic

这个构造函数主要形成一个由各基本原子、基团以及化学键等为键组成的字典。其中每个项都由“基础 : 整数序号”组成:

vocab_dic:
{'<PAD>': 0,
 '<CLS>': 1,
 '<MASK>': 2,
 '<SEP>': 3,
 '[UNK]': 4,
 '>': 5,
 'C': 6,
...
}

然后是_regex_match方法:

class Smiles_tokenizer():
    ...
    def _regex_match(self, smiles):
        regex_string = r"(" + self.regex + r"|"
        regex_string += r".)"
        prog = re.compile(regex_string)

        tokenised = []
        for smi in smiles:
            tokens = prog.findall(smi)
            if len(tokens) > self.max_length:
                tokens = tokens[:self.max_length]
            tokenised.append(tokens) # 返回一个所有的字符串列表
        return tokenised

上面的正则表达式是:

REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
self.regex = REGEX

这个正则表达式匹配的内容包括:方括号里面的内容(例如:[NH4+])、N、P、Br、.、@、=等符号。

_regex_match传入的smiles参数,是一个SMILES字符串表示的反应方程式构成的列表。通过for循环,不断用这个正则表达式将SMILES字符串表示的反应方程式中的各个元素提起出来,作为tokens,也就是说tokens是反应方程式中的各个元素组成的列表(不妨称之为:特征元素列表),而tokenised就是整个smiles对应的特征元素列表:

# 下面是一个第四个反应的tokens的例子,是个一维列表
tokens:['C', '1', 'C', 'O', 'C', 'C', 'N', '1', '.', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'Cl', '>', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'N', '1', 'C', 'C', 'O', 'C', 'C', '1']
# tokenised是嵌套列表
tokenised:[
    ['c', '1', 'c', 'c', 'c', '2', 'c', '(', 'c', '1', ')', 'N', 'c', '1', 'c', 'c', 'c', 'c', 'c', '1', 'O', '2', '.', 'Br', 'c', '1', 'c', 'c', 'c', 'c', 'c', '1', 'I', '>', 'Br', 'c', '1', 'c', 'c', 'c', 'c', 'c', '1', 'N', '1', 'c', '2', 'c', 'c', 'c', 'c', 'c', '2', 'O', 'c', '2', 'c', 'c', 'c', 'c', 'c', '2', '1']
    ...
]

接着看_pad_seqs和_pad_token_to_idx方法:

class Smiles_tokenizer():
    ...
    def _pad_seqs(self, seqs, pad_token):
        pad_length = max([len(seq) for seq in seqs])
        padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs]
        return padded

    def _pad_token_to_idx(self, tokens):
        idx_list = []
        for token in tokens:
            tokens_idx = []
            for i in token:
                if i in self.vocab_dic.keys():
                    tokens_idx.append(self.vocab_dic[i])
                else:
                    self.vocab_dic[i] = max(self.vocab_dic.values()) + 1
                    tokens_idx.append(self.vocab_dic[i])
            idx_list.append(tokens_idx)
        
        return idx_list

这两个方法,第一个是把所有反应方程,通过填充<PAD>(也就是0),变成相同的长度,也就是把每个方程对应的tokens变成拥有相同的元素数量,便于后续的处理。

第二个方法是根据前面的vocab_dic,把tokenised中的所有特征元素变成整数:

# 下面是tokenised在padded后的形式,其中展示的列表是填充后的第四个反应
padded:[
    ...
  ['C', '1', 'C', 'O', 'C', 'C', 'N', '1', '.', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'Cl', '>', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'N', '1', 'C', 'C', 'O', 'C', 'C', '1', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
    ...
]
# 下面是tokenised在padded后对应的idx_list,其中展示的列表是填充后的第四个反应
idx_list:[
    ...
    [6, 8, 6, 19, 6, 6, 13, 8, 35, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 40, 5, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 13, 8, 6, 6, 19, 6, 6, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ...
]

最后一下 tokenize方法:

class Smiles_tokenizer():
    ...    
    def tokenize(self, smiles):
        tokens = self._regex_match(smiles)
        # 添加上表示开始和结束的token:<cls>, <end>
        tokens = [["<CLS>"] + token + ["<SEP>"] for token in tokens]
        tokens = self._pad_seqs(tokens, self.pad_token)
        token_idx = self._pad_token_to_idx(tokens)
        return tokens, token_idx

这个方法是把前面的几个方法汇总到一起,也就是最后返回的tokens就是上面padded的形式(前后添加了"<CLS>"和"<SEP>"),token_idx就是idx_list的形式:

tokens:[
    ...
  ['<CLS>', 'C', '1', 'C', 'O', 'C', 'C', 'N', '1', '.', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'Cl', '>', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'N', '1', 'C', 'C', 'O', 'C', 'C', '1', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<SEP>']
    ...
]

token_idx:[
    ...
    [1, 6, 8, 6, 19, 6, 6, 13, 8, 35, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 40, 5, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 13, 8, 6, 6, 19, 6, 6, 8, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ...
]

至此,就把文本型的数据转化为了由整数构成的数据了,而collate_fn函数就是通过Smiles_tokenizer类获得的token_idx转化成张量。

def collate_fn(batch):
    REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
    tokenizer = Smiles_tokenizer("<PAD>", REGEX, "../vocab_full.txt", max_length=300)
    smi_list = []
    yield_list = []
    for i in batch:
        smi_list.append(i[0])
        yield_list.append(i[1])
    tokenizer_batch = torch.tensor(tokenizer.tokenize(smi_list)[1])
    yield_list = torch.tensor(yield_list)
    return tokenizer_batch, yield_list
tokenizer_batch:
tensor([[ 1,  7,  8,  7,  7,  7, 11,  7,  9,  7,  8, 15, 13,  7,  8,  7,  7,  7,
          7,  7,  8, 19, 11, 35, 17,  7,  8,  7,  7,  7,  7,  7,  8, 38,  5, 17,
          7,  8,  7,  7,  7,  7,  7,  8, 13,  8,  7, 11,  7,  7,  7,  7,  7, 11,
         19,  7, 11,  7,  7,  7,  7,  7, 11,  8,  3],
        [ 1,  7,  8,  7,  7,  7, 11,  7,  9,  7,  8, 15, 13,  7,  8,  7,  7,  7,
          7,  7,  8, 19, 11, 35, 17,  7,  8,  7,  7,  7,  7,  7,  8, 38,  5, 17,
          7,  8,  7,  7,  7,  7,  7,  8, 13,  8,  7, 11,  7,  7,  7,  7,  7, 11,
         19,  7, 11,  7,  7,  7,  7,  7, 11,  8,  3],
        ...
       ])
yield_list:
tensor([0.7800, 0.9000, ...])

4. 训练数据获得模型

def train():
    ## super param
    N = 10  #int / int(len(dataset) * 1)  # 或者你可以设置为数据集大小的一定比例,如 int(len(dataset) * 0.1)
    NUM_EMBED = 294 # nn.Embedding()
    INPUT_SIZE = 300 # src length
    HIDDEN_SIZE = 512
    OUTPUT_SIZE = 512
    NUM_LAYERS = 10
    DROPOUT = 0.2
    CLIP = 1 # CLIP value
    N_EPOCHS = 100
    LR = 0.0001
    
    start_time = time.time()  # 开始计时
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = 'cpu'
    data = read_data("../dataset/round1_train_data.csv")
    dataset = ReactionDataset(data)
    subset_indices = list(range(N))
    subset_dataset = Subset(dataset, subset_indices)
    train_loader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)

    model = RNNModel(NUM_EMBED, INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, NUM_LAYERS, DROPOUT, device).to(device)
    model.train()
    
    optimizer = optim.Adam(model.parameters(), lr=LR)
    # criterion = nn.MSELoss() # MSE
    criterion = nn.L1Loss() # MAE

    best_loss = 10
    for epoch in range(N_EPOCHS):
        epoch_loss = 0
        for i, (src, y) in enumerate(train_loader):
            src, y = src.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(src)
            loss = criterion(output, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)  # 使用范数裁剪梯度
            optimizer.step()
            epoch_loss += loss.item()
            loss_in_a_epoch = epoch_loss / len(train_loader)
        print(f'Epoch: {epoch+1:02} | Train Loss: {loss_in_a_epoch:.3f}')
        if loss_in_a_epoch < best_loss:
            # 在训练循环结束后保存模型
            torch.save(model.state_dict(), '../model/RNN.pth')
    end_time = time.time()  # 结束计时
    # 计算并打印运行时间
    elapsed_time_minute = (end_time - start_time)/60
    print(f"Total running time: {elapsed_time_minute:.2f} minutes")

if __name__ == '__main__':
    train()

前面几个大写的字母是超参数。

通过DataLoader划分数据集,一共23538/128=184个batch。

训练模型包括以下几步:

  1. 实例化RNNModel并以train模式运行;
  2. 选择优化器为 Adam
  3. 使用平均绝对误差 (MAE) 作为损失函数
  4. 循环训练100次:
    • 遍历每个批次的数据
    • 使用GPU设备进行运算
    • 清除梯度
    • 前向传播计算输出
    • 计算损失
    • 反向传播计算梯度
    • 应用梯度裁剪
    • 更新模型参数
    • 累加每个批次的损失
    • 计算整个周期的平均损失

5. 使用测试集验证模型

# 生成结果文件
def predicit_and_make_submit_file(model_file, output_file):
    NUM_EMBED = 294
    INPUT_SIZE = 300
    HIDDEN_SIZE = 512
    OUTPUT_SIZE = 512
    NUM_LAYERS = 10
    DROPOUT = 0.2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    test_data = read_data("../dataset/round1_test_data.csv", train=False)
    test_dataset = ReactionDataset(test_data)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn) 

    model = RNNModel(NUM_EMBED, INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, NUM_LAYERS, DROPOUT, device).to(device)
    # 加载最佳模型
    model.load_state_dict(torch.load(model_file))
    model.eval()
    output_list = []
    for i, (src, y) in enumerate(test_loader):
        src, y = src.to(device), y.to(device)
        with torch.no_grad():
            output = model(src)
            output_list += output.detach().tolist()
    ans_str_lst = ['rxnid,Yield']
    for idx,y in enumerate(output_list):
        ans_str_lst.append(f'test{idx+1},{y:.4f}')
    with open(output_file,'w') as fw:
        fw.writelines('\n'.join(ans_str_lst))

    print("done!!!")
    
predicit_and_make_submit_file("../model/RNN.pth",
                              "../output/RNN_submit.txt")

上一步已经训练好了模型,这一步使用这个模型对测试集进行预测。

各超参数还是那些参数,数据集换成测试集。

实例化RNNModel,加载上一步保存的训练好的模型,然后设置模型为评估模式。

遍历测试DataLoader中的每个批次的数据,获取输入数据和标签。需要使用torch.no_grad()上下文管理器来禁用梯度计算,提高预测速度。

最后将预测结果写入txt文件。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容