图像分割之U-net模型训练VOC2007数据集(2)

-代码部分-
data.py

import os

from torch.utils.data import Dataset
from utils import *
from torchvision import transforms
transform=transforms.Compose([transforms.ToTensor()]) #归一化

class MyDataset(Dataset):
    def __init__(self,path): #初始化数据集地址
        self.path=path
        self.name=os.listdir(os.path.join(path,'SegmentationClass')) #获取标签名字 拼接 获取所有文件名

    def __len__(self):
        return len(self.name) #数据集数量

    def __getitem__(self, index):
        segment_name=self.name[index] #原图jpg xx.png
        segment_path=os.path.join(self.path,'SegmentationClass',segment_name)
        image_path=os.path.join(self.path,'JPEGImages',segment_name.replace('png','jpg'))#替换格式
        segment_image=keep_image_size_open(segment_path)
        image=keep_image_size_open(image_path)
        return transform(image),transform(segment_image)

if __name__ =='__main__':
    data=MyDataset('E:\VOCdevkit\VOC2007')
    print(data[0][0].shape) #原图形状
    print(data[0][1].shape)

utils.py

from PIL import Image

def keep_image_size_open(path, size=(256, 256)):
    img = Image.open(path)
    temp = max(img.size) #获取最长边
    mask = Image.new('RGB', (temp, temp), (0, 0, 0)) #黑色源码
    mask.paste(img, (0, 0)) #粘原图到左上角
    mask = mask.resize(size)
    return mask
#等比缩放 防止变形

net.py

import torch
from torch import nn
from torch.nn import functional as F

class Conv_Block(nn.Module): #卷积板块
    def __init__(self,in_channel,out_channel): #输入、输出通道
        super(Conv_Block, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False), #3*3卷积,步长为1,padding为1
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
            nn.Conv2d(out_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)

class DownSample(nn.Module): #下采样
    def __init__(self,channel):
        super(DownSample,self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(channel,channel,3,2,1, padding_mode='reflect',bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)

class UpSample(nn.Module):
    def __init__(self,channel):
        super(UpSample,self).__init__()
        self.layer=nn.Conv2d(channel,channel//2,1,1) #1*1降通道
    def forward(self,x,feature_map):
        up=F.interpolate(x,scale_factor=2,mode='nearest') #插值法
        out=self.layer(up)
        return torch.cat((out,feature_map),dim=1)

class UNet(nn.Module): #定义网络
    def __init__(self):
        super(UNet,self).__init__()
        self.c1=Conv_Block(3,64)
        self.d1=DownSample(64)
        self.c2=Conv_Block(64,128)
        self.d2=DownSample(128)
        self.c3=Conv_Block(128,256)
        self.d3=DownSample(256)
        self.c4=Conv_Block(256,512)
        self.d4=DownSample(512)
        self.c5=Conv_Block(512,1024)
        self.u1=UpSample(1024)
        self.c6=Conv_Block(1024,512)
        self.u2=UpSample(512)
        self.c7=Conv_Block(512,256)
        self.u3=UpSample(256)
        self.c8=Conv_Block(256,128)
        self.u4=UpSample(128)
        self.c9=Conv_Block(128,64)
        self.out=nn.Conv2d(64,3,3,1,1)#输出
        self.Th=nn.Sigmoid() #对像素值进行二分类,激活函数采用Sigmoid

    def forward(self,x):
        R1=self.c1(x)
        R2=self.c2(self.d1(R1))
        R3=self.c3(self.d2(R2))
        R4=self.c4(self.d3(R3))
        R5=self.c5(self.d4(R4))
        O1=self.c6(self.u1(R5,R4))
        O2=self.c7(self.u2(O1,R3))
        O3=self.c8(self.u3(O2,R2))
        O4=self.c9(self.u4(O3,R1))

        return self.Th(self.out(O4))

if __name__ =='__main__':
    x=torch.randn(2,3,256,256)
    net=UNet()
    print(net(x).shape)

train.py

from torch import  nn,optim #优化器
import torch
from torch.utils.data import DataLoader
from data import *
from net import *
from torchvision.utils import save_image

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path='params/unet.pth'
data_path=r'E:\VOCdevkit\VOC2007'
save_path='train_image'

if __name__ =='__main__':
    data_loader=DataLoader(MyDataset(data_path),batch_size=2,shuffle=True) #批次根据内存可调
    net=UNet().to(device)
    if os.path.exists(weight_path):
        print('successful load weight:')
    else:
        print('not successful load weight')

    opt=optim.Adam(net.parameters())
    loss_fun=nn.BCELoss() #定义损失

    epoch=1
    while True: #循环
        for i,(image,segment_image) in enumerate(data_loader):
            image,segment_image=image.to(device),segment_image.to(device)

            out_image=net(image)
            train_loss=loss_fun(out_image,segment_image)

            opt.zero_grad() #清空梯度
            train_loss.backward()
            opt.step() #更新梯度

            if i%5==0: #每隔五批次更新一次权重
                print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')

            if i%50==0: #每隔五十批次保存权重
                torch.save(net.state_dict(),weight_path)

            _image=image[0]
            _segment_image=segment_image[0]
            _out_image=out_image[0]

            img=torch.stack([_image,_segment_image,_out_image],dim=0)
            save_image(img,f'{save_path}/{i}.png')

    epoch+=1

test.py

import os

import torch

from net import *
from unet.utils import keep_image_size_open
from data import *
from torchvision.utils import save_image

net=UNet()

weights='params/unet.pth'
if os.path.exists(weights):
    net.load_state_dict(torch.load(weights))
    print('successfully')
else:
    print('no loading')

_input=input('please input image path:')

img=keep_image_size_open(_input)
img_data=transform(img)
print(img_data.shape)
img_data=torch.unsqueeze(img_data,dim=0) #升维
out=net(img_data)
save_image(out,'result/result.jpg')
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容