pytorch-lightning 框架初探

项目地址 https://github.com/PyTorchLightning/pytorch-lightning
以下内容整理自项目作者的讲解视频:Converting from PyTorch to PyTorch Lightning (油管视频需梯自备子)

import torch.nn as nn 
import torch  
import torch.optim as optim
import pytorch_lightning as pl

class Net(pl.LightningModule):

    def __init__(self):
        super().__init__()

    def forward(self,x):
        # 可以结合training_step函数,简化forward的内容
        pass

    def loss_func(self, y_hat, y):
        return F.cross_entropy(y_hat, y)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)
    
    def training_step(self, batch, batch_idx):
        x,y = batch #
        y_hat = self(x)
        # return {'loss':F.cross_entropy(y_hat, y)}
        loss = self.loss_func(y_hat, y)
        return {'loss':loss}
        ################################
        # log = {'train_loss':loss}
        # return {'loss':loss, 'log':log}
        # 这样就可以在tensorboard中看到train_loss的曲线

    def log_func(self,):
        # do whatever you want, print, file operation, etc.
        pass

    def validation_step(self, batch, batch_idx):
        # !!! val data 不应该用shuffle
        x,y = batch #
        y_hat = self(x)
        val_loss = self.loss_func(y_hat, y)

        if batch_idx == 0:
            n = x.size(0)
            self.log_func()

        return {'val_loss':val_loss}

    ##############################################################
    ###  这里定义了dataloader fit里就不用通过参数传入了
    ################################
    def train_dataloader(self):
        loader = torch.utils.data.DataLoader()
        return loader

    def val_dataloader(self):
        loader = torch.utils.data.DataLoader()
        return loader

    ################################
    # 使用tensorboard等 logger,  替代validation_step中log_func这一部分
    ################################
    def validation_epoch_end(self, outputs):

        # 计算batch的平均损失,这里的outputs就是validation_step返回的
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        # 也可以传入其他数据,如VAE 重建的图像
        # x_hat = outputs[0]['x_hat']
        # grid = torchvision.utils.make_grid(x_hat)
        # self.logger.experiment 就是 tensorboard SummaryWriter
        self.logger.experiment.add_image('images', grid,0)

        log = {'avg_val_loss':val_loss}
        return {'log':log}
        ################################
        # 如果return的dict中有key='val_loss'会自动出发保存模型
        # return {'val_loss':val_loss}

    
if __name__ == '__main__':

    # dataloader 可以放到module中

    train_loader = torch.utils.data.DataLoader()
    val_loader= torch.utils.data.DataLoader() # shuffle=False
    net =Net()

    # 快速跑完一个train batch和一个dev batch
    # 验证整个流程没错
    trainer = pl.Trainer(fast_dev_run=True) 
    # 完整的训练过程 Trainer() 即可
    # train_percent_check=0.1  只训练0.1的数据
    trainer.fit(net,
                train_dataloader=train_loader,
                val_dataloaders=val_loader
                )

    ################################
    # argparser 的使用

    from argparser import ArgumentParser

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--batch_size', default=32, type=int, help='batch size')
    parser.add_argument('--learning_rate', default=1e-3, type=float)

    args = parser.parse_args()

    net = Net()
    trainer = pl.Trainer.from_argparse_args(args, fast_dev_run=True)
    trainer.fit(net)

    ################################
    # 单GPU训练
    # terminal:  python main.py --gpus 1 --batch_size 256
    # 多GPU训练
    # 默认用DP dataparallel 但用DDP更好 distributed DP
    # terminal:  python main.py --gpus 2 --distributed_backend ddp 

    ################################
    # 16 bit 训练  pytorch 1.6 内建 apex
    # 可能需要修改一定的代码,比如说Loss函数   
    # from F.binary_cross_entropy  to  
    # F.binary_cross_entropy_with_logits(y_hat,y,reduction='sum')

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,837评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,551评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,417评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,448评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,524评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,554评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,569评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,316评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,766评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,077评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,240评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,912评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,560评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,176评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,425评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,114评论 2 366
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,114评论 2 352