RNN使用mini-batch的每一个数据是需要等长的,因此需要对输入数据进行 处理
输入格式【batch_size, seq_len, num_hiddens】
- 对于所有输入数据短的进行填充,长的进行截断,使所有数据都一样长。但填充的时候如果在最后补上缺失字符其实是相当于噪声的。如果只取最后一个状态来预测,效果极差。可选择在前面补上缺失字符,效果会好一点。
if pad_size:
if len(token) < pad_size:
token=[PAD] * (pad_size - len(token))+token
else:
token = token[:pad_size]
seq_len = pad_size
- 使用pad_sequence方法 将每个mini-batch中的数据变为等长,这样每个mini-batch不等长,而mini-batch中的数据是等长的
首先定义collate_fn 方法
def collate_fn(data):
"""
定义 dataloader 的返回值
:param data:
:return:
"""
data.sort(key=lambda x: len(x[0]), reverse=True)
data_length = [len(sq[0]) for sq in data]
x = [i[0] for i in data]
y = [i[1] for i in data]
data = rnn_utils.pad_sequence(x, batch_first=True, padding_value=0) # 这里最好也设置在最前面填充
return data.unsqueeze(-1), data_length, torch.tensor(y, dtype=torch.float32).view(-1, 1)
然后 在Dataloader中设置collate_fn
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True,collate_fn=collate_fn)