Datawhale深入浅出pytorchtask02笔记

pytorch主要组成模块分为

数据读入,模型构建,模型初始化,损失函数,优化器,训练,评估,可视化等

数据读入:分两种,

一,pytorch自带数据集

自带方法

通过torchvision中的datasets来获取

二,自定义方法

自己动手

通过集成pytorch自带的Dataset类和Dataloader来设立自己的数据集

Dataset中主要为__init__,__len__,__getitem__三个函数

__init__用于传入外部参数,定义样本集

__len__用过返回数据集的样本个数

__getitem__用于逐个读出样本集合中的元素并最终返回训练和验证所用的数据集

Dataloader用于按批次读入数据

主要参数:batch_size(每批次的样本数)

num_workers(多少个进程)

shuffle(是否打乱)

drop_last(未满批次数的样本去除)

模型构建主要依赖于pytorch中的nn.Module

集成nn.Module的类,在类中有__init__和forward函数

前者负责定义模型中的所有层,后者定义这些层如何向前传播

一个简单的cnn模型

Alexnet

初始化基于pytorch.nn.init

相关函数

通常会自己去封装一个initialize_weights函数来进行初始化

大概:判断是什么层,如何初始化

初始化函数的封装

pytorch提供了很多种损失函数,种类太多日后再慢慢分析

训练和评估模型通过定义train和val函数来实现,

训练:送入数据,梯度置零,送入模型,计算损失函数,反向传播,用优化器更新参数

测试的不同点在于:

预先设置torch.no_grad,不需要把优化器梯度置零,不需要反向传播和更新参数

图像分类的训练和评估函数
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容