原文:https://blog.csdn.net/angel_hben/article/details/84341421
import numpy as np
import tensorflow as tf
p.random.seed(0)
x = np.random.sample((11,2)) # make a dataset from a numpy array
print(x)
dataset = tf.data.Dataset.from_tensor_slices(x) ####
#####最好用 dataset = tf.data.Dataset.from_generator(self.generator, (tf.float32, tf.int32,tf.int32, tf.string)),使用例子如下:
‘’‘
def generator(self):
for index in range(len(self.data_list)):
file_basename_image,file_basename_label = self.data_list[index]
image_path = os.path.join(self.data_dir, file_basename_image)
label_path= os.path.join(self.data_dir, file_basename_label)
image= self.read_data(image_path)
label = self.read_data(label_path)
label_pixel,label=self.label_preprocess(label)
image = (np.array(image[:, :, np.newaxis]))
label_pixel = (np.array(label_pixel[:, :, np.newaxis]))
yield image, label_pixel,label, file_basename_image
’‘’
dataset = dataset.shuffle(2)#将数据打乱,数值越大,混乱程度越大
dataset = dataset.batch(4)#按照顺序取出4行数据,最后一次输出可能小于batch
dataset = dataset.repeat()#数据集重复了指定次数
# repeat()在batch操作输出完毕后再执行,若在之前,相当于先把整个数据集复制两次 #为了配合输出次数,一般默认repeat()空
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()
with tf.Session() as sess:
for i in range(6):
value = sess.run(el)
print(value)
---------------------