2022-1-5, Wed., 13:37 于鸢尾花基地
可以采用如下方式对之前保存的预训练模型进行批量测试:
for ckpt in ckpt_list:
model = ptl_module.load_from_checkpoint(ckpt, args=args)
trainer.test(model, dataloaders=test_dataloader)
然而,在上述循环中,通过trainer.test
每执行一次测试,都只是执行了一个epoch
的测试(也就是执行多次ptl_module.test_step
和一次ptl_module.test_epoch_end
),而不可能把ckpt_list
中的多个预训练模型(checkpoint
)当做多个epoch
,多次执行ptl_module.test_epoch_end
。
我们期望,对多个checkpoint
的测试能像对多个epoch
的训练一样简洁:
trainer.test(ptl_module, dataloaders=test_dataloader)
怎么做到?在训练过程中,要训练多少个epoch
是由参数max_epochs
来决定的;而在测试过程中,怎么办?PTL并非完整地保存了所有epoch的预训练模型。
由于在测试过程中对各checkpoint
是独立测试的,如果要统计多个checkpoint
的最优性能(如最大PSNR/SSIM),怎么办?这里的一个关键问题是如何保存每次测试得到的评估结果,好像PTL并未对此提供接口。
解决方案
PTL提供了“回调类(Callback)”(在 pytorch_lightning.callbacks
中),可以自定义一个回调类,并重载on_test_epoch_end
方法,来监听ptl_module.test_epoch_end
。
如何使用?只需要在定义trainer
时,把该自定义的回调函数加入其参数callbacks
即可:ptl.Trainer(callbacks=[MetricTracker()])
。这里,MetricTracker
为自定义的回调类,具体如下:
class MetricTracker(Callback):
def __init__(self):
self.optim_metrics = None
def on_test_epoch_end(self, trainer, pl_module):
if self.optim_metrics is None:
self.optim_metrics = pl_module.metrics_dict
return
tensorboard = pl_module.logger.experiment
metrics_key_list, metrics_val_list = [], []
for k in pl_module.metrics_dict:
# comp_fun 是自己定义的比较函数
self.optim_metrics[k] = comp_fun(self.optim_metrics[k], pl_module.metrics_dict[k])
评论: 由于MetricTracker
具有与Trainer
相同的生命周期,因此,在整个测试过程中,MetricTracker
能够维护一个最优的评估结果optim_metrics
。