按照《The The Annotated Transformer》教程写下来,卡在了Greedy Decoding部分,回溯发现问题出在两个部分
- 在Embeddings部分,教程中如下
class Embeddings(nn.Module):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, x):
x = x.long()
print(type(x))
print(x)
return self.lut(x) * math.sqrt(self.d_model)
会报错,提示要求输入的x为LongTensor,在这里我增加了
x = x.long()
解决
- 在LabelSmoothing部分,使用scatter_函数的时候提示。没搞定,这样模型训练应该是有问题的……
RuntimeError: Expected object of type torch.LongTensor but found type torch.IntTensor for argument #3 'index'
查了半天,应该是Pytorch新老版本的问题,现搁置起来,继续跑完项目所有的训练步骤。
如果有大神遇到这样的问题解决了,麻烦留言啊,我快疯了……