PyTorch 是深度学习领域广泛使用的框架,而 Tensor 是其核心数据结构。理解 Tensor,不仅是构建神经网络的第一步,也是搞懂维度转换、索引切片、模型输入输出的关键所在。
本文将从最基础的 Tensor 结构、创建方法、维度理解入手,结合你整理的笔记内容和实战代码,帮助你从零掌握 PyTorch 中的 Tensor 操作。
一、什么是 Tensor?
Tensor 可以理解为一个通用的多维数组结构,在 PyTorch 中,它支持:
- 多维数据存储(标量、向量、矩阵、张量)
- GPU 运算加速
- 自动求导功能
它与 Python 的 list 或 NumPy 中的 ndarray 有一定相似性,但功能更丰富。
二、Tensor 的形状与维度
通过 .shape
可以获取 Tensor 的形状:
import torch
a = torch.arange(1, 10)
b = a.reshape(3, 3)
c = torch.arange(1, 13).reshape(2, 3, 2)
print(a.shape) # torch.Size([9])
print(b.shape) # torch.Size([3, 3])
print(c.shape) # torch.Size([2, 3, 2])
换一种直观理解:
c
本质上可以想象成一个嵌套的 list:
-
2
是最外层:有两个“块” -
3
是每个块中包含 3 个列表(行) -
2
是每行里包含两个值
这就是形状 [2, 3, 2]
背后的含义。
三、如何创建 Tensor
3.1 直接创建
torch.tensor([1, 2, 3], dtype=torch.float32, requires_grad=False)
参数说明:
-
data
: 原始数据(list、array) -
dtype
: 数据类型,如torch.float32
-
requires_grad
: 是否启用自动求导
3.2 随机生成
torch.rand(2, 3, 4) # 均匀分布 [0,1)
torch.randn(5, 4) # 该张量填充了来自均值 '0' 和方差 '1' 的正态分布的随机数(也称为标准正态分布)
torch.randint(0, 10, (2, 3)) # 在0到10中随机获取数值,组成一个 2*3 的 tensor
3.3 常用构造函数
函数 | 功能 |
---|---|
torch.zeros(2, 3) |
创建一个大小为 2*3,里面的值为 0 的 Tensor |
torch.ones(2, 3) |
创建一个大小为 2*3,里面的值为 1 的 Tensor |
torch.eye(n, m) |
创建一个 n*m 的矩阵,对角线是 1,其他是 0 |
torch.arange(1, 10, 2) |
创建一个1维 tensor 和 python range 一样 |
torch.linspace(0, 1, 5) |
等间距采样 |
linspace(start, end ,steps) |
在 start 到 end 内均匀的取出 steps 个值 |
四、如何理解高维 Tensor?(0D 到 5D)
维度 | 类型 | 举例 | 含义 |
---|---|---|---|
0D | 标量 | 5 |
单个值 |
1D | 向量 | [1, 2, 3] |
一维数组 |
2D | 矩阵 | [[1, 2], [3, 4]] |
行列形式 |
3D | 立方体 | [[[1,2],[3,4]],[[5,6],[7,8]]] |
嵌套结构 |
4D | 批量图像 | torch.tensor(10, 5, 28, 28) |
视频帧、图像批次 |
5D | 批量序列 | torch.tensor(32, 10, 5, 28, 28) |
LSTM/3D卷积输入 |
五、Tensor 的基础运算
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
a + b # tensor([4, 6])
a - b # tensor([-2, -2])
a * b # tensor([3, 8])
a / b # tensor([0.3333, 0.5000])
运算默认逐元素操作(element-wise),支持广播机制。
六、Tensor 切片、索引与条件筛选
Tensor 的切片操作和 Python list / NumPy 类似,可以对全数据或指定维度扫描。
基础切片
x = torch.arange(1, 10).reshape(3, 3)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
x[1] # 第 1 行 → tensor([4, 5, 6])
x[0:2] # 第 0 到 1 行
x[:, 1] # 所有行的第 1 列 → tensor([2, 5, 8])
x[0:2, 1:3] # 左上角 2x2 子矩阵
x[n:m, x:y] # 获取第 n 到第 m-1 行,在获取每行的第 x 列到第 y-1 列的值
条件筛选
x[(x > 2) & (x < 5)] # 选择大于 2 小于 5 的元素
x[x > 3] # 选择大于 3 的元素
步长切片和对角线
x[::2, ::2] # 每 2 行 2 列选择
x.diag() # 主对角线
七、Tensor 变换:reshape、unsqueeze、squeeze
reshape :改变形状
x = torch.arange(12)
x.reshape(3, 4)
unsqueeze :增加维度
x.unsqueeze(0).shape # [1, 12]
x.unsqueeze(1).shape # [12, 1]
squeeze :删除尺寸为 1 的维度
a = torch.rand(1, 12, 1)
a.squeeze(0).shape # [12, 1]
a.squeeze(2).shape # [1, 12]
view :与 reshape 相似,tensor 中元素数量和值不变的前提下修改 shape
x.view(2, 6)
八、Tensor 合并与拆分
torch.cat() 合并
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 按行合并
c = torch.cat((a, b), dim=0)
# 按列合并
c = torch.cat((a, b), dim=1)
torch.chunk() 拆分
拆解规则
- 尽量平均:尽量把张量划分为相等大小的块。
- 不能整除的时候,前面的块会稍微大一点。向上取整,反向抹零
# 将 Tensor 按给定维度分成 chunks 个块
chunks = torch.chunk(c, chunks=3, dim=1)
九、基本的操作工具
torch.flip() 置反
a.flip(0) # 按行置反
a.flip(1) # 按列置反
torch.gather() 按 index 查找
input = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 1], [1, 0]])
result = torch.gather(input, dim=1, index=index)
tensor.max(dim)
values, indices = input.max(dim=1) # 每行最大值