最近跑一个项目时发现调用mps会出现nan loss的问题,甚至运行速度不如cpu
调用mps时发现的问题
使用cpu正常运行甚至速度还更快
解决方法
使用以下代码更新pytorch
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
更新后就没有问题了,速度也提升了
参考:https://discuss.pytorch.org/t/loss-becomes-nan-or-inf-when-using-mps/164774