2019-08-12

import os
import time
import torch
from torch import nn
from torchvision import datasets, transforms
from models.networks.ByVGGnet import VGGNet


batch_size=50
lr=0.0001
save_dir='../save_weight_files/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

def save_network(network, network_label, epoch_label):
    save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
    save_path = os.path.join(save_dir, save_filename)
    torch.save(network.state_dict(), save_path)

train_transforms = transforms.Compose([
        transforms.Scale(227),
        #transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
train_dir = '../../data/IDADP-PRCV2019/Atraindataset'
test_dir = '../../data/IDADP-PRCV2019/Atestdataset'
train_datasets = datasets.ImageFolder(train_dir, transform=train_transforms)
train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
test_datasets = datasets.ImageFolder(test_dir, transform=train_transforms)
test_dataloader = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=True)

model = VGGNet().cuda()
model_dict = model.state_dict()
dict_path=torch.load('../save_weight_files/9_net_myVgg16.pth')
pretrained_dict = {k: v for k, v in dict_path.items() if k !='classifier.6.weight'and k !='classifier.6.bias' and k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
lossF = nn.CrossEntropyLoss().cuda()
# params = [{'params': md.parameters()} for md in model.children()
#           if md in [model.classifier]]
optimizer = torch.optim.Adam(model.parameters(), lr=lr,betas=(0.9, 0.999))

def Accuracy():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_dataloader:
            images, labels = data
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    # print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
    return 100.0 * correct / total


for epoch in range(200):
    # model = model.train()
    total_accurcy = []
    for i, data in enumerate(train_dataloader):
        starttime = time.time()
        input,label=data
        input, label =input.cuda(),label.cuda()
        output = model(input)
        optimizer.zero_grad()
        loss = lossF(output, label)
        loss.backward()
        optimizer.step()
        if i % 50 == 0:
            pred_y = torch.max(output, 1)[1]
            # print(output)
            # print(label)
            # print(pred_y)
            # print((pred_y == label).sum())
            # print(label.size(0))
            accuracy = float((pred_y == label).sum()) / float(label.size(0))
            total_accurcy.append(accuracy)
            avgAccurcy = sum(total_accurcy) / len(total_accurcy)
            print("Epoch:%d/%d Batch: %d/%d  Time Taken:%d sec ----- loss:%f-- trian accuracy: %.4f--avg accuracy: %.4f--test accuracy: %.4f" % (
                    epoch, 10, i, len(train_dataloader), time.time() - starttime, loss, accuracy, avgAccurcy,Accuracy()))
    if (epoch + 1) % 5 == 0 and (epoch > 0):
        save_network(model,'planBVgg16',str(epoch))
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容