Tensorflow的线程同步和停止

Tensorflow的多线程使用

Tensorflow的计算主要在使用CPU/GPU和内存,而数据读取涉及磁盘操作,速度远低于前者操作。因此通常会使用多个线程读取数据,然后使用一个线程来使用这些数据,QueueRunner就是来管理这些读写队列的线程,而只用QueueRunner的话有时候会造成这种同步的卡壳,导致程序被强行关闭,因此需要QueueRunner和Coordinator的配合来进行调用,共同协作来停止绘画中的所有线程,并向在等待所有工作线程终止的程序报告。
示例如下所示:

import tensorflow as tf

q = tf.FIFOQueue(1000 , "float32")
counter = tf.Variable(0.0)
#   函数原型是tf.assign_add(ref,value,use_locking=None,name=None),作用是更新ref的值,通过增加value,即:ref = ref + value
add_op = tf.assign_add(counter , tf.constant(1.0))
#   通过enqueue函数将counter变量加入队列
enqueueData_op = q.enqueue(counter)

#   Session 是 Tensorflow 为了控制,和输出文件的执行的语句,意思就是将其加入tensorflow的对话,
#   运行 sess.run() 可以获得你要得知的运算结果, 或者是你所要运算的部分。
sess = tf.Session()
#   tf.train.QueueRunner是创建并运行线程的函数,q代表之前创建的队列,enqueue_ops代表需要加入到q的线程
#   add_op表示的是计数,enqueueData_op表示的是加入队列,这里实际创建了4个线程,两个增加计数,两个执行入队
#   这一步的作用是用多个线程向队列添加数据,这样的话就可以减少由于数据读取的慢速度影响程序整体的运行速度
qr = tf.train.QueueRunner(q , enqueue_ops=[add_op , enqueueData_op] * 2)
sess.run(tf.global_variables_initializer())

#   开启一个协调器
coord = tf.train.Coordinator()
#   启动队列运行器线程
enqueue_threads = qr.create_threads(sess , coord=coord , start=True)


for i in range(10):
    print(sess.run(q.dequeue()))

#   完成后,要求线程停止
coord.request_stop()
#   并等待这些线程完成
coord.join(enqueue_threads)
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容