Pytorch 之 模型的保存与调用

介绍关于用pytorch搭建模型时,对模型进行保存以及再次调用模型参数的相关函数命令。

使用 torch.save(model.state_dict(), PATH)来保存模型学习到的参数,给模型恢复提供最大的灵活性。

先对模型进行实例化,再用load_state_dict()调用模型,在对模型进行推理之前,调用model.eval():

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))  #该函数只接收字典对象,而不是保存对象的路径,在这之前要反序列化保存的state_dict。

model.eval()


torch.load( f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>,  **pickle_load_args)  从文件加载用torch.save()保存的对象。  目前需要知道该函数前两个参数的正确使用即可

f: 类似于文件的对象,或包含文件名称的字符串,如:要载入的模型所在的完整路径的字符串

map_location: 一个函数,torch.device,字符串或字典,明确如何重映射存储空间位置

pickle_module:用于解开元数据和对象的模块(必须与序列化文件的pickle_module相匹配)

pickle_load_args:(只有Python3才有)可选择的关键字参数,并传递给pickle_module.load()和pickle_module.Unpickler(),比如,errors=...。

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容