改变张量的形状

PyTorch 中改变张量形状有 view、reshape 和 resize_ (没有原地操作的resize方法未来会被丢弃) 三种方式,其中 resize_ 比较特殊,它能够在修改张量形状的同时改变张量的大小,而 view 和 reshape 方法不能改变张量的大小,只能够重新调整张量形状。

resize_ 方法比较特殊,后续用到的时候再详细介绍。本文主要介绍 view 和 reshape 方法,在 PyTorch 中 view 方法存在很长时间,reshape 方法是在 PyTorch0.4 的版本中引入,两种方法功能上相似,但是一些细节上稍有不同,因此这里介绍两个方法的不同之处。

  • view 只能用于数据连续存储的张量,而 reshape 则不需要考虑张量中的数据是否连续存储

nD 张量底层实现是使用一块连续内存的一维数组,由于 PyTorch 底层实现是 C 语言 (C/C++ 使用行优先方式),所以n维张量也使用行优先方式。比如对于下面形状为 (3 x 3) 的 2D 张量:

2D 张量在内存中实际以一维数组的形式进行存储,行优先的方式指的是存储的顺序按照 2D 张量的行依次存储。

上面形状为 (3 x 3) 的 2D 张量通常称为存储的逻辑结构,而实际存储的一维数组形式称为存储的物理结构。

  1. 如果元素在存储的逻辑结构上相邻,在存储的物理结构中也相邻,则称为连续存储的张量;

  2. 如果元素在存储的逻辑结构上相邻,但是在存储的物理结构中不相邻,则称为不连续存储的张量;

为了方便理解代码,先来熟悉一些方法。

  • 可以通过 tensor.is_contiguous() 来查看 tensor 是否为连续存储的张量;

  • PyTorch 中的转置操作能够将连续存储的张量变成不连续存储的张量;

>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

True

>>> view_a = a.view(1, 9)
>>> reshape_a = a.reshape(9, 1)
>>> # 通过转置操作将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

False

>>> # view_t_a = a.view(1, 9) error
>>> reshape_t_a = a.reshape(1, 9)

其中 view_t_a = a.view(1, 9) 会抛出异常,再次验证了 view 只能用于数据连续存储的张量,而 reshape 则不需要考虑张量中的数据是否连续存储。

  • view 方法会返回原始张量的视图,而 reshape 方法可能返回的是原始张量的视图或者拷贝

原始张量的视图简单来说就是和原始张量共享数据,因此如果改变使用 view 方法返回的新张量,原始张量也会发生想对应的改变。

>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                  [3, 4, 5],
                  [6, 7, 8]])
>>> view_a = a.view(1, 9)
>>> print(view_a)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])

>>> # 更改张量中的元素值
>>> view_a[:, 1] = 100
>>> print(a)

tensor([[  0, 100,   2],
        [  3,   4,   5],
        [  6,   7,   8]])

>>> print(view_a)

tensor([[  0, 100,   2,   3,   4,   5,   6,   7,   8]])

reshape 方法可能返回的是原始张量的视图或者拷贝,当处理连续存储的张量 reshape 返回的是原始张量的视图,而当处理不连续存储的张量 reshape 返回的是原始张量的拷贝。

>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

True

>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)

tensor([[  0, 100,   2],
        [  3,   4,   5],
        [  6,   7,   8]])

>>> print(reshape_a)

tensor([[  0, 100,   2,   3,   4,   5,   6,   7,   8]])
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
>>> # 通过转置将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

False

>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)

tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])

>>> print(reshape_a)

tensor([[  0, 100,   6,   1,   4,   7,   2,   5,   8]])

原文地址:
PyTorch入门笔记-改变张量的形状

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容