PyTorch Tensor 介绍和基本操作

PyTorch 是深度学习领域广泛使用的框架,而 Tensor 是其核心数据结构。理解 Tensor,不仅是构建神经网络的第一步,也是搞懂维度转换、索引切片、模型输入输出的关键所在。

本文将从最基础的 Tensor 结构、创建方法、维度理解入手,结合你整理的笔记内容和实战代码,帮助你从零掌握 PyTorch 中的 Tensor 操作。


一、什么是 Tensor?

Tensor 可以理解为一个通用的多维数组结构,在 PyTorch 中,它支持:

  • 多维数据存储(标量、向量、矩阵、张量)
  • GPU 运算加速
  • 自动求导功能

它与 Python 的 list 或 NumPy 中的 ndarray 有一定相似性,但功能更丰富。


二、Tensor 的形状与维度

通过 .shape 可以获取 Tensor 的形状:

import torch

a = torch.arange(1, 10)
b = a.reshape(3, 3)
c = torch.arange(1, 13).reshape(2, 3, 2)

print(a.shape)  # torch.Size([9])
print(b.shape)  # torch.Size([3, 3])
print(c.shape)  # torch.Size([2, 3, 2])

换一种直观理解:

c 本质上可以想象成一个嵌套的 list:

  • 2 是最外层:有两个“块”
  • 3 是每个块中包含 3 个列表(行)
  • 2 是每行里包含两个值

这就是形状 [2, 3, 2] 背后的含义。


三、如何创建 Tensor

3.1 直接创建

torch.tensor([1, 2, 3], dtype=torch.float32, requires_grad=False)

参数说明

  • data: 原始数据(list、array)
  • dtype: 数据类型,如 torch.float32
  • requires_grad: 是否启用自动求导

3.2 随机生成

torch.rand(2, 3, 4)   # 均匀分布 [0,1)
torch.randn(5, 4)     # 该张量填充了来自均值 '0' 和方差 '1' 的正态分布的随机数(也称为标准正态分布)
torch.randint(0, 10, (2, 3))  # 在0到10中随机获取数值,组成一个 2*3 的 tensor

3.3 常用构造函数

函数 功能
torch.zeros(2, 3) 创建一个大小为 2*3,里面的值为 0 的 Tensor
torch.ones(2, 3) 创建一个大小为 2*3,里面的值为 1 的 Tensor
torch.eye(n, m) 创建一个 n*m 的矩阵,对角线是 1,其他是 0
torch.arange(1, 10, 2) 创建一个1维 tensor 和 python range 一样
torch.linspace(0, 1, 5) 等间距采样
linspace(start, end ,steps) 在 start 到 end 内均匀的取出 steps 个值

四、如何理解高维 Tensor?(0D 到 5D)

维度 类型 举例 含义
0D 标量 5 单个值
1D 向量 [1, 2, 3] 一维数组
2D 矩阵 [[1, 2], [3, 4]] 行列形式
3D 立方体 [[[1,2],[3,4]],[[5,6],[7,8]]] 嵌套结构
4D 批量图像 torch.tensor(10, 5, 28, 28) 视频帧、图像批次
5D 批量序列 torch.tensor(32, 10, 5, 28, 28) LSTM/3D卷积输入

五、Tensor 的基础运算

a = torch.tensor([1, 2])
b = torch.tensor([3, 4])

a + b  # tensor([4, 6])
a - b  # tensor([-2, -2])
a * b  # tensor([3, 8])
a / b  # tensor([0.3333, 0.5000])

运算默认逐元素操作(element-wise),支持广播机制。


六、Tensor 切片、索引与条件筛选

Tensor 的切片操作和 Python list / NumPy 类似,可以对全数据或指定维度扫描。

基础切片

x = torch.arange(1, 10).reshape(3, 3)
# tensor([[1, 2, 3],
#         [4, 5, 6],
#         [7, 8, 9]])

x[1]         # 第 1 行 → tensor([4, 5, 6])
x[0:2]       # 第 0 到 1 行
x[:, 1]      # 所有行的第 1 列 → tensor([2, 5, 8])
x[0:2, 1:3]  # 左上角 2x2 子矩阵
x[n:m, x:y]  # 获取第 n 到第 m-1 行,在获取每行的第 x 列到第 y-1 列的值

条件筛选

x[(x > 2) & (x < 5)]   # 选择大于 2 小于 5 的元素
x[x > 3]               # 选择大于 3 的元素

步长切片和对角线

x[::2, ::2]  # 每 2 行 2 列选择
x.diag()     # 主对角线

七、Tensor 变换:reshape、unsqueeze、squeeze

reshape :改变形状

x = torch.arange(12)
x.reshape(3, 4)

unsqueeze :增加维度

x.unsqueeze(0).shape  # [1, 12]
x.unsqueeze(1).shape  # [12, 1]

squeeze :删除尺寸为 1 的维度

a = torch.rand(1, 12, 1)
a.squeeze(0).shape  # [12, 1]
a.squeeze(2).shape  # [1, 12]

view :与 reshape 相似,tensor 中元素数量和值不变的前提下修改 shape

x.view(2, 6)

八、Tensor 合并与拆分

torch.cat() 合并

a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])

# 按行合并
c = torch.cat((a, b), dim=0)
# 按列合并
c = torch.cat((a, b), dim=1)

torch.chunk() 拆分

拆解规则

  • 尽量平均:尽量把张量划分为相等大小的块。
  • 不能整除的时候,前面的块会稍微大一点。向上取整,反向抹零
# 将 Tensor 按给定维度分成 chunks 个块
chunks = torch.chunk(c, chunks=3, dim=1)

九、基本的操作工具

torch.flip() 置反

a.flip(0)  # 按行置反
a.flip(1)  # 按列置反

torch.gather() 按 index 查找

input = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 1], [1, 0]])
result = torch.gather(input, dim=1, index=index)

tensor.max(dim)

values, indices = input.max(dim=1)  # 每行最大值
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容