图像分类学习(3):X光胸片诊断识别——迁移学习

1、数据介绍

  • 数据源于kaggle,可在此链接自行下载
  • 数据集分为3个文件夹(train,test,val),并包含每个图像类别(Pneumonia / Normal)的子文件夹。有5,863个X射线图像(JPEG)和2个类别(肺炎/正常)

2、迁移学习

由于从头训练一个神经网络需要花费的时间较长,而且对数据量的要求也比较大。在实践中,很少有人从头开始训练整个卷积网络(随机初始化),因为拥有足够大小的数据集相对来说比较少见。相反,pytorch中提供的这些模型都已经预先在1000类的Imagenet数据集上训练完成。可以直接拿来训练自己的数据集,即称为模型微调或者迁移学习。
迁移学习包含微调和特征提取。 在微调中,我们从一个预训练模型开始,然后为我们的新任务更新所有的模型参数,实质上就是重新训练整个模型。 在特征提取中,我们从预训练模型开始,只更新产生预测的最后一层的权重。它被称为特征提取是因为我们使用预训练的CNN作为固定的特征提取器,并且仅改变输出层。
本次图像分类只对模型进行特征提取,即更改最后一个全连接层,然后进行模型训练

3、建立模型

3.1 了解数据

首先先导入需要用到的模块,定义一下基本参数。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

data_dir = './cv/chest_xray'
model_name = 'vgg'
num_classes = 2
batch_size = 16
num_epochs = 10
input_size = 224
device = torch.device('cuda' if torch.cuda.is_available() else 'gpu')

进行一系列数据增强,然后生成训练、验证、和测试数据集。

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}

print("Initializing Datasets and Dataloaders...")


image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val', 'test']}

dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val', 'test']}

定义一个查看图片和标签的函数

def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) 

imgs, labels = next(iter(dataloaders_dict['train']))


out = torchvision.utils.make_grid(imgs[:8])

imshow(out, title=[classes[x] for x in labels[:8]])

OUT :


可以看到图片有两个类别的X光片,PNEUMONIA(肺炎),NORMAL(正常)

classes = image_datasets['test'].classes
classes

OUT :
['NORMAL', 'PNEUMONIA']

3.2 建立VGG16迁移学习模型

model = torchvision.models.vgg16(pretrained=True)
model

pretrained=True,则会下载预训练权重,需要耐心等待一段时间。
查看一下VGG16的结构:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

可以看到VGG16主要由features和classifier两种结构组成,classifier[6]为最后一层,我们将它的输出改为我们的类别数2。由于我们只需要训练最后一层,再改之前我们先将模型的参数设置为不可更新。

# 先将模型参数改为不可更行
for param in model.parameters():
    param.requires_grad = False
# 再更改最后一层的输出,至此网络只能更行该层参数
model.classifier[6] = nn.Linear(4096, num_classes)

3.3 定义训练函数

def train_model(model, dataloaders, criterion, optimizer, mun_epochs=25):
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs-1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0.0
            
            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += (preds == labels).sum().item()
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects / len(dataloaders[phase].dataset)
            
            print('{} loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:.4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model

3.4 定义优化器和损失函数

model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

3.5 开始训练

model_ft = train_model(model, dataloaders_dict, criterion, optimizer, num_epochs)

OUT :

Epoch 0/9
----------
train loss: 0.4240 Acc: 0.8171
val loss: 0.2968 Acc: 0.8125

Epoch 1/9
----------
train loss: 0.4095 Acc: 0.8284
val loss: 0.1901 Acc: 0.9375

Epoch 2/9
----------
train loss: 0.3972 Acc: 0.8424
val loss: 0.2445 Acc: 0.9375

Epoch 3/9
----------
train loss: 0.4145 Acc: 0.8315
val loss: 0.1973 Acc: 0.9375

Epoch 4/9
----------
train loss: 0.4012 Acc: 0.8416
val loss: 0.1253 Acc: 1.0000

Epoch 5/9
----------
train loss: 0.3976 Acc: 0.8489
val loss: 0.1904 Acc: 0.9375

Epoch 6/9
----------
train loss: 0.4025 Acc: 0.8432
val loss: 0.1527 Acc: 1.0000

Epoch 7/9
----------
train loss: 0.3768 Acc: 0.8495
val loss: 0.1761 Acc: 1.0000

Epoch 8/9
----------
train loss: 0.3906 Acc: 0.8472
val loss: 0.1346 Acc: 1.0000

Epoch 9/9
----------
train loss: 0.3847 Acc: 0.8403
val loss: 0.0996 Acc: 1.0000

Training complete in 7m 16s
Best val Acc: 1.0000

经过10轮训练,验证集的准确率已达到100%,由于验证集的图片很少(大概只有十多张),可能不太能说明网络的训练效果。接下来看一下模型在测试集表现如何。

3.6 测试集评估

首先我们先拿出10张X光片给模型进行判断,看看它能否准确预测出X光片的类别。

imgs, labels = next(iter(dataloaders_dict['test']))
imgs, labels = imgs.to(device), labels.to(device)
outputs = model_ft(imgs)
_, preds = torch.max(outputs, 1)
print('real:' + ' '.join('%9s' % classes[labels[j]] for j in range(10)))
print('pred:' + ' '.join('%9s' % classes[preds[j]] for j in range(10)))

OUT :

real:   NORMAL PNEUMONIA PNEUMONIA    NORMAL    NORMAL    NORMAL    NORMAL PNEUMONIA PNEUMONIA PNEUMONIA
pred:PNEUMONIA PNEUMONIA PNEUMONIA PNEUMONIA    NORMAL    NORMAL    NORMAL PNEUMONIA PNEUMONIA PNEUMONIA

十张X片,其中第一张和第四张错误的将正常预测成了肺炎,其他八张预测正确。有80%的准确率,最后我们查看一下在全部的测试集中的准确率是否有80%左右。

correct = 0.0
for imgs, labels in dataloaders_dict['test']:
    imgs, labels = imgs.to(device), labels.to(device)
    output = model_ft(imgs)
    _, preds = torch.max(output, 1)
    correct += (preds == labels).sum().item()
print('test accuracy:{:.2f}%'.format(100 * correct / len(dataloaders_dict['test'].dataset)))

OUT :
test accuracy:82.53%

4、总结

模型预测的准确率为82.53%,并没有想象中的高。
若想进一步提升模型准确率,我觉得可以在以下几个方面改进:

  • 换一个带批标准化的VGG模型或直接换一个更强大的模型,如ResNet。
  • 减小学习率,或者训练时进行学习率衰减。
  • 尝试增加epoch,更改batch_size等。
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,657评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,662评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,143评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,732评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,837评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,036评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,126评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,868评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,315评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,641评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,773评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,470评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,126评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,859评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,095评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,584评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,676评论 2 351