https://blog.csdn.net/Tramac/article/details/70158426
Python层作为输入的四个必须的函数:
setup(self, bottom, top)
reshape(self, bottom, top)
forward(self, bottom, top)
backward(self, bottom, top)
class VOCSegDataLayer(caffe.layer) #定义为Python层
该类的作用:获取所需要的image对,即input image、label image,作为FCN的输入。
①def setup(self, bottom, top)
setup()的参数:voc_dir:数据的根目录. split:train/val/test模式的选择. mean:图片的均值. randomize:初始化随机数生成,作用? seed:随机化seed,起点,作用?
以下为该5个参数的初始化:
params = eval(self.param_str) #将字符串当成有效的表达式来求值并返回计算结果
self.voc_dir = params['voc_dir']
self.split = params['split']
self.mean = np.array(params['mean']) #变成数组
self.random = params.get('randomize', True) #初始化self.random,默认为True
self.seed = params.get('seed', None)
if len(top)!= 2: #判断输出是否为2(data和label)
报错
if len(bottom) != 0 #判断输入是否为0(数据层无bottom)
#load indices for images and labels(获取image和lable的索引)
split_f = '{}/ImageSets/Segmentation/{}.txt'.format(self.voc_dir, self.split) #.format() 用{}代替%,()中的内容按顺序放入前面的{}中,txt文件中存的是图片的索引
self.indices = open(split_f, 'r').read().splitline() #open()文件打开操作,'r'读模式, 'w'写模式,'a'追加, 'b'二进制, '+'读/写. read()用于从文件读取指定的字节数,若未给定或为负值则读取所有. splitline()按照行('\r','\r\n','\n')分隔,返回一个包含各行作为元素的列表,默认不包含换行符.
self.idx = 0
#如果不是train模式,则不需要参数self.random
if 'train' not in self.split: #not in 运算符,如果在指定的序列中没有找到值,返回True,否则返回False.
#如果是train模式,则需要初始化self.random.
if self.random:
random.seed(self.seed) #seed()不能直接访问,需要导入random模块.()中的self.seed为改变随机数生成器的种子seed.无返回值. 作用:设置生成随机数用的起始值,调用任何其他random模块函数之前调用这个函数.
self.idx = random.randint(0, len(self.indices)-1) #random.randint(a, b)用于生成一个指定范围内的整数,用来打乱'txt'中文件的顺序?
②def reshape(self, bottom, top): #获取image+label对
self.data = self.load_image(self.indices[self.idx]) #self.indices前面得到的文件名列表,self.idx索引值
self.label = self.load_label(self.indices[self.idx])
#load_image()和load_label()定义的两个数据接口函数
#reshape输出的形状
top[0].reshape(1, *self.data.shape) #image
top[1].reshape(1, *self.label.shape) #label
③def forward(self, bottom, top): #assign output分配输出
top[0].data[...] = self.data #将self.data,self.label赋值到top.data?
top[1].data[...] = self.label
#选择下一个输入
if self.randon:
self.idx = random.randint(0, len(self.indices)-1)
else:
self.idx +=1 #若不是train,idx没有随机,从0开始
if self.idx == len(self.indices):
self.idx = 0
④def backward(self, top, propagate_down, bottom)
pass #无反向过程
def load_image(self.idx): #调用时,输入的为indices[idx]
#获取输入的图像,预处理来适配caffe:转为float→改变通道(RGB→BGR)→减去均值→改为CxHxW
im = Image.open('{}/JPEGImage/{}.jpg'.format(self.voc_dir, idx)) #打开图片
in_ = np.array(im, dtype = np.float32 #创建数组并制定数组中元素的类型
in_ = in_[:,:,::-1] #RGB-BGR?
in_ -= self.mean #减去均值
in_ = in_.transpose((2,0,1)) #改变通道顺序
return in_
def load_label(self.idx) #获取label,label为1xHxW单通道,整数数组,单通道是由loss决定的
im = Image.open('{}/SegmentationClass/{}.png'.format(self.voc_dir, idx))
label = np.array(im, dtype = np.unit8)
label = label[np.newaxis,...]
return label
#[...]代表许多产生一个完整的索引元组必要的分号,label本来为2维图像,只有H,W,需要将其变为3维,即1xHxW.
---------------------
本文来自 Tramac 的CSDN 博客 ,全文地址请点击:https://blog.csdn.net/Tramac/article/details/70158426?utm_source=copy