PyTorch 快速上手

torch.Tensor数据类型

torch.Tensor是一种包含单一数据类型元素的多维矩阵。

Data tyoe CPU tensor GPU tensor
32-bit floating point torch.FloatTensor torch.cuda.FloatTensor
64-bit floating point torch.DoubleTensor torch.cuda.DoubleTensor
16-bit floating point N/A torch.cuda.HalfTensor
8-bit integer (unsigned) torch.ByteTensor torch.cuda.ByteTensor
8-bit integer (signed) torch.CharTensor torch.cuda.CharTensor
16-bit integer (signed) torch.ShortTensor torch.cuda.ShortTensor
32-bit integer (signed) torch.IntTensor torch.cuda.IntTensor
64-bit integer (signed) torch.LongTensor torch.cuda.LongTensor

torch.Tensor是默认的tensor类型(torch.FlaotTensor)的简称

会改变tensor的函数操作会用一个下划线后缀来标示。比如,torch.FloatTensor.abs_()会在原地计算绝对值,并返回改变后的tensor,而tensor.FloatTensor.abs()将会在一个新的tensor中计算结果。

创建Tensor

# uninitialized
torch.empty()
torch.FloatTensor()
torch.IntTensor(d1,d2,d3)

torch.tensor([1.2, 3]).type()
# 设置默认数据类型
torch.set_default_tensor_type(torch.DoubleTensor)

# 随机初始化
a = torch.rand(3,3)  #  [0,1]
torch.rand_like(a)
torch.randint(1,10,[3,3])  # [min, max]
# 正态分布
torch.randn(3,3)  # N(0,1)
torch.normal(mean=torch.full([10], 0), std=torch.arange(1, 0, -0.1))

torch.full([2,3], 7)  # 每个元素都设置为7
torch.full([], 7)  # 标量
torch.arange(0,10)

# linspace/logspace
torch.linspace(0,10, steps=4)
torch.logspace(0, -1, steps=10)

# ones/zeros/eye/*_like
torch.ones(3,3)
torch.zeros(3,3)
torch.eye(3,4)

# randperm == random.shuffle
torch.randperm(10)

Tensor 切片

类似于numpy切片操作,eg: a[1:10,:], a[:10:2,:]

a = torch.randn(4,3,28,28)
a[:2]
a[:2, 1:, :,:].shape # output: [2,2,28,28]

# select by specific index
a.index_select(0, torch.tensor([0,2])) 
a[...].shape # 任意维度
a[..., :2] # 与*list 变长解包类似?

# select by mask
x = torch.randn(3,4)
mask = x.ge(0.5)
torch.masked_select(x, mask)

# select by flatten index
src = torch.tensor([[4,3,5], [6,7,8]])
torch.take(src, torch.tensor([0,2]))

Tensor维度变换

  • view/reshape
  • squeeze/unsqueeze
  • transpose/t/permute
  • expand/repeat
# view reshape  (lost dim information)
In [41]: a = torch.rand(4,1 ,28, 28)

In [42]: a.shape
Out[42]: torch.Size([4, 1, 28, 28])

In [43]: a.view(4, 28*28)
Out[43]: 
tensor([[0.6006, 0.8933, 0.1474,  ..., 0.5848, 0.9790, 0.6479],
        [0.1824, 0.8874, 0.1635,  ..., 0.3386, 0.3563, 0.0075],
        [0.8867, 0.9460, 0.1208,  ..., 0.1569, 0.2614, 0.7639],
        [0.1437, 0.5749, 0.2275,  ..., 0.5167, 0.6074, 0.5263]])
In [44]: a.view(4, 28*28).shape
Out[44]: torch.Size([4, 784])

# unsqueeze(维度增加)
In [50]: b = torch.rand(32)
In [51]: f = torch.rand(4,32, 14,14)
In [52]: b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
In [53]: b.shape
Out[53]: torch.Size([1, 32, 1, 1])
# expand/repeat
b.expand([4,32,14,14]) # [1,32,1,1] -> [4,32,14,14]
b.repeat(4,1,32,32) # 重复

# a.t() 2d数据
# transpose
a.transpose(1,3) # 指定交换的dim
a.transpose(1,3).contiguous()

# permute 交换维度
# [b c h w] -> [b h w c]
b.permute(0,2,3,1) # [b h w c]
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。