TensorFlow官方教程翻译4:导入数据

原文地址:https://www.tensorflow.org/programmers_guide/datasets
需要注意的是,如下教程的tf.data的模块需要将tensorflow升级到1.4的版本,才可以支持,低于1.4的版本的导入数据教程,见之前的翻译教程,戳这里
原创翻译,版权所有,请勿私自用于商务用途,转载请在文章首尾注明出处

Dataset的API让你能从简单,可重用的模块中构建复杂的输入管道。例如一个图片模型的输入管道,可能要从分布式的文件系统中获得数据,对每张图片做随机扰动,以及将随机选取的图片合并到一个批次中用作训练。文本模型的输入管道可能涉及到从原始文本数据中提取符号,然后将其转换到查找表中嵌入的标识符,以及将不同长度的序列组合成批次。Dataset的API使得处理大量的数据,不同的数据格式和复杂的转换变得容易。

Dataset的API为TensorFlow中引入了两个新的抽象概念:

  • tf.data.Dataset表示一个元素的序列,在这个序列中每个元素包含一个或多个Tensor对象。例如在一个图片输入管道中,一个元素可能是单个训练样本,这个元素包含一对Tensor分别表示图片数据和一个标签。有两种不同的方式创建一个dataset:
    • 创建一个source(例如Dataset.from_tensor_slices())从一个或多个tf.Tensor对象中构建一个dataset
    • 应用一个transformation(例如Dataset.batch())从一个或多个tf.data.Dataset对象中构建一个dataset
  • tf.data.Iterator提供从一个dataset中提取元素的主要方式。Iterator.get_next()返回的操作在运行时会产生一个Dataset的下一个元素,它通常充当着输入管道代码和你的模型之间的接口。最简单的迭代器是“一次性迭代器”,这种迭代器与特殊的Dataset联系并且只通过它迭代一次。对于更复杂的使用,Iterator.initializer操作能让你使用不同的数据集重新初始化和配置迭代器。例如,你可以在同一个程序中多次迭代训练和验证数据。

Basic mechanics

这部分的指南介绍了创建不同类型的Dataset和Iterator对象的基础,以及如何从它们中获取数据。
为了开始一个输入管道,你必须定义一个源。例如你可以使用tf.data.Dataset.from_tensors()或者tf.data.Dataset.from_tensor_slices()来使用在内存中的一些tensor构建一个Dataset。或者你的数据在硬盘里以推荐的TFRecord格式保存,那么你可以创建一个tf.data.TFRecordDataset。

一旦你有了一个Dataset对象,你可以通过在tf.data.Dataset对象上链接方法调用来将其转换成一个新的Dataset对象。比如你可以应用每个元素的转换,如Dataset.map()(来对每个元素调用函数),以及多元素的转换,如Dataset.batch()。有关转换的完整列表,请参阅tf.data.Dataset的文档。
最常见的从一个Dataset中消耗数值的方法就是创建一个迭代器对象,迭代器对象提供对于数据集中一个元素的一次访问(例如通过调用Dataset.make_one_shot_iterator())。tt.data.Iterator提供两个操作:Iterator.initializer,让你能(重新)初始化迭代器的状态;Iterator.get_next()会返回对应符号的下一个元素的tf.Tensor对象。根据你的使用情况,你可以选择不同类型的迭代器,下面概述了可选的迭代器。

Dataset structure

一个数据集包含的每个元素都有同样的结构。一个元素包含一个或者多个tf.Tensor对象,这些对象被称作部件。每个部件有一个tf.DType表示在tensor中元素的类型,和一个tf.TensorShape表示(可能是部分指定的)每个元素的静态形状。
Dataset.output_types和Dataset.output_shapes属性使得你能检查数据集元素的每个部件的推断的类型和形状。这些属性的嵌套结构映射到一个元素的结构,该元素可能是单个张量,张量元组或张量的嵌套元组。例如:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types)  # ==> "tf.float32"
print(dataset1.output_shapes)  # ==> "(10,)"

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random_uniform([4]),
    tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types)  # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes)  # ==> "((), (100,))"

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes)  # ==> "(10, ((), (100,)))"

通常给一个元素的每个组件命名是比较方便的,例如如果它们表示一个训练样本的不同特征。除了元组,你可以使用collections.namedtuple或者将字符串映射到张量的字典来表示Dataset中的单个元素。

dataset = tf.data.Dataset.from_tensor_slices(
   {"a": tf.random_uniform([4]),
    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

Dataset的转换支持任何结构的数据集。当使用Dataset.map(),Dataset.flat_map()和Dataset.filter()转换时——这些转换会对每个元素应用一个函数,元素的结构决定调用函数的参数:

dataset1 = dataset1.map(lambda x: ...)

dataset2 = dataset2.flat_map(lambda x, y: ...)

# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)

Creating an iterator

一旦你创建了一个Dataset对象代表你的输入数据,下一步就是创建一个Iterator来从数据集中获取数据。
Dataset的API现在支持下列迭代器,它们的复杂程度顺序递增:

  • one-shot,
  • initializable,
  • reinitializable
  • feedable.

one-shot迭代器是最简单的迭代器形式,它只支持迭代遍历数据集一次,不需要显式初始化。one-shot迭代器处理了现存的基于队列输入管道支持的几乎所有的情况,但是它们不支持参数化配置。使用Dataset.range()例子:

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
value = sess.run(next_element)
assert i == value

注意:现在,one-shot迭代器是唯一的易于和Estimator使用的迭代器类型。

initializable迭代器要求你在使用它之前,显式的调用iterator.initializer操作。这种不便换来的是它能让你使用一个或多个tf.placeholder()张量来参数化定义数据集,这些张量能在你初始化迭代器的时候被提供。接着Dataset.range()的例子:

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
  value = sess.run(next_element)
  assert i == value

# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
  value = sess.run(next_element)
  assert i == value

reinitializable迭代器能被多个不同的Dataset对象初始化。例如,你可能有一个训练输入管道,它会对输入的图片进行随机扰动来提高其泛华能力,与此同时,有一个验证输入管道在不变的数据上评估预测。这些管道一般使用不同的Dataset对象,但这些对象有相同的结构(比如每个元素有相同的类型和兼容的形状)。

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = Iterator.from_structure(training_dataset.output_types,
                                   training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(100):
    sess.run(next_element)

  # Initialize an iterator over the validation dataset.
  sess.run(validation_init_op)
  for _ in range(50):
    sess.run(next_element)

feedable迭代器可以和tf.placeholder一起使用,通过熟悉的feed_dict机制在每次调用tf.Session.run的时候,选择使用何种Iterator。它提供了与reinitializable迭代器相同的功能,但是在迭代器切换的时候,它不需要从数据集的开头初始化迭代器。例如使用上述相同的训练集和验证集的例子,你可以使用tf.data.Iterator.from_string_handle来定义一个feedable迭代器,它可以让你在两个数据集之间切换:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Loop forever, alternating between training and validation.
while True:
  # Run 200 steps using the training dataset. Note that the training dataset is
  # infinite, and we resume from where we left off in the previous `while` loop
  # iteration.
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  # Run one pass over the validation dataset.
  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})

Consuming values from an iterator

Iterator.get_next()返回一个或者多个tf.Tensor对象,它们对应迭代器的迭代器下一个元素的象征符号。每次这些张量被评估,它们获取在隐藏的数据集中的下一个元素的数值。(注意:像其他在TensorFlow中的状态对象,调用Iterator.get_next()不会马上推动迭代器。相反你必须在TensorFlow表达式中使用返回的tf.Tensor对象,并且将这个表达式的结果传给tf.Session.run()来获取下一个元素和推动迭代器。)
如果迭代器达到数据集的末尾,运行Iterator.get_next()操作会抛出tf.errors.OutOfRangeError的错误。此时迭代器会处于禁用状态,如果你想再次使用它,那么你必须重新初始化它。

dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)

sess.run(iterator.initializer)
print(sess.run(result))  # ==> "0"
print(sess.run(result))  # ==> "2"
print(sess.run(result))  # ==> "4"
print(sess.run(result))  # ==> "6"
print(sess.run(result))  # ==> "8"
try:
  sess.run(result)
except tf.errors.OutOfRangeError:
  print("End of dataset")  # ==> "End of dataset"

一个常见的模式是将“训练循环”封装在try-except块中:

sess.run(iterator.initializer)
while True:
  try:
    sess.run(result)
  except tf.errors.OutOfRangeError:
    break

如果数据集的每个元素是嵌套结构,那么Iterator.get_next()返回值会是以同样结构嵌套的一个或者多个tf.Tensor对象:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

iterator = dataset3.make_initializable_iterator()

sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()

注意,评估next1,next2或者next3中任何一个,都会推动对于所有组件共用的迭代器。一个典型的迭代器消耗,是在单个表达式中包含其所有组件。

Reading input data

Consuming NumPy arrays

如果你的所有输入数据能装进内存中,用它们创建一个Dataset最简单的方式就是将它们转换成tf.Tensor对象,并用Dataset.from_tensor_slices()。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

注意上述代码片段会将features和labels数组作为tf.constant()操作嵌入你的TensorFlow的图中。这对于小的数据集而言运行良好,但是浪费内存——因为数组的内容会被拷贝两次——并且会达到tf.GraphDef协议缓冲的2GB限制。
作为一种替代,你可以按照tf.placeholder()张量来定义Dataset,然后在初始化Iterator的时候供给Numpy数组给这个数据集。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

Consuming TFRecord data

Dataset的API支持各种文件格式,因而你可以处理那些不能装进内存的大型数据集。例如,TFRecord文件格式是一种简单的记录式二进制格式,很多的TensorFlow应用将其格式用于训练数据。tf.data.TFRecordDataset类可以让你将一个或多个TFRecord文件的内容作为输入管道的一部分进行流式处理。

# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

传给TFRecordDataset初始化的filenames参数既可以是一个字符串,或者是一个字符串列表,或者是字符串的tf.Tensor。因此如果你有用于训练和验证两个数据集,你可以使用tf.placeholder(tf.string)来当做filenames参数,然后用合适的filenames参数来初始化迭代器:

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)  # Parse the record into tensors.
dataset = dataset.repeat()  # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()

# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.

# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})

Consuming text data

很多数据集分布在一个或多个文本文件中。
tf.data.TextLineDataset提供了从一个或多个文本文件中获取每行数据的简单方式。给定一个或多个文件名,TextLineDataset会为这些文件的每一行产生一个字符串-数值元素。与TFRecordDataset相同,TextLineDataset接收tf.Tensor类型的数据作为filenames参数,因此你可以通过传一个tf.placeholder(tf.string)来参数化配置filenames这个参数。

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)

默认情况下,TextLineDataset会遍历文件的每一行,这样可能是不必要的,比如文件如果开始有个标题行,或者包含评论。可以使用Dataset.skip()和Dataset.filter()转换移除这些行。为了对每个文件分别应用这些转换,我们使用Dataset.flat_map()为每个文件创建一个嵌套的Dataset。

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]

dataset = tf.data.Dataset.from_tensor_slices(filenames)

# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
    lambda filename: (
        tf.data.TextLineDataset(filename)
        .skip(1)
        .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))

有关使用datasets解析CSV文件的完全例子,查看Regression Examples中的 imports85.py

Preprocessing data with Dataset.map()

Dataset.map(f)转换通过对输入的数据集中每个元素应用给定的函数f,来产生一个新的数据集。这基于在函数式编程语言中经常应用于列表(和其他结构)的map()函数。函数f获得在输入中表示单个元素的tf.Tensor对象,然后返回其在新的数据集中代表的单个元素的tf.Tensor对象。这个实现使用了标准的TensorFlow的操作来将一个元素转换成另一个。
这节包含了如何使用Dataset.map()的常用例子。

Parsing tf.Example protocol buffer messages

很多输入管道从TFRecord格式的文件中提取tf.train.Example协议缓存消息(例如用tf.python_io.TFRecordWriter写的)每个tf.train.Example记录包含一个或多个“特征”,一般输入管道将这些特征转换成张量。

# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

Decoding image data and resizing it

当使用真实世界的图片数据来训练一个神经网络的时候,经常需要将不同大小的图片转换成一个统一的大小,这样使它们能够合批到一个固定的大小。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

Applying arbitrary Python logic with tf.py_func()

因为性能的原因,我们鼓励你尽可能使用TensorFlow的操作来预处理你的数据。但是,当你解析你的输入数据的时候,有时候需要调用额外的Python库。为此,在Dataset.map()转换中调用tf.py_func()。

import cv2

# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
  image_decoded = cv2.imread(image_string, cv2.IMREAD_GRAYSCALE)
  return image_decoded, label

# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
  image_decoded.set_shape([None, None, None])
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
    lambda filename, label: tuple(tf.py_func(
        _read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)

Batching dataset elements

Simple batching

最简单的合批形式就是将一个数据集中的连续n个元素堆叠成一个元素。Dataset.batch()就是这么坐的,它与tf.stack()操作有同样的约束,被应用于每个元素的组件:就是说对于每个组件i,所有的元素必须是有确切形状的张量。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])
print(sess.run(next_element))  # ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
print(sess.run(next_element))  # ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])

Batching tensors with padding

上述办法对于有一样大小的张量有用。但是,很多模型(比如序列模型)处理的输入数据会有不同的大小(比如不同长度的序列)。为了处理这种情况,Dataset.padded_batch()转换能使你通过指定一个或多个维度来填充,从而合批不同形状的张量。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element))  # ==> [[4, 4, 4, 4, 0, 0, 0],
                               #      [5, 5, 5, 5, 5, 0, 0],
                               #      [6, 6, 6, 6, 6, 6, 0],
                               #      [7, 7, 7, 7, 7, 7, 7]]

Dataset.padded_batch()转换允许你对于每个组件的每个维度设置不同的填充,并且这些填充可以是变长度的(如上述例子中的用None指明)或者固定长度的。它也可以设置填充的数值,这个数值默认为0。

Training workflows

Processing multiple epochs

Dataset的API主要提供两种方式来处理同样的数据多代使用的情况。

在一个数据集上迭代多代次的最简单版本使用Dataset.repeat()转换。例如创建一个数据集,重复输入10代次:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)

没有参数的应用Dataset.repeat()将重复输出无限次。Dataset.repeat()转换连接其参数,不会在一代结束和下一代开始的时候发信号。
如果你想要在每代结束的时候接收信号,你可以编写训练循环来捕获数据集末尾的tf.errors.OutOfRangeError。这时你可以为该代收集一些统计信息(比如验证错误)。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Compute for 100 epochs.
for _ in range(100):
  sess.run(iterator.initializer)
  while True:
    try:
      sess.run(next_element)
    except tf.errors.OutOfRangeError:
      break

  # [Perform end-of-epoch calculations here.]

Randomly shuffling input data

Dataset.shuffle()转换使用与tf.RandomShuffleQueue相似的算法来随机打乱输入的数据集:它维护了一个固定大小的缓存,并等概率的随机的从中选择下一个元素。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()

Using high-level APIs

tf.train.MonitoredTrainingSession的API简化了在分布式设置上运行TensorFlow的很多方面。MonitoredTraingSession抛出tf.errors.OutOfRangeError来通知训练的完成,因此配合它使用Dataset的API,我们推荐使用Dataset.make_one_shot_iterator()。例如

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()

next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)

training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
  while not sess.should_stop():
    sess.run(training_op)

为了在tf.estimator.Estimator的input_fn中使用Dataset,我们也推荐使用Dataset.make_one_shot_iterator()。例如:

def dataset_input_fn():
  filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
  dataset = tf.data.TFRecordDataset(filenames)

  # Use `tf.parse_single_example()` to extract data from a `tf.Example`
  # protocol buffer, and perform any additional per-record preprocessing.
  def parser(record):
    keys_to_features = {
        "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
        "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    # Perform additional preprocessing on the parsed data.
    image = tf.decode_jpeg(parsed["image_data"])
    image = tf.reshape(image, [299, 299, 1])
    label = tf.cast(parsed["label"], tf.int32)

    return {"image_data": image, "date_time": parsed["date_time"]}, label

  # Use `Dataset.map()` to build a pair of a feature dictionary and a label
  # tensor for each example.
  dataset = dataset.map(parser)
  dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.batch(32)
  dataset = dataset.repeat(num_epochs)
  iterator = dataset.make_one_shot_iterator()

  # `features` is a dictionary in which each value is a batch of values for
  # that feature; `labels` is a batch of labels.
  features, labels = iterator.get_next()
  return features, labels
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,884评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,755评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,369评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,799评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,910评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,096评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,159评论 3 411
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,917评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,360评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,673评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,814评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,509评论 4 334
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,156评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,882评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,123评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,641评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,728评论 2 351

推荐阅读更多精彩内容