创建张量
x=torch.empty(5,3,dtype=torch.long) #未初始化空张量 5*3维
x=torch.rand(5,3) #随机初始化 [0,1]区间均匀分布 5*3维
x=torch.randn(5,3) #随机初始化 正态分布(均值0,方差1)5*3维
x=torch.zeros(5,3,dtype=torch.long) #以0初始化 5*3维
x=torch.ones(5,3) #以1初始化 5*3维
x=torch.tensor([5.5,3]) #用矩阵创建张量 1*2维
print(x)
x=x.new_ones(5,3,dtype=torch.double) #基于现用张量创建新的张量,保留原来的性质如dtype
print(x)
x=x.new_zeros(5,3,dtype=torch.double)
print(x)
x=torch.randn_like(x,dtype=torch.float) #覆盖类型,与x具有相同size
print(x)
print(x.size())
output:
tensor([5.5000, 3.0000])
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], dtype=torch.float64)
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=torch.float64)
tensor([[ 0.2667, -0.6650, 0.7252],
[ 2.2209, -0.8658, -0.2760],
[-0.9438, -0.2276, -1.3241],
[ 0.8277, -0.8688, -2.5501],
[-0.1919, 2.1744, 0.4230]])
torch.Size([5, 3])
张量运算
加法运算
- x+y
- torch.add(x+y)
- torch.add(x+y,out=result)
- y.add_(x)
例子
x=torch.randn(2,2)
y=torch.rand(2,2)
print(x+y)
print(torch.add(x,y))
result=torch.empty(2,2)
torch.add(x,y,out=result)
print(result)
y.add_(x)
print(y)
output:
tensor([[ 1.0287, -0.3419],
[ 0.3529, 0.8750]])
tensor([[ 1.0287, -0.3419],
[ 0.3529, 0.8750]])
tensor([[ 1.0287, -0.3419],
[ 0.3529, 0.8750]])
tensor([[ 1.0287, -0.3419],
[ 0.3529, 0.8750]])
Resize/Reshape张量
- x.view()
例子
x=torch.randn(4,4)
y=x.view(16)
z=x.view(-1,8) # 第一个参数是-1表示该维度从另一个维度推断
print(x.size(),y.size(),z.size())
output:
torch.Size([4, 4]) torch.Size([16]) torch.Size([2, 8])
item()函数
张量只有一个值时,返回一个python值
例子
x=torch.randn(1)
print(x)
print(x.item())
output:
tensor([0.6466])
0.6466003656387329
torch张量与NumPy相互转换
- b=a.numpy()
- b=torch.from_numpy(a)
- Torch Tensor和NumPy数组共享内存(如果TorchTensor在CPU上),更改一个将更改另一个。
例子1
a=torch.ones(5)
b=a.numpy()
print(a)
print(b)
print(a.add_(1))
print(b)
output:
tensor([1., 1., 1., 1., 1.])
[1. 1. 1. 1. 1.]
tensor([2., 2., 2., 2., 2.])
[2. 2. 2. 2. 2.]
例子2
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)
output:
[2. 2. 2. 2. 2.]
tensor([2., 2., 2., 2., 2.], dtype=torch.float64)