项目地址 https://github.com/PyTorchLightning/pytorch-lightning
以下内容整理自项目作者的讲解视频:Converting from PyTorch to PyTorch Lightning (油管视频需梯自备子)
import torch.nn as nn
import torch
import torch.optim as optim
import pytorch_lightning as pl
class Net(pl.LightningModule):
def __init__(self):
super().__init__()
def forward(self,x):
# 可以结合training_step函数,简化forward的内容
pass
def loss_func(self, y_hat, y):
return F.cross_entropy(y_hat, y)
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)
def training_step(self, batch, batch_idx):
x,y = batch #
y_hat = self(x)
# return {'loss':F.cross_entropy(y_hat, y)}
loss = self.loss_func(y_hat, y)
return {'loss':loss}
################################
# log = {'train_loss':loss}
# return {'loss':loss, 'log':log}
# 这样就可以在tensorboard中看到train_loss的曲线
def log_func(self,):
# do whatever you want, print, file operation, etc.
pass
def validation_step(self, batch, batch_idx):
# !!! val data 不应该用shuffle
x,y = batch #
y_hat = self(x)
val_loss = self.loss_func(y_hat, y)
if batch_idx == 0:
n = x.size(0)
self.log_func()
return {'val_loss':val_loss}
##############################################################
### 这里定义了dataloader fit里就不用通过参数传入了
################################
def train_dataloader(self):
loader = torch.utils.data.DataLoader()
return loader
def val_dataloader(self):
loader = torch.utils.data.DataLoader()
return loader
################################
# 使用tensorboard等 logger, 替代validation_step中log_func这一部分
################################
def validation_epoch_end(self, outputs):
# 计算batch的平均损失,这里的outputs就是validation_step返回的
val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
# 也可以传入其他数据,如VAE 重建的图像
# x_hat = outputs[0]['x_hat']
# grid = torchvision.utils.make_grid(x_hat)
# self.logger.experiment 就是 tensorboard SummaryWriter
self.logger.experiment.add_image('images', grid,0)
log = {'avg_val_loss':val_loss}
return {'log':log}
################################
# 如果return的dict中有key='val_loss'会自动出发保存模型
# return {'val_loss':val_loss}
if __name__ == '__main__':
# dataloader 可以放到module中
train_loader = torch.utils.data.DataLoader()
val_loader= torch.utils.data.DataLoader() # shuffle=False
net =Net()
# 快速跑完一个train batch和一个dev batch
# 验证整个流程没错
trainer = pl.Trainer(fast_dev_run=True)
# 完整的训练过程 Trainer() 即可
# train_percent_check=0.1 只训练0.1的数据
trainer.fit(net,
train_dataloader=train_loader,
val_dataloaders=val_loader
)
################################
# argparser 的使用
from argparser import ArgumentParser
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument('--batch_size', default=32, type=int, help='batch size')
parser.add_argument('--learning_rate', default=1e-3, type=float)
args = parser.parse_args()
net = Net()
trainer = pl.Trainer.from_argparse_args(args, fast_dev_run=True)
trainer.fit(net)
################################
# 单GPU训练
# terminal: python main.py --gpus 1 --batch_size 256
# 多GPU训练
# 默认用DP dataparallel 但用DDP更好 distributed DP
# terminal: python main.py --gpus 2 --distributed_backend ddp
################################
# 16 bit 训练 pytorch 1.6 内建 apex
# 可能需要修改一定的代码,比如说Loss函数
# from F.binary_cross_entropy to
# F.binary_cross_entropy_with_logits(y_hat,y,reduction='sum')