import os
import time
import torch
from torch import nn
from torchvision import datasets, transforms
from models.networks.ByVGGnet import VGGNet
batch_size=50
lr=0.0001
save_dir='../save_weight_files/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
def save_network(network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(save_dir, save_filename)
torch.save(network.state_dict(), save_path)
train_transforms = transforms.Compose([
transforms.Scale(227),
#transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dir = '../../data/IDADP-PRCV2019/Atraindataset'
test_dir = '../../data/IDADP-PRCV2019/Atestdataset'
train_datasets = datasets.ImageFolder(train_dir, transform=train_transforms)
train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
test_datasets = datasets.ImageFolder(test_dir, transform=train_transforms)
test_dataloader = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=True)
model = VGGNet().cuda()
model_dict = model.state_dict()
dict_path=torch.load('../save_weight_files/9_net_myVgg16.pth')
pretrained_dict = {k: v for k, v in dict_path.items() if k !='classifier.6.weight'and k !='classifier.6.bias' and k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
lossF = nn.CrossEntropyLoss().cuda()
# params = [{'params': md.parameters()} for md in model.children()
# if md in [model.classifier]]
optimizer = torch.optim.Adam(model.parameters(), lr=lr,betas=(0.9, 0.999))
def Accuracy():
correct = 0
total = 0
with torch.no_grad():
for data in test_dataloader:
images, labels = data
images, labels = images.cuda(), labels.cuda()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
return 100.0 * correct / total
for epoch in range(200):
# model = model.train()
total_accurcy = []
for i, data in enumerate(train_dataloader):
starttime = time.time()
input,label=data
input, label =input.cuda(),label.cuda()
output = model(input)
optimizer.zero_grad()
loss = lossF(output, label)
loss.backward()
optimizer.step()
if i % 50 == 0:
pred_y = torch.max(output, 1)[1]
# print(output)
# print(label)
# print(pred_y)
# print((pred_y == label).sum())
# print(label.size(0))
accuracy = float((pred_y == label).sum()) / float(label.size(0))
total_accurcy.append(accuracy)
avgAccurcy = sum(total_accurcy) / len(total_accurcy)
print("Epoch:%d/%d Batch: %d/%d Time Taken:%d sec ----- loss:%f-- trian accuracy: %.4f--avg accuracy: %.4f--test accuracy: %.4f" % (
epoch, 10, i, len(train_dataloader), time.time() - starttime, loss, accuracy, avgAccurcy,Accuracy()))
if (epoch + 1) % 5 == 0 and (epoch > 0):
save_network(model,'planBVgg16',str(epoch))
2019-08-12
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 知识点 抽象类 abstract 所有的类都是用来描绘对象的,如果一个类中没有包含足够的信息来描绘一个具体的对象...