TF 提供了一种统一输入数据的格式—— TFRecord ~
它有两个特别好的优点:
1.可以将一个样本的所有信息统一起来存储,这些信息可以是不同的数据类型;
2.利用文件队列的多线程操作,使得数据的读取和批量处理更加方便快捷。
part 1 获得数据
从 CelebA 数据集的20多万个数据中,得到每一个样本的图像及对应的标签,用作图像分类的训练和测试数据:
def get_data(txt_path,img_path):
imgs = []
labels = []
with open(txt_path) as f:
# 解压后的 list_attr_celeba.txt 文件从第三行开始是数据内容
line = f.readline() # 第一行
line = f.readline() # 第二行
line = f.readline() # 第三行
while line:
array = line.split()
file_name = array[0]
# print(file_name)
img = cv2.imread(img_path+file_name)
img = cv2.resize(img,(96,128))
imgs.append(img)
label = np.zeros([5,2])
for i,idx in enumerate([16,35,36,38,39]):
l = int(array[idx])
if l == 1:
label[i,1] = 1
else:
label[i,0] = 1
labels.append(label)
line = f.readline()
print('Data prepared!')
return imgs,labels
调用上面定义的 get_data()函数,得到 images 和 labels(这里label取了5类,判断人脸是否含有帽子/眼镜/项链/耳环/领带等装饰):
txt_path = r'E:/celeA/list_attr_celeba.txt'
img_path = r'E:/celeA/img_align_celeba/'
imgs,labels=get_data(txt_path,img_path)
len(imgs),len(labels) # (202599, 202599)
part 2 创建一个 writer 将数据写入TFRecord文件
TFRcord 文件中的数据都是通过 tf.train.Example()
定义的,其中包含了一个从属性名称到取值的字典。
属性名称为一个字符串,属性的取值可以为字符串(BytesList)/ 实数列表(FloatList)/ 整数列表(Int64List)。
# 生成字符串型的属性。
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成整数型的属性。
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成实数型的属性。
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
如:imgs 的数据类型为 uint8,而labels的数据类型为float64,tfrecord可以将图片及其对应的标签编码成字符串,作为 tfrecord 文件中的一条数据。下面取前20万数据做为训练数据,写入20个文件,每个文件记录10000条数据,剩下的作为测试数据:
num_shards = 20 # 文件数
instances_per_shard = 10000 # 每个文件包含的数据量
for i in range(num_shards):
# 文件名如'E:/celeA/data/data.tfrecords-00000-of-00100'
filename = ('E:/celeA/all_data/test.tfrecords-%.5d-of-%.5d'%(i,num_shards-1))
writer = tf.python_io.TFRecordWriter(filename)
for j in range(instances_per_shard):
# 将图像和标签转化成字符串
image_raw = test_x[instances_per_shard*i+j].tostring()
label_raw = test_y[instances_per_shard*i+j].tostring()
# 将图像和标签数据作为一个example
example = tf.train.Example(features=tf.train.Features(feature={
'image':_bytes_feature(image_raw),
'label':_bytes_feature(label_raw)
}))
writer.write(example.SerializeToString())
writer.close()
part 3 创建一个reader来读取TFRecord文件
files=tf.train.match_filenames_once('E:/celeA/all_data/data.tfrecords-*')
# 文件队列,方便利用多线程管理原始文件列表
filename_queue = tf.train.string_input_producer(files,shuffle=False)
reader = tf.TFRecordReader()
# 解析读入的单个数据
_,serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
# tf.FixedLenFeature() 是一种属性解析方法,解析结果为一个Tensor
'image':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([],tf.string) # 这里的数据格式要和写入时一样
})
# tf.decode_raw() 用于解析字符串
# tf.cast() 转换数据类型
img = tf.decode_raw(features['image'],tf.uint8)
image = tf.reshape(tf.cast(img,tf.float32), [96,128,3])
l = tf.decode_raw(features['label'],tf.float64)
label = tf.reshape(tf.cast(l,tf.float32), [5,2])
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
# print(sess.run(files))
# 启用多线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(80):
im,la = sess.run([image,label])
print(im.shape,la.dtype) # (96, 128, 3) float32
coord.request_stop()
coord.join(threads)
注意:在解析字符串时,解析的数据类型如果和原始数据的数据类型不一样,解析得到的结果就和原始数据不同,所以在读写 tfrecord 文件时一定要明确原始数据类型。这里image的原始数据类型是uint8,为了作为在tensorflow中网络的输入数据(一般是 tf.float32),利用tf.cast()
函数将数据类型转换成 tf.float32 ,label 亦然。
还有一点需要注意:
因为用到的文件队列操作,这里需要开启多线程(指定线程数量,默认为1)。
part 4 组合数据 batching
在训练网络时,通常将训练数据分成小批量的数据进行训练,这样能够提高模型训练效率。tensorflow 提供了tf.train.batch()
和tf.train.shuffle_batch
函数来将组织小批量数据。
batch_size = 64
min_after_dequeue = 64 # 定义出队时最少元素个数来保证随机打乱的顺序
capacity = min_after_dequeue+3*batch_size # batch 队列中最多可以存储的数据个数
batch_x,batch_y = tf.train.shuffle_batch([image,label],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
# print(sess.run(files))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(10):
b_x,b_y = sess.run([batch_x,batch_y])
print(b_x,b_y)
coord.request_stop()
coord.join(threads)
part 5 输入数据处理框架
1.生成用于训练和测试的 tfrecord 文件
2.定义计算图
3.开启会话,在训练过程中,从不同文件中读取小批量数据(是否按顺序,可选)进行训练/验证