读取数据流程
该过程可以分为三步
-
构造文件名队列
把文件名队列读取进来,并随即打乱shufflu,即从filename到Filename到FilenameQueeue阶段
-
读取与解码
使用读取器从上一步拿到的文件名队列,从文件中读取数据,按照一个样本为单位读取的,图片,文本,的编码不同,使用的解码器不同。
-
批处理阶段
构建批次即batch_size
另一种tfrecord的数据处理形式
对于数据容量不太大的数据集,将其整体转化为Tensorflow专用的格式输入到模型中进行训练是一个非常好的方法,对于某些容量非常庞大的工程,而且往往原始数据集和转换后的数据集容量过大,使得加载和读取耗费更多的资源,从而引起良一系列问题
因此在工程中,除了直接将数据集转化成专用的数据格式之外,还有一种常用的方法就是将需要读取的数据地址集转换成专用的格式,每次直接在其中读取生成batch后的地址,将地址读取后直接在模型每部生成包含25个图片格式的TFRecord。代码如下
def get_batch(image_list,label_list,img_width,img_height,batch_size,capacity):
image=tf.cast(image_list,tf.string)
label=tf.cast(label_list,tf.int32)
input_queue=tf.train.slice_input_producer([image,label])
label=input_queue[1]
image_contents=tf.read_file(input_queue[0])
image=tf.image.decode_jpeg(image_contents,channels=3)
image=tf.image.resize_image_with_crop_or_pad(image,img_width,img_height)
image=tf.image.per_image_standardization(image) # 将图片标准化
image_batch,label_batch=tf.train.batch([image,label],batch_size=batch_size,
num_threads=64,
capacity=capacity)
label_batch=tf.reshape(label_batch,[batch_size])
return image_batch,label_batch
在这里 get_batch(image_list,label_list.img_width,img_height,batch_size,capacity)函数中有6个参数,主要说capacity分别是每次生成的图片数量和内存中存储的最大数据容量,这里可根据不同硬件配置制定。