代码:
https://github.com/4uiiurz1/pytorch-nested-unet
文件包括:utils.py,preprocess_dsb2018.py,dataset.py,train.py, archs.py,losses.py,metrics.py,test.py
-
utils.py
def str2bool: => tru = 1; false = 0
def count_params(): =>计算训练参数量
-
preprocess_dsb2018.py
数据预处理
创建好image,mask所在目录,规范image channel = 3,resize image & mask,将千奇百态的图片按上述处理后再重命名保存。
-
dataset.py
将image,mask [0,255]的范围归一化为[0,1]。
如果进行数据增强,则将部分图片左右翻转或上下翻转,并将(h,w,c) =>(c,h,w)
-
train.py
用训练集训练模型,并在验证集的约束下优化模型并保存。
-
test.py
加载最优模型对测试集预测,保存分割的mask,并用相关指标评估。
preprocess_dsb2018.py
glob.glob()
返回所有匹配的文件路径列表(list);该方法需要一个参数用来指定匹配的路径字符串(字符串可以为绝对路径也可以为相对路径),其返回的文件名只包括当前目录里的文件名,不包括子文件夹里的文件。tqdm
Tqdm 是 Python 进度条库,可以在 Python 长循环中添加一个进度提示信息用法:tqdm(iterator)
-
skimage.io.imread
io.imread读出图片格式是uint8(unsigned int);value是numpy array;图像数据是以RGB的格式进行存储的,通道值默认范围0-255。(height,width, channel)
skimage图片信息
from skimage import io, data
img = data.chelsea()
io.imshow(img)
print(type(img)) #显示类型
print(img.shape) #显示尺寸
print(img.shape[0]) #图片高度
print(img.shape[1]) #图片宽度
print(img.shape[2]) #图片通道数
print(img.size) #显示总像素个数
print(img.max()) #最大像素值
print(img.min()) #最小像素值
print(img.mean()) #像素平均值
print(img[0][0]) #图像的像素值
image.shape[0], 图片垂直尺寸
image.shape[1], 图片水平尺寸
image.shape[2], 图片通道数
len()
返回字符串、列表、字典、元组等长度
if (len(image.shape) == 2):
如果图片只有height,width,没有channelnp.tile()
函数形式: tile(A,rep)
功能:重复A的各个维度
参数类型:
A: Array类的都可以
rep:A沿着各个维度重复的次数skimage.io.imsave
使用io模块的imsave(fname,arr)函数来实现。第一个参数表示保存的路径和名称,第二个参数表示需要保存的数组变量。
保存图片的同时也起到了转换格式的作用。如果读取时图片格式为jpg图片,保存为png格式,则将图片从jpg图片转换为png图片并保存。os.path.basename()
返回path最后的文件名。若path以/或\结尾,那么就会返回空值。
path='D:\CSDN'
os.path.basename(path)=CSDN
path='/root/runoob.txt'
os.path.basename(path)=runoob.txt
- Numpy的布尔索引与花式索引
for mask_path in glob(path+'/masks/*'):
mask_ = imread(mask_path) > 127
mask[mask_] = 1
- imread(mask_path) 是numpy.ndarray,(h,w,c),范围是[0,255],0为黑色,255为白色,
- mask_ 数组中 > 127(白色)的元素记为ture,否则记为false.
- mask[mask_] = 1将mask_ 元素为1的的地方赋值为1。
详见:
Numpy的布尔索引与花式索引
-
cv2.resize()
cv2.resize(src,dsize,dst=None,fx=None,fy=None,interpolation=None)- scr:原图
- dsize:输出图像尺寸
- fx:沿水平轴的比例因子
- fy:沿垂直轴的比例因子
- interpolation:插值方法
只会resize原图像的水平方向尺寸和垂直方向尺寸,不会对channel有影响。
详见:
cv2.resize()
train.py
-
.__ dict __
通俗的理解:每个参数,变量,对象都是以字典的形式存储,每一个key对应一个value
def main()
-
argparse.ArgumentParser()
- choices - 设置参数值的范围,如果choices中的类型不是字符串,要指定type。#parser.add_argument(“-y”, choices=[‘a’, ‘b’, ‘d’])
- metavar - 参数的名字,在显示 帮助信息时才用到. # parser.add_argument(“-o”, metavar=”OOOOOO”)
更多见--python3中argparse模块详解
vars()
返回对象object的属性和属性值的字典对象。getattr()
getattr():从名字上看获取属性值.
class Person():
age = 14
Tom = Person()
print(getattr(Tom,'age'))
此时的结果为14,
若,该属性不存在
getattr(Tom,'name')
AttributeError: 'Person' object has no attribute 'name'
-
train_test_split()
- 所在包:sklearn.model_selection
- 功能:划分数据的训练集与测试集
- 参数解读:train_test_split (*arrays,test_size, train_size, rondom_state=None, shuffle=True, stratify=None)
- arrays:特征数据和标签数据(array,list,dataframe等类型),要求所有数据长度相同。
- test_size / train_size: 测试集/训练集的大小,若输入小数表示比例,若输入整数表示数据个数。
- rondom_state:随机种子(一个整数),其实就是一个划分标记,对于同一个数据集,如果rondom_state相同,则划分结果也相同。
- shuffle:是否打乱数据的顺序,再划分,默认True。
- stratify:none或者array/series类型的数据,表示按这列进行分层采样。
xtrain,xtest,ytrain,ytest=train_test_split(data,label,test_size=0.2,stratify=data['a'],random_state=1)
- pytorch固定部分参数进行网络训练
if args.optimizer == 'Adam':
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
elif args.optimizer == 'SGD':
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
详见:
pytorch固定部分参数进行网络训练
pytorch固定部分参数进行网络训练
class AverageMeter
计算并存储平均值和当前值
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def train
- .cuda()
input = input.cuda()
x.cuda()操作将 x 变为Tensor类型
- python 默认参数
def enroll(name, gender, age=6, city='Beijing'):
print('name:', name)
print('gender:', gender)
print('age:', age)
print('city:', city)
只有与默认参数不符的学生才需要提供额外的信息:
enroll('Bob', 'M', 7)
enroll('Adam', 'M', city='Tianjin')
losses = AverageMeter()
实例化AverageMeter(),调用init初始化函数losses.update(loss.item(), input.size(0))
input.size()覆盖默认参数n=1
input经过.cuda后为Tensor类型,input [bacth_size, c, w, h]
dataset.py
from skimage.io import imread
image = imread(img_path)
mask = imread(mask_path)
image = image.astype('float32') / 255
mask = mask.astype('float32') / 255
if self.aug:
if random.uniform(0, 1) > 0:
image = image[:, ::-1, :].copy()
mask = mask[:, ::-1].copy()
if random.uniform(0, 1) > 0.5:
image = image[::-1, :, :].copy()
mask = mask[::-1, :].copy()
image = image.transpose((2, 0, 1))
mask = mask[:,:,np.newaxis]
mask = mask.transpose((2, 0, 1))
-
image = imread(img_path)
- imread的读取的image结果为numpy.ndarray, type为('uint8'),image里都是整数uint8,范围[0-255]
- 假设image为(300, 400, 3) (h,w,c),channel顺序为RGB
array([
[ [143, 198, 201 (dim=3)],[143, 198, 201],... (w=200)],
[ [143, 198, 201],[143, 198, 201],... ], ...(h=100) ], dtype=uint8)
image = image.astype('float32') / 255
将image的type改为float32,并把数据范围缩小到[0,1]** image = image[:, ::-1, :].copy()**
image = image[:, ::-1, :] 表示将图像向右翻转180°
image = image[::-1,: , :]表示将图像向下翻转180°image = image.transpose((2, 0, 1))
image (h,w,c) transpose为(c,w,h)
utils.py
def str2bool(v):
if v.lower() in ['true', 1]:
return True
elif v.lower() in ['false', 0]:
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_params()
计算模型训练参数量pytorch 获取模型参数量
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
arch.py
有Unet和Unet++
losses.py
一些计算losses的函数
metrics.py
评估函数
参考链接:
image.shap--yournevermore
skimage.io.imread与cv2.imread的区别
python 处理图像的常见操作-- 平移缩放裁剪
skimage图像处理
skimage.io.imread API
numpy模块的tile()方法简单说明
Python os.path() 模块
os.path.basename()作用
python之getattr()函数
train_test_split数据集分割
OpenCV、Skimage、PIL图像处理的细节差异
Numpy的布尔索引与花式索引
image = image[:, ::-1, :]的含义是什么
python 矩阵转置transpose
pytorch 获取模型参数量