该工具主要包含以下几行代码:
# 1 start a new experiment
wandb.init(project="new-project-model")
# 2 capture a dictionary of hyperparameters with config
wandb.config = {"learning_rate": 0.001, "epochs": 100, "batch_size": 128}
# 3 optional: track gradients
wandb.watch(model)
# 4 log metrics inside your training loop to visualize model performance
wandb.log(metrics)
# 5 optional: save model at the end
model.to_onnx()
wandb.save("model.onnx")
下面根据以上内容详解如何进行使用。
安装
进入虚拟环境,如果是在 jupyter notebook 中可以使用以下命令。
!pip install wandb -Uq
导入模块
import wandb
接下来进入 weights and bias 官网,注册个人账号,找到 api_key,执行下面的命令时,输入key。
wandb.login()
1 配置 config 字典
将模型训练需要的超参数放置在 config 中,通过 config 传递到数据的 dataloader 中。
config = dict(
epochs=10,
classes=2,
batch_size=1024,
learning_rate=0.001
)
2 wandb.init() 入口
# 所有关于模型训练的部分都要以 wandb.init() 作为入口。
with wandb.init(project="pytorch-demo", config=config):
# access all HPs through wandb.config, so logging matches execution!
config = wandb.config
# 通过config 获取 model, data, and optimization 等
model, train_loader, test_loader, criterion, optimizer
# and use them to train the model
train(model, train_loader, criterion, optimizer, config)
# and test its final performance
test(model, test_loader)
3 wandb.watch() 监听
在训练模型时,监听模型。
def train(model, train_loader, criterion, optimizer, config):
# Tell wandb to watch what the model gets up to: gradients, weights, and more!
wandb.watch(model, log="all", log_freq=1024)
......
其中,参数 log_freq 设置不宜过小,如果过小就会频繁将数据写入官网记录,导致 gpu 资源利用率不高,如下图:
参数 log="all",是将 gradients 和 parameters 都写入记录,随着训练时间变长,使得参数量增大,gpu 速度也会变得越来越低。
def watch(
models,
criterion=None,
log: Optional[Literal["gradients", "parameters", "all"]] = "gradients",
log_freq: int = 1000,
idx: Optional[int] = None,
log_graph: bool = False,
):
"""Hooks into the torch model to collect gradients and the topology.
Should be extended to accept arbitrary ML models.
Args:
models: (torch.Module) The model to hook, can be a tuple
criterion: (torch.F) An optional loss value being optimized
log: (str) One of "gradients", "parameters", "all", or None
log_freq: (int) log gradients and parameters every N batches
idx: (int) an index to be used when calling wandb.watch on multiple models
log_graph: (boolean) log graph topology
4 日志记录
def train(model, train_loader, criterion, optimizer, config):
# Tell wandb to watch what the model gets up to: gradients, weights, and more!
wandb.watch(model, log="all", log_freq=1024)
......
for i_batch, batch in enumerate(train_loader):
......
if i_batch % 20 == 0:
# 以字典的形式记录
wandb.log({"train_loss:": loss })
其中,如果每个batch都记录的话,会导致溢出,故需要每隔批次保存。 打开官网,找到项目路径,即可看到如下图。
watch.log() 记录几个,即可产生几个图。
5 模型保存
将训练好的模型保存成 onnx 文件。值得注意的是,存放的参数类型是 tuple 格式。
def test(model, test_loader):
model.eval()
# Run the model on some test examples
with torch.no_grad():
for feature, labels in test_loader:
......
# Save the model in the exchangeable ONNX format
torch.onnx.export(model, features, "model.onnx")
wandb.save("model.onnx")
参数如下。
Export a model into ONNX format. This exporter runs your model
once in order to get a trace of its execution to be exported;
at the moment, it supports a limited set of dynamic models (e.g., RNNs.)
Arguments:
model (torch.nn.Module): the model to be exported.
args (tuple of arguments): the inputs to
the model, e.g., such that ``model(*args)`` is a valid
invocation of the model. Any non-Tensor arguments will
be hard-coded into the exported model; any Tensor arguments
will become inputs of the exported model, in the order they
occur in args. If args is a Tensor, this is equivalent
to having called it with a 1-ary tuple of that Tensor.
(Note: passing keyword arguments to the model is not currently
supported. Give us a shout if you need it.)
f: a file-like object (has to implement fileno that returns a file descriptor)
or a string containing a file name. A binary Protobuf will be written to this file.
如果有用就点个赞吧 ^_~
参考资料:
https://docs.wandb.ai/guides/track/limits
https://docs.wandb.ai/quickstart
https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch/Simple_PyTorch_Integration.ipynb#scrollTo=O9IZcHfDRrHR
sweep参数优化原理 https://arxiv.org/pdf/1807.01774.pdf
https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch/Organizing_Hyperparameter_Sweeps_in_PyTorch_with_W%26B.ipynb#scrollTo=QkRm4tLyRLGn