转载自:https://blog.csdn.net/foreseerwang/article/details/80572182
注意,Dataset.from_generator在旧版Tensorflow中没有,在1.4版本以上才有tf.data.Dataset。
tensorflow的基本原理是先构造一个计算图,最后再统一计算。为此,tf重写了几乎所有常见函数,用于构造计算图,而且tensorflow不支持循环、选择(if 跳转)等普通编程语言的常见操作。这就给编程使用带来比较大的麻烦。
Dataset.from_generator可以使用普通编程语言编写的外部子函数生成Dataset,这样几乎不受tensorflow编程不便的影响。先举一个最简单的示例:
import numpy as np
import tensorflow as tf
def data_generator():
dataset = np.array(range(5))
for d in dataset:
yield d
dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32), (tf.TensorShape([])))
dataset = dataset.repeat(3)
dataset = dataset.batch(4)
iterator = dataset.make_one_shot_iterator() #one-shot iterator 是最简单的一种遍历器。这种遍历器只支持#遍历单一dataset,并且还不需要显式的初始化。
one_element = iterator.get_next()
with tf.Session() as sess:
try:
batch_num=0
while True:
one_batch = sess.run(one_element)
print('Batch No. %d:' % batch_num)
print(one_batch)
print('')
batch_num+=1
except tf.errors.OutOfRangeError:
print('end!')
很显然,这个的输出如下:
Batch No. 0:
[0 1 2 3]
Batch No. 1:
[4 0 1 2]
Batch No. 2:
[3 4 0 1]
Batch No. 3:
[2 3 4]
end!
下面给出一个复杂的问题。假设需要输入如下序列:A BA C BC…其中A/B/C分别代表一个文件,例如一张图片或是一个文本文件。每一行是一条记录,按行读入,并聚集多行形成batch,譬如每4行形成一个batch。这里有两个难点:1.每一行/每一条记录的元素长度不一样;2.读入元素A/B/C之后还要以之作为文件名读入文件内容。现有各种data feeding方式似乎很难同时解决这两个难点,除了Dataset.from_generator。
import io
import numpy as np
import tensorflow as tf
class DataFeeder:
def __init__(self, filenames):
self.filenames = filenames
def file_readline(self):
for filename in self.filenames:
fr = io.open(filename, 'r', encoding='utf-8')
while True:
file_line = fr.readline()
if not file_line:
break
datalist = file_line.split()
# if datalist is a list of filename, file contents can
# be read and appendded here.
yield np.asarray(datalist, dtype='int32')
fr.close()
def generate_batch(self, batch_size, num_epochs=None):
dataset = tf.data.Dataset.from_generator(self.file_readline,
tf.int32,
tf.TensorShape([None]))
dataset = dataset.repeat(num_epochs)
dataset = dataset.padded_batch(
batch_size,
padded_shapes=tf.TensorShape([3]),
padding_values=-1)
iterator = dataset.make_one_shot_iterator()
out_batch = iterator.get_next()
return out_batch
filenames = ['a.txt', 'b.txt', 'c.txt']
data_feeder = DataFeeder(filenames)
one_batch = data_feeder.generate_batch(batch_size=2, num_epochs=1)
with tf.Session() as sess:
try:
batch_num = 0
while True:
data_batch = sess.run(one_batch)
print('Batch No. %d:' % batch_num)
print(data_batch)
print('')
batch_num+=1
except tf.errors.OutOfRangeError:
print('end!')
其中三个文本文件a.txt/b.txt/c.txt的内容分别如下:
a.txt:
1 2 3
2 3
3
b.txt:
4 5
6 7 8
9
c.txt:
10 11 12
13 14
15
运行以上代码的输出为:
Batch No. 0:
[[ 1 2 3]
[ 2 3 -1]]
Batch No. 1:
[[ 3 -1 -1]
[ 4 5 -1]]
Batch No. 2:
[[ 6 7 8]
[ 9 -1 -1]]
Batch No. 3:
[[10 11 12]
[13 14 -1]]
Batch No. 4:
[[15 -1 -1]]
end!