参考:怎么理解tensorflow中tf.train.shuffle_batch()函数?
TensorFlow学习--tf.train.batch与tf.train.shuffle_batch
(1)背景:
(2)用法
tf.train.shuffle_batch() 将队列中数据打乱后再读取出来.
函数是先将队列中数据打乱,然后再从队列里读取出来,因此队列中剩下的数据也是乱序的.
tensors:排列的张量或词典.
batch_size:从队列中提取新的批量大小.
capacity:队列中元素的最大数量.
min_after_dequeue:出队后队列中元素的最小数量,用于确保元素的混合级别.
num_threads:线程数量.
seed:队列内随机乱序的种子值.
enqueue_many:tensors中的张量是否都是一个例子.
shapes:每个示例的形状.(可选项)
allow_smaller_final_batch:为True时,若队列中没有足够的项目,则允许最终批次更小.(可选项)
shared_name:如果设置,则队列将在多个会话中以给定名称共享.(可选项)
name:操作的名称.(可选项)
(2)功能:
Creates batches by randomly shuffling tensors,但需要注意的是它是一种图运算,要跑在sess.run()里。具体地,
This function adds the following to the current Graph:
在运行这个函数时它会在当前图上创建如下的东西:
A shuffling queue into which tensors from tensors are enqueued.
一个乱序的队列,进队的正是传入的tensors
A dequeue_many operation to create batches from the queue.
一个dequeue_many的操作从队列中推出成batch的tensor
A QueueRunner to QUEUE_RUNNER collection, to enqueue the tensors from tensors.
一个QueueRunner的线程,正是这个线程将传入的数据推进队列中.
把数据放在队列里有很多好处,可以完成训练数据和测试数据的解耦,同时有利于写成分布式训练(个人理解),但需要注意的是在取数据的时候,容易造成堵塞的情况.
这时候,应该需要截获超时异常来强制停止线程.