PyTorch:切片函数index_select()

index_select()函数有两种用法。
第一种是将被切片的函数作为参数传入index_select()中

torch.index_select(input, dim, index, out=None)

还有一种是调用张量内置的index_select()函数。

input.index_select(dim, index)

index_select()函数的作用是针对张量input,在它的dim维度上切取index指定的范围切片。

参数:
input:被操作的张量
dim:维度
index:一维Tensor,表示索引下标的范围

例如

import torch
a = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7]])

b = torch.index_select(a, 0, torch.tensor([1]))
print(b)

c = torch.index_select(a, 1, torch.tensor([1,3]))
print(c)

输出为


这里维度dim从0开始算,则b表示在第0维(即行)上,切下下标为1的行;c表示在第1维(即列)上,切下下标为1和3的列。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容