本篇介绍两个高阶操作where和gather
where
- torch.where(condition, x, y) -> Tensor
gather
应用场景
在总共10类的分类任务中,如果真实标签是100到109,而不是0到1。但是网络的预测值却是0到1,那我怎么将预测值和真实值映射到一起呢?
有人可能就说,我让预测值都加100,或者真实标签都减100,不就行了嘛。
但如果真实标签和预测值之间不是这种简单的映射关系呢,比如是[100, 102, 105, 106, 109, 110, 200, 400, 900, 10000]呢,当然,你也可以粗暴地找个映射关系。
但最优雅的方式还是使用这里的gather 方法
- 这里虽然label中的值千奇百怪,但是这些值的索引还是0到10,这样就可以简单映射了。