torch.split(tensor, split_size, dim=)
tensor是要切割的张量,dim表示在哪个维度上面进行切割
注意:split_size是切分后每块的大小,不是切分为多少块!
a = torch.LongTensor([[1,2,3,4],[2,3,4,5]])
b = torch.cat(torch.split(a, 4, dim=1), dim=0)
print(b)
输出:tensor([[1, 2, 3, 4],
[2, 3, 4, 5]])
torch.split(tensor, split_size, dim=)
tensor是要切割的张量,dim表示在哪个维度上面进行切割
a = torch.LongTensor([[1,2,3,4],[2,3,4,5]])
b = torch.cat(torch.split(a, 4, dim=1), dim=0)
print(b)
输出:tensor([[1, 2, 3, 4],
[2, 3, 4, 5]])