pytorch主要组成模块分为
数据读入,模型构建,模型初始化,损失函数,优化器,训练,评估,可视化等
数据读入:分两种,
一,pytorch自带数据集
通过torchvision中的datasets来获取
二,自定义方法
通过集成pytorch自带的Dataset类和Dataloader来设立自己的数据集
Dataset中主要为__init__,__len__,__getitem__三个函数
__init__用于传入外部参数,定义样本集
__len__用过返回数据集的样本个数
__getitem__用于逐个读出样本集合中的元素并最终返回训练和验证所用的数据集
Dataloader用于按批次读入数据
主要参数:batch_size(每批次的样本数)
num_workers(多少个进程)
shuffle(是否打乱)
drop_last(未满批次数的样本去除)
模型构建主要依赖于pytorch中的nn.Module
集成nn.Module的类,在类中有__init__和forward函数
前者负责定义模型中的所有层,后者定义这些层如何向前传播
初始化基于pytorch.nn.init
通常会自己去封装一个initialize_weights函数来进行初始化
大概:判断是什么层,如何初始化
pytorch提供了很多种损失函数,种类太多日后再慢慢分析
训练和评估模型通过定义train和val函数来实现,
训练:送入数据,梯度置零,送入模型,计算损失函数,反向传播,用优化器更新参数
测试的不同点在于:
预先设置torch.no_grad,不需要把优化器梯度置零,不需要反向传播和更新参数