我们时常想使用一些功能, 它们可以被归类但又分散在项目的各个位置, 一般它们与核心代码无关. 例如, 我们想写追溯程序的运行状态, 一般我们会在程序运行的各个节点写下诸如 print(‘training starts....’)
的代码. 这些代码与主程序无关,如果全部写在主程序里面会显得比较冗杂, 它们实现同一个功能但又分散在程序的各个位置,这时就可以用到callbacks了.
原则
首先明确callbacks的使用原则.
- 独立. callback需要保持独立的功能, callbacks 之间也不要相互调用对方的方法.
- 无顺序. 不建议callbacks之间相互依赖, callbacks之间不要有执行的先后顺序.
- 不要调用上一级的方法. 例如
on_validation_end()
.
位置
要使用callbacks, 首先要明确这些钩子的位置. 官网给了一个伪代码来帮我们快速找到它们的位置.
使用
from lightning.pytorch.callbacks import Callback
# 通过继承创建一个自己的Hook
class MyPrintingCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is starting")
def on_train_end(self, trainer, pl_module):
print("Training is ending")
# 注册自定的Hook
trainer = Trainer(callbacks=[MyPrintingCallback()])
因此到位置的时候, 程序会执行注册钩子里的代码:
# pseudocode
on_train_start() # "Training is starting"
train()
on_train_end() # "Training is ending"
内置callbacks
有一些内置callbacks实现了一些用的功能, 以下显示了一些我觉得有用的:
Ref
https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#hooks
https://stephencowchau.medium.com/pytorch-lightning-hooks-and-callbacks-my-limited-understanding-d8e0a56dcf2b