load_model 如何导入自定义的loss 函数

训练一个lstm模型,然后保存为model.h5文件,之后load_model("model.h5") 出错,错误如下

ValueError: Unknown loss function:root_mean_squared_error

原因:训练模型时的loss函数是自己定义的RMSE函数,如下:

def root_mean_squared_error(y_true, y_pred):

        return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))

在此更正一下,RMSE 函数定义应该没有axis=-1,上面那个函数应该是定义的MAE

def root_mean_squared_error(y_true, y_pred):

        return K.sqrt(K.mean(K.square(y_pred - y_true)))

模型编译如下:

model.compile(optimizer = "rmsprop", loss = root_mean_squared_error, metrics =["accuracy"])

经过网上查找,找到一个快速并且有效的解决办法,在这里和大家分享,希望可以帮助小伙伴,解决同样的issue

需要再将root_mean_squared_error定义一遍,就是再写一遍(如何在你的script中已经存在root_mean_squared_error函数,就不需要重新定义了。我是写了两个scripts,一个用于模型训练,一个用于模型应用new data进行regression)

        def root_mean_squared_error(y_true, y_pred):

                return K.sqrt(K.mean(K.square(y_pred - y_true)))

然后在load_model中加入一个参数custom_objects如下:

model = load_model('model.h5', custom_objects={'root_mean_squared_error': root_mean_squared_error})


如果解决了您的问题,给个赞👍吧,谢谢!!!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容