介绍
最近在读取tfrecord时,遇到了关于tensorf shape的问题。
我们需要知道,大多数情况下图片进行encode编码保存在tfrecord时 是一个一维张量,shape为(1,)。 而在输入神经网络之前,我们必须要将这个图片张量reshape成一个合乎网络结构需求的三维张量。
在针对这样的需求时,我们会发现,大部分同学会选择在生成tfrecord前就定义好网络的输入shape,例如[224,224,3], 然后将所有的图片先reshape成这个大小,接着存储在tfrecord中。
这种方式的优点在于提前完成的reshape,避免了后续很多的shape uncompatible 的问题,以及后续训练中不用再对图片进行reshape,加快了训练速度。
缺点在于,限制了网络输入尺寸的定义。每修改一次神经网络的输入shape。
当我们需要从存储着不定尺寸图片的tfrecord读取数据时, 我们是无法直接将图片reshape成指定的网络结构输入尺寸的。例如图片大小 [667,1085,3]。显然,我们无法直接将其reshape成 [224,224,3]的。那么我们该如何处理呢?
按照思路,我们应该先将图片的一维tensor 转换成三维tensor, 然后再利用 tf.image库中不同的reshape 操作,将三维图片tensor转换为需要的 tensor大小。
按照这种思路,在这里,我总结了两种读写tfrecord的方式,并对这两种方式的不同点,尤其是容易导致bug的地方进行了整理。
第一种: 利用slim.dataset.Dataset读写tfrecord文件,这种方式常见于利用slim库进行目标检测等网络的实现过程中。
第二种:tf.parse_single_example 是更为常见的一种方式
利用slim.dataset.Dataset读写tfrecord文件
利用这个这个接口读写tfrecord非常的方便。它的神奇之处在于,
它不需要图片宽高的信息,只需要其二进制string tensor。 这个接口会自动返回一个三维图片tensor。 在此基础上,我们可以很方便的对其进行reshape,然后输入神经网络。
具体步骤如下:
在生成tfrecord文件时,我们需要先定义 tf_example的写入格式,然后在将图片文件依据这个写入格式,生成tfrecord文件
- 定义 tf_example的写入特征
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def create_tf_example(image_path, label, resize_size=None):
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
# 对于可能存在RGBA 4通道的图片进行处理
image,is_process = process_image_channels(image)
# 如有必要,那么就在生成tfrecord时即进行resize
width, height = image.size
if resize_size is not None:
if width > height:
width = int(width * resize_size / height)
height = resize_size
else:
width = resize_size
height = int(height * resize_size / width)
image = image.resize((width, height), Image.ANTIALIAS)
# update encode_jpg
if is_process or resize_size is not None:
bytes_io = io.BytesIO()
image.save(bytes_io, format='JPEG')
encoded_jpg = bytes_io.getvalue()
tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature('jpg'.encode()),
'image/class/label': int64_feature(label),
'image/height': int64_feature(height),
'image/width': int64_feature(width)}))
return tf_example
- 生成完整的tfrecord文件
在定义完对应的tf_example 方式后,我们可以遍历图片文件,生成完整的tfrecord文件了。
def generate_tfrecord(annotation_dict, output_path, resize_size=None):
num_valid_tf_example = 0
writer = tf.python_io.TFRecordWriter(output_path)
for image_path, label in annotation_dict.items():
if not tf.gfile.GFile(image_path):
print('%s does not exist.' % image_path)
continue
tf_example = create_tf_example(image_path, label, resize_size)
if tf_example:
writer.write(tf_example.SerializeToString())
num_valid_tf_example += 1
if num_valid_tf_example % 100 == 0:
print('Create %d TF_Example.' % num_valid_tf_example)
writer.close()
print('Total create TF_Example: %d' % num_valid_tf_example)
对应着,在读取tfrecord时,slim提供了 slim.dataset.Dataset 的API接口,非常方便对读入的tfrecord数据进行操作。
def get_record_dataset(record_path,
reader=None,
num_samples=50000,
num_classes=32):
"""Get a tensorflow record file.
Args:
"""
if not reader:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/class/label':
tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
dtype=tf.int64))}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(image_key='image/encoded',
format_key='image/format'),
'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
items_to_descriptions = {
'image': 'An image with shape image_shape.',
'label': 'A single integer.'}
return slim.dataset.Dataset(
data_sources=record_path,
reader=reader,
decoder=decoder,
num_samples=num_samples,
num_classes=num_classes,
items_to_descriptions=items_to_descriptions,
labels_to_names=labels_to_names)
在返回了slim.dataset.Dataset这个slim支持的data封装后, 我们可直接对返回的图片数据进行reshape,保证这个图片张量的shape与网络结构的输入层shape一致。
dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples,
num_classes=FLAGS.num_classes)
data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
image, label = data_provider.get(['image', 'label'])
# 输出当前tensor的静态shape 和动态shape,与另一种读取方式进行对比
print("----------tf.shape(image): ",tf.shape(image))
print("----------image.get_shape(): ",image.get_shape())
image = _fixed_sides_resize(image, output_height=368, output_width=368)
inputs, labels = tf.train.batch([image, label],
batch_size=FLAGS.batch_size,
#capacity=5*FLAGS.batch_size,
allow_smaller_final_batch=True)
其中,对三维图片张量进行reshape的代码如下
def _fixed_sides_resize(image, output_height, output_width):
"""Resize images by fixed sides.
Args:
image: A 3-D image `Tensor`.
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
Returns:
resized_image: A 3-D tensor containing the resized image.
"""
output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)
image = tf.expand_dims(image, 0)
resized_image = tf.image.resize_nearest_neighbor(
image, [output_height, output_width], align_corners=False)
resized_image = tf.squeeze(resized_image)
resized_image.set_shape([None, None, 3])
return resized_image
完成了这几步之后,我们就可以利用image 和 label 进行神经网络训练了。
利用tf.parse_single_example 读写tfrecord文件
这种方式我们需要自己手动将一维的图片tensor,先还原成三维图片tensor。 因为每一张图片的shape不相同。那么我们需要将图片的shape也存入tfrecord文件中。当我们从tfrecord文件中读取时,我们先利用tf.reshape将一维图片张量还原成三维图片张量,再reshape规定的网络输入尺寸。
- 照例,此处的重点在于tf_example的构建。在这一部分,我将图片的shape作为一个feature,也存入了tfrecord里面。 那么,在对张量的还原时,我们可以利用这个三维的shape tensor,
def create_tf_example(image_path, label, resize_size=None):
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
# 对于RGBA 4通道的图片进行处理
image,is_process = process_image_channels(image)
# Resize
width, height = image.size
if resize_size is not None:
if width > height:
width = int(width * resize_size / height)
height = resize_size
else:
width = resize_size
height = int(height * resize_size / width)
image = image.resize((width, height), Image.ANTIALIAS)
img_array = np.asarray(image)
shape = img_array.shape
byte_image = image.tobytes()
tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image': bytes_feature(byte_image),
'label': int64_feature(label),
'img_shape': int64_list_feature(shape)}))
return tf_example
在完成这个后,我们仍旧可以使用上述提及的generate_tfrecord 函数来生成对应的tfrecord
那么,对应这种方式生成的tfrecord文件,我们该如何读取呢?
在这里,我给出对应的parse_example函数就足以了。
def parse(serialized):
# Define a dict with the data-names and types we expect to
# find in the TFRecords file.
# It is a bit awkward that this needs to be specified again,
# because it could have been written in the header of the
# TFRecords file instead.
features = {
'image':
tf.FixedLenFeature((), tf.string, default_value=''),
'label':
tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
dtype=tf.int64)),
'img_shape':
tf.FixedLenFeature(shape=(3,), dtype=tf.int64)}
# Parse the serialized data so we get a dict with our data.
parsed_example = tf.parse_single_example(
serialized=serialized, features=features)
# Get the image as raw bytes.
image_raw = parsed_example['image']
# Decode the raw bytes so it becomes a tensor with type.
image = tf.decode_raw(image_raw, tf.uint8)
# The type is now uint8 but we need it to be float.
image = tf.cast(image, tf.float32)
shape = parsed_example['img_shape']
image = tf.reshape(image,shape=shape)
if not (shape[0] == shape[1] == default_img_size):
image = _fixed_sides_resize(image,default_img_size,default_img_size)
image.set_shape([default_img_size,default_img_size,3])
label = parsed_example['label']
# The image and label are now correct TensorFlow types.
return image, label
在这里,读写tfrecord的重要流程就已经展现好了。
对比
这两种方式有一个比较重要的区别,那就是制作tfrecord时存储的图片信息不同。
使用slim api时 我们制作tfrecord 时,相关代码为
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
当我们使用第二种方式时,制作tfrecord时存储的图片信息的相关代码如下所示。
image = Image.open(img_dir)
byte_image = image.tobytes()
第一种方式保存的图片信息,其字节数不等于图片的height, width, channel的乘积。 所以不能用 第二种的方式去读取这种方式存储的tfrecord。 会出现 reshape时 维度不对的错误。 当然,使用slim.dataset.Dataset 则不需要考虑这个问题了。 网络上使用slim.dataset.Dataset 来加载tfrecord的方式,都是使用第一种方式存储的tfrecord数据。
第二种方式,其存储的图片字节大小等于图片的height, width, channel的乘积。所以它可以直接用tf.reshape直接将原图矩阵还原回来,然后再进行下一步的reshape操作。
总结
之所以写这篇文章,是因为网络上针对不定尺寸图片tfrecord读取的解决方案不是很完善。
例如 https://stackoverflow.com/questions/40258943/using-height-width-information-stored-in-a-tfrecords-file-to-set-shape-of-a-ten
将height, width,channel 分别存入tfrecord,然后按照提问者描述这样是不成功的。
再例如https://stackoverflow.com/questions/35028173/how-to-read-images-with-different-size-in-a-tfrecord-file 提供的解决方案
image_rows = tf.cast(features['rows'], tf.int32)
image_cols = tf.cast(features['cols'], tf.int32)
image_data = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.reshape(image_data, tf.pack([image_rows, image_cols, 3]))
这种方式在tf.reshape阶段会报错,因为我们无法将 两个tensor和一个int数值组合起来。最完善的方式是直接将shape作为一个整体存入tfrecord中,最终读取出来就是一个张量了。