复现:
1.从GitHub下载源码以后,按照README 配置U-net的环境。
这里给出GitHub源码链接:pytorch Unet

配置环境
requirements.txt文件在源码文件里
2.配置完成后进行训练,但是我们没有数据集
解决方法:
用来测试汽车数据集
提取码:upb3
自己制作数据集的需要注意,放在data里的imgs和masks,注意image文件命名如果是1.+后缀,对应mask文件就需要是1._mask+后缀。(网盘给出的数据集已经满足要求,直接放在对应imgs和masks目录下即可)
3.在train.py文件中修改
修改channel,如果是RGB图像,channel=3,如果是灰度图,channel=1;
修改classes,就是背景+你的数据集里有几个类别;比如我给的那个数据集有汽车和背景两个类,那么classes=2

修改参数
4.开始训练
python train.py
可以自己调整训练参数

训练图
训练完成,模型会保存在checkpoints路径下
5.进行预测
在predict.py文件,修改参数,注意路径问题,下图所示default参数里 前面还要加上/checkpoints

修改路径
将需要预测的图片放到主目录里,
-i 是指定预测的照片 其他参数可以自己看get_args部分
python predict.py -i 1.jpg -o output.jpg

预测效果