pytorch学习(十八)—预训练模型微调

训练结果

image.png
image.png
image.png
image.png
image.png
image.png
image.png

完整工程

  • 工程目录结构


    image.png
  • 代码

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import copy


# ---------------------------------------------------------
# 载入预训练的AlexNet模型
model = models.alexnet(pretrained=True)
# 修改输出层,2分类
model.classifier[6] = nn.Linear(in_features=4096, out_features=2)


# -------------------------数据集----------------------------------------------------

transform = transforms.Compose([transforms.Resize((227,227)),
                                transforms.ToTensor()])

train_dataset = ImageFolder(root='./data/train', transform=transform)
val_dataset = ImageFolder(root='./data/val', transform=transform)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, num_workers=4, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)


# ------------------优化方法,损失函数--------------------------------------------------
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fc = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, 20, 0.1)


# --------------------判断是否支持GPU--------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

# -------------------训练-------------------------------------------------------------

epoch_nums = 50
best_model_wts = model.state_dict()
best_acc = 0
for epoch in range(epoch_nums):
    scheduler.step()
    running_loss = 0.0
    epoch_loss = 0.0
    correct = 0
    total = 0

    for i, sample_batch in enumerate(train_dataloader):
        inputs = sample_batch[0]
        labels = sample_batch[1]

        inputs.to(device)
        labels.to(device)

        model.train()
        optimizer.zero_grad()
        # forward
        outputs = model(inputs)
        # loss
        loss = loss_fc(outputs, labels)

        loss.backward()
        optimizer.step()

        #
        running_loss += loss.item()
        if i % 10 == 9:
            correct = 0
            total = 0
            for images_test, labels_test in val_dataloader:
                model.eval()
                images_test = images_test.to(device)
                labels_test = labels_test.to(device)
                outputs_test = model(images_test)
                _, prediction = torch.max(outputs_test, 1)
                correct += ((prediction == labels_test).sum()).item()
                total += labels_test.size(0)
            accuracy = correct/total
            print('[{}, {}] running loss={:.5f}, accuracy={:.5f}'.format(epoch + 1, i + 1, running_loss/10, accuracy))
            running_loss = 0.0
            if accuracy > best_acc:
                best_acc = accuracy
                best_model_wts = copy.deepcopy(model.state_dict())


print('Train finish')
torch.save(best_model_wts, './models/model_50.pth')
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容