概念
迁移学习简单来说就是使用别人已经训练好的模型的参数,并根据需求修改模型。比如vgg模型默认是输入一张三通道的图,并在最后一层输出包含1000个特征数的分类结果,假如我们的数据集特征只有10个,那么只需要把其模型的最后一层输出从1000改成10,然后将模型中除了最后一层的网络层都关闭权值更新,此时就可以将我们的训练集放入训练了。由于前面的层数都已经训练的很好了,此时基本就只需要训练模型的最后一层即可,从而达到消耗少、且收敛快的目的
简单示例
下面导入了torchvision
中提供的vgg
预训练模型,我们先取消模型的权重更新,并查看模型的结构:
import torch
from torch import nn
import torchvision
from torchvision import datasets
vgg = torchvision.models.vgg16(pretrained=True)
# 载入预训练好的vgg16模型
for param in vgg.parameters():
param.requires_grad = False
# 取消所有层的权值更新
print(vgg)
# VGG(
# (features): Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# ...
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# ...
# (classifier): Sequential(
# (0): Linear(in_features=25088, out_features=4096, bias=True)
# ...
# (6): Linear(in_features=4096, out_features=1000, bias=True)
# )
# )
可以看到在模型的最后是一个序列模型,而序列模型的最后一层是一个输出为1000个特征的全连接层,此时我们如果特征只有10个,那么要修改的就只有这最后一层。而最后一层的属性名是classifier
,所以我们的修改代码如下:
vgg.classifier[-1] = nn.Linear(4096, 10)
# 修改classifier层的最后一层
至此,我们就完成了对模型的修改,接下来直接训练就可以了~
更多示例参考
https://ptorch.com/news/138.html
https://blog.csdn.net/weixin_43845931/article/details/89304733