可以参考官方的文件.
tf.data提供了一整套复杂的数据输入和使用的方法. 比如图片的数据可能包含图片的数据(image)和它的标签(label).
1. 创建
有两种创建的方式:
- 从内存中获取呢使用函数
tf.data.Dataset.from_tensors()ortf.data.Dataset.from_tensor_slices(). - 从TFRecord 格式的文件中
tf.data.TFRecordDataset().
1.1 .from_tensors
from_tensors会将数据压缩成一组元素.
1.2 .from_tensor_slices
from_tensor_slices会将数据压缩, 然后以他们的第一个维数进行分组.
dataset0 = tf.data.Dataset.from_tensors([8, 3, 0, 8, 2, 1])
dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
for elem in dataset:
print(elem.numpy())
from_tensors处理后的dataset0会包含一个元素, 而from_tensor_slices处理后会包含多个元素.
使用.numpy()的方式转换成NumPy. 这里使用遍历(for)的方式打印.
也可以使用迭代器
it = iter(dataset)
print(next(it).numpy())
可以使用reduce遍历所有的元素, 生成一个元素, 如
print(dataset.reduce(0, lambda state, value: state + value).numpy())
2. 数据结构
可以使用Dataset.element_spec查看数据类型
Dataset.map()和Dataset.filter()可以被执行用于所有的元素.
3. 批处理所有的元素
使用Dataset.batch()进行批处理
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
# dataset 为 ZipDataset 类似元组
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
# 将其4个一组打包
batched_dataset = dataset.batch(4)
# .take(4)选择前4个
for batch in batched_dataset.take(4):
print([arr.numpy() for arr in batch])
使用drop_remainder可以舍弃剩余项
batched_dataset = dataset.batch(7, drop_remainder=True)