如果所有数据都在内存中,则根据输入数据创建 Dataset 的最简便的方式就是将他们转换成 tf.Tensor() 对象,并使用 Dataset.from_tensor_slices()
# 加载训练数据
with np.load(data_path) as data:
features = data['features']
labels = data['labels']
# 看features和labels的length是不是一样的
assert features.shape[0] == labels.shape[0]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
注意上面的代码段会将 features 和 labels numpy数组作为 tf.constant() 指令嵌入在 tensorflow 图中。这种方式比较适用于比较小型的数据集,但是会浪费内存,因为存在多次复制数组的内容,并可能会达到 tf.GraphDef协议缓冲区的2GB上限。
代替方案 ---> 可以根据tf.placeholder() 张量定义Dataset,并在对数据集初始化 Iterator 时馈送 Numpy 数组。
# 加载训练数据
with np.load(data_path) as data:
features = data['features']
labels = data['labels']
# 看features和labels的length是不是一样的
assert features.shape[0] == labels.shape[0]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset= tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# 下面对dataset 对一些处理,如shuffle, batch, repeat...
dataset = ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
...