Pytorch中的view()和reshape()的功能都是reshape tensor:
import torch
x = torch.arange(10)
x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)
其区别是:
- view()要求tensor必须是Contiguous Memory,遇到noncontiguous memory会报错!Contiguous Memory vs Noncontiguous Memory
- reshape()没有上述要求,在操作Contiguous Memory时,性能比view()稍差
import torch
x = torch.arange(10)
# contiguous memory
x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)
# noncontiguous memory
y = x_2x5.t()
y_1x10 = y.view(10)
报错信息:
y_1x10 = y.view(10)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
解决方式:用Tensor.contiguous(memory_format=torch.contiguous_format) → Tensor
方法,将noncontiguous memory变成contiguous memory,然后再用view()
import torch
x = torch.arange(10)
# contiguous memory
x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)
# noncontiguous memory
y = x_2x5.t()
y_1x10 = y.contiguous().view(10)
print(y_1x10.shape)
执行结果:
torch.Size([10])