本代码由ipynb文件转换为py文件,所以有的地方有改动
- 首先导入各种包
from __future__ import division, print_function
# 为了兼容python2而导入print_function 这样即使在python2也得按照python3的输出格式
# division为精算除法 如3/4=0.75 在python3中这都是默认的
# get_ipython().magic('matplotlib inline')
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from tf_unet import image_gen
from tf_unet import unet
from tf_unet import util
- 画图的设置以及设置随机数种子
plt.rcParams['image.cmap'] = 'gist_earth'
np.random.seed(98765)
- 设置图片尺寸以及建立生成随机数据集的类的一个实例
nx = 572
ny = 572
generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20)
x_test, y_test = generator(1)
展开
3.1 查看GrayScaleDataProvider这个类
class GrayScaleDataProvider(BaseDataProvider):
channels = 1
n_class = 2
def __init__(self, nx, ny, **kwargs):
super(GrayScaleDataProvider, self).__init__()
self.nx = nx
self.ny = ny
self.kwargs = kwargs
rect = kwargs.get("rectangles", False)
if rect:
self.n_class=3
def _next_data(self):
return create_image_and_label(self.nx, self.ny, **self.kwargs)
3.1.1 查看父级类BaseDataProvider
class BaseDataProvider(object):
"""
Abstract base class for DataProvider implementation. Subclasses have to
overwrite the `_next_data` method that load the next data and label array.
This implementation automatically clips the data with the given min/max and
normalizes the values to (0,1]. To change this behavoir the `_process_data`
method can be overwritten. To enable some post processing such as data
augmentation the `_post_process` method can be overwritten.
:param a_min: (optional) min value used for clipping
:param a_max: (optional) max value used for clipping
"""
channels = 1
n_class = 2
def __init__(self, a_min=None, a_max=None):
self.a_min = a_min if a_min is not None else -np.inf
self.a_max = a_max if a_min is not None else np.inf
def _load_data_and_label(self):
data, label = self._next_data()
train_data = self._process_data(data)
labels = self._process_labels(label)
train_data, labels = self._post_process(train_data, labels)
nx = train_data.shape[1]
ny = train_data.shape[0]
return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class),
def _process_labels(self, label):
if self.n_class == 2:
nx = label.shape[1]
ny = label.shape[0]
labels = np.zeros((ny, nx, self.n_class), dtype=np.float32)
labels[..., 1] = label
labels[..., 0] = ~label
return labels
return label
def _process_data(self, data):
# normalization
data = np.clip(np.fabs(data), self.a_min, self.a_max)
data -= np.amin(data)
if np.amax(data) != 0:
data /= np.amax(data)
return data
def _post_process(self, data, labels):
"""
Post processing hook that can be used for data augmentation
:param data: the data array
:param labels: the label array
"""
return data, labels
def __call__(self, n):
train_data, labels = self._load_data_and_label() #增加了一个维度的单张训练图片,以及增加了一个维度的labels,labels里面包含label以及~label
#这里的train_data,和labels,每个train_data[i,...],和labels[i,...]都代表了一张图和对应的标签label和~label
nx = train_data.shape[1]
ny = train_data.shape[2]
X = np.zeros((n, nx, ny, self.channels))
Y = np.zeros((n, nx, ny, self.n_class))
X[0] = train_data
Y[0] = labels
for i in range(1, n):
train_data, labels = self._load_data_and_label() #导入下一个图片和标签
X[i] = train_data
Y[i] = labels
# 该方法返回张量X Y 存储了n-1个图片和对应的标签
return X, Y
其中call()方法使得实例对象变得可以调用
聚焦该方法,第一行调用了_load_data_and_label()
def _load_data_and_label(self):
data, label = self._next_data() #生成单张图片以及对应的标签
train_data = self._process_data(data) # 将图片每个像素得灰度值归一化
labels = self._process_labels(label) # 将布尔类型的lable放在lablels的第一个通道,将~lable放在labels的第0个通道
train_data, labels = self._post_process(train_data, labels) #pass 这个钩子函数在这个模块里还没编辑,但是其它模块里面编辑了哦
nx = train_data.shape[1]
ny = train_data.shape[0]
return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class),
# 输出 增加一个维度的单张图片(tran_data),和增加一个维度的lebels
这里又调用了_next_data()方法,但是该方法不在父级类BaseDataProvider中,
在子类GrayScaleDataProvider中有该方法
def _next_data(self):
return create_image_and_label(self.nx, self.ny, **self.kwargs)
查看create_image_and_label()方法
该方法可以生成单张图像和标签
def create_image_and_label(nx,ny, cnt = 10, r_min = 5, r_max = 50, border = 92, sigma = 20, rectangles=False):
image = np.ones((nx, ny, 1))
label = np.zeros((nx, ny, 3), dtype=np.bool)
mask = np.zeros((nx, ny), dtype=np.bool)
for _ in range(cnt):
a = np.random.randint(border, nx-border)
b = np.random.randint(border, ny-border)
r = np.random.randint(r_min, r_max)
h = np.random.randint(1,255)
y,x = np.ogrid[-a:nx-a, -b:ny-b]
m = x*x + y*y <= r*r
mask = np.logical_or(mask, m)
image[m] = h
label[mask, 1] = 1
if rectangles:
mask = np.zeros((nx, ny), dtype=np.bool)
for _ in range(cnt//2):
a = np.random.randint(nx)
b = np.random.randint(ny)
r = np.random.randint(r_min, r_max)
h = np.random.randint(1,255)
m = np.zeros((nx, ny), dtype=np.bool)
m[a:a+r, b:b+r] = True
mask = np.logical_or(mask, m)
image[m] = h
label[mask, 2] = 1
label[..., 0] = ~(np.logical_or(label[...,1], label[...,2]))
image += np.random.normal(scale=sigma, size=image.shape)
image -= np.amin(image)
image /= np.amax(image)
if rectangles:
return image, label
else:
return image, label[..., 1]
- 所以,_next_data()方法的作用就是生成单张图片和其对应的标签