pack用于打包多个tensor
pack操作,输入的sequences是tensor组成的list,要求按长度从大到小排序
import torch
import torch.nn.utils.rnn
from torch.nn.utils.rnn import pack_sequence
a=torch.tensor([1,2,3])
b=torch.tensor([4,5])
c=torch.tensor([6])
pack_sequence([a,b,c])
PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]))
pad操作,sequences也是list。给list里的tensor都用padding_value来pad成最长的长度,并组合成一个tensor
torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0)
from torch.nn.utils.rnn import pad_sequence
a=torch.ones(25,300)
b=torch.ones(22,300)
c=torch.ones(15,300)
pad_sequence([a,b,c]).shape
torch.Size([25, 3, 300])
使用的多的是下面两个函数,因为数据我们常常已经预处理经过了pad
torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)
import torch as t
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
a = t.tensor([[1,2,3],[6,0,0],[4,5,0]]) #(batch_size, max_length)
lengths = t.tensor([3,1,2])
# 排序
a_lengths, idx = lengths.sort(0, descending=True)
_, un_idx = t.sort(idx, dim=0)
a = a[idx]
# 定义层
emb = t.nn.Embedding(20,2,padding_idx=0)
lstm = t.nn.LSTM(input_size=2, hidden_size=4, batch_first=True)
a_input = emb(a)
a_packed_input = t.nn.utils.rnn.pack_padded_sequence(input=a_input, lengths=a_lengths, batch_first=True)
packed_out, _ = lstm(a_packed_input)
out, _ = pad_packed_sequence(packed_out, batch_first=True)
# 根据un_idx将输出转回原输入顺序
out = t.index_select(out, 0, un_idx)