[代码解读]U-Net++ Pytorch

代码:
https://github.com/4uiiurz1/pytorch-nested-unet
文件包括:utils.py,preprocess_dsb2018.py,dataset.py,train.py, archs.py,losses.py,metrics.py,test.py

  • utils.py

def str2bool: => tru = 1; false = 0
def count_params(): =>计算训练参数量

  • preprocess_dsb2018.py

数据预处理
创建好image,mask所在目录,规范image channel = 3,resize image & mask,将千奇百态的图片按上述处理后再重命名保存。

  • dataset.py

将image,mask [0,255]的范围归一化为[0,1]。
如果进行数据增强,则将部分图片左右翻转或上下翻转,并将(h,w,c) =>(c,h,w)

  • train.py

用训练集训练模型,并在验证集的约束下优化模型并保存。

  • test.py

加载最优模型对测试集预测,保存分割的mask,并用相关指标评估。


preprocess_dsb2018.py

  • glob.glob()
    返回所有匹配的文件路径列表(list);该方法需要一个参数用来指定匹配的路径字符串(字符串可以为绝对路径也可以为相对路径),其返回的文件名只包括当前目录里的文件名,不包括子文件夹里的文件。

  • tqdm
    Tqdm 是 Python 进度条库,可以在 Python 长循环中添加一个进度提示信息用法:tqdm(iterator)

  • skimage.io.imread
    io.imread读出图片格式是uint8(unsigned int);value是numpy array;图像数据是以RGB的格式进行存储的,通道值默认范围0-255。(height,width, channel)
    skimage图片信息
from skimage import io, data
img = data.chelsea()
io.imshow(img)
print(type(img))  #显示类型
print(img.shape)  #显示尺寸
print(img.shape[0])  #图片高度
print(img.shape[1])  #图片宽度
print(img.shape[2])  #图片通道数
print(img.size)   #显示总像素个数
print(img.max())  #最大像素值
print(img.min())  #最小像素值
print(img.mean()) #像素平均值
print(img[0][0])  #图像的像素值

image.shape[0], 图片垂直尺寸
image.shape[1], 图片水平尺寸
image.shape[2], 图片通道数

  • len()
    返回字符串、列表、字典、元组等长度
    if (len(image.shape) == 2):
    如果图片只有height,width,没有channel

  • np.tile()
    函数形式: tile(A,rep)
    功能:重复A的各个维度
    参数类型:
    A: Array类的都可以
    rep:A沿着各个维度重复的次数

  • skimage.io.imsave
    使用io模块的imsave(fname,arr)函数来实现。第一个参数表示保存的路径和名称,第二个参数表示需要保存的数组变量。
    保存图片的同时也起到了转换格式的作用。如果读取时图片格式为jpg图片,保存为png格式,则将图片从jpg图片转换为png图片并保存。

  • os.path.basename()
    返回path最后的文件名。若path以/或\结尾,那么就会返回空值。

path='D:\CSDN'
os.path.basename(path)=CSDN
path='/root/runoob.txt'
os.path.basename(path)=runoob.txt
  • Numpy的布尔索引与花式索引
for mask_path in glob(path+'/masks/*'):
     mask_ = imread(mask_path) > 127
     mask[mask_] = 1
  • imread(mask_path) 是numpy.ndarray,(h,w,c),范围是[0,255],0为黑色,255为白色,
  • mask_ 数组中 > 127(白色)的元素记为ture,否则记为false.

  • cv2.resize()
    cv2.resize(src,dsize,dst=None,fx=None,fy=None,interpolation=None)
    • scr:原图
    • dsize:输出图像尺寸
    • fx:沿水平轴的比例因子
    • fy:沿垂直轴的比例因子
    • interpolation:插值方法
      只会resize原图像的水平方向尺寸和垂直方向尺寸,不会对channel有影响。
      详见:
      cv2.resize()

train.py

  • .__ dict __
    通俗的理解:每个参数,变量,对象都是以字典的形式存储,每一个key对应一个value
    __ dict __.png

def main()

  • argparse.ArgumentParser()

    • choices - 设置参数值的范围,如果choices中的类型不是字符串,要指定type。#parser.add_argument(“-y”, choices=[‘a’, ‘b’, ‘d’])
    • metavar - 参数的名字,在显示 帮助信息时才用到. # parser.add_argument(“-o”, metavar=”OOOOOO”)
      更多见--python3中argparse模块详解
  • vars()
    返回对象object的属性和属性值的字典对象。

  • getattr()
    getattr():从名字上看获取属性值.

class Person():
    age = 14
Tom = Person()
print(getattr(Tom,'age'))

此时的结果为14,
若,该属性不存在

getattr(Tom,'name')
AttributeError: 'Person' object has no attribute 'name'
  • train_test_split()
    • 所在包:sklearn.model_selection
    • 功能:划分数据的训练集与测试集
    • 参数解读:train_test_split (*arrays,test_size, train_size, rondom_state=None, shuffle=True, stratify=None)
    • arrays:特征数据和标签数据(array,list,dataframe等类型),要求所有数据长度相同。
    • test_size / train_size: 测试集/训练集的大小,若输入小数表示比例,若输入整数表示数据个数。
    • rondom_state:随机种子(一个整数),其实就是一个划分标记,对于同一个数据集,如果rondom_state相同,则划分结果也相同。
    • shuffle:是否打乱数据的顺序,再划分,默认True。
    • stratify:none或者array/series类型的数据,表示按这列进行分层采样。
xtrain,xtest,ytrain,ytest=train_test_split(data,label,test_size=0.2,stratify=data['a'],random_state=1)
  • pytorch固定部分参数进行网络训练
if args.optimizer == 'Adam':
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
elif args.optimizer == 'SGD':
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
                momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)

详见:
pytorch固定部分参数进行网络训练
pytorch固定部分参数进行网络训练

class AverageMeter

计算并存储平均值和当前值

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def train

  • .cuda()
input = input.cuda()

x.cuda()操作将 x 变为Tensor类型


cuda = >Tensor
  • python 默认参数
def enroll(name, gender, age=6, city='Beijing'):
    print('name:', name)
    print('gender:', gender)
    print('age:', age)
    print('city:', city)

只有与默认参数不符的学生才需要提供额外的信息:
enroll('Bob', 'M', 7)
enroll('Adam', 'M', city='Tianjin')

  • losses = AverageMeter()
    实例化AverageMeter(),调用init初始化函数

  • losses.update(loss.item(), input.size(0))
    input.size()覆盖默认参数n=1
    input经过.cuda后为Tensor类型,input [bacth_size, c, w, h]


dataset.py

from skimage.io import imread

image = imread(img_path)
mask = imread(mask_path)

image = image.astype('float32') / 255
mask = mask.astype('float32') / 255

if self.aug:
    if random.uniform(0, 1) > 0:
        image = image[:, ::-1, :].copy()
        mask = mask[:, ::-1].copy()
    if random.uniform(0, 1) > 0.5:
        image = image[::-1, :, :].copy()
        mask = mask[::-1, :].copy()

image = image.transpose((2, 0, 1))
mask = mask[:,:,np.newaxis]
mask = mask.transpose((2, 0, 1))
  • image = imread(img_path)

    • imread的读取的image结果为numpy.ndarray, type为('uint8'),image里都是整数uint8,范围[0-255]
    • 假设image为(300, 400, 3) (h,w,c),channel顺序为RGB
      array([
      [ [143, 198, 201 (dim=3)],[143, 198, 201],... (w=200)],
      [ [143, 198, 201],[143, 198, 201],... ], ...(h=100) ], dtype=uint8)
  • image = image.astype('float32') / 255
    将image的type改为float32,并把数据范围缩小到[0,1]

  • ** image = image[:, ::-1, :].copy()**
    image = image[:, ::-1, :] 表示将图像向右翻转180°
    image = image[::-1,: , :]表示将图像向下翻转180°

  • image = image.transpose((2, 0, 1))
    image (h,w,c) transpose为(c,w,h)


utils.py

def str2bool(v):
    if v.lower() in ['true', 1]:
        return True
    elif v.lower() in ['false', 0]:
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
  • count_params()
    计算模型训练参数量

  • pytorch 获取模型参数量

# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')

total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

arch.py

有Unet和Unet++


losses.py

一些计算losses的函数


metrics.py

评估函数


参考链接:
image.shap--yournevermore
skimage.io.imread与cv2.imread的区别
python 处理图像的常见操作-- 平移缩放裁剪
skimage图像处理
skimage.io.imread API
numpy模块的tile()方法简单说明
Python os.path() 模块
os.path.basename()作用
python之getattr()函数
train_test_split数据集分割
OpenCV、Skimage、PIL图像处理的细节差异
Numpy的布尔索引与花式索引
image = image[:, ::-1, :]的含义是什么
python 矩阵转置transpose
pytorch 获取模型参数量

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,076评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,658评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,732评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,493评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,591评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,598评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,601评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,348评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,797评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,114评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,278评论 1 344
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,953评论 5 339
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,585评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,202评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,442评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,180评论 2 367
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,139评论 2 352