在训练神经网络模型时,每过一遍数据都要反向传播来更新网络参数,但有时候我们还想做其他操作,比如:更新参数的滑动平均、BN操作等。为了一次完成多个操作而且是有顺序的执行,TF提供了tf.control_dependencies和tf.group两种机制。
tf.control_dependencies(control_inputs)
参数:
control_inputs:在运行上下文中定义的操作之前必须执行或计算的列表Operation或Tensor对象;也可以None清除控件依赖关系
例子
import tensorflow as tf
a = tf.constant([0])
for i in range(10):
a = a + 1
with tf.control_dependencies([a]):
c = tf.identity(a)
with tf.Session() as sess:
print(sess.run(c)) # 输出[10]
import tensorflow as tf
sess = tf.InteractiveSession()
tag =tf.reduce_any([True])
a = tf.Variable([1])
b = tf.Variable([2])
c = tf.Variable([10])
do_updates = tf.group(a.assign(a+9),
b.assign(b+8),
c)
d = a + b + c
sess.run(tf.global_variables_initializer())
# 执行group内所有ops
sess.run([do_updates])
# 检查,也就是现在a、b的值已经改变(执行了+运算)
print(sess.run(a)) # [10]
print(sess.run(b)) # [10]
print(sess.run(c)) # [10]
# 我们想要的结果
print(sess.run(d)) # [30]
有时候可以把d看作是我们想要的最终结果,把a、b、c看作是d的前提或基础(不执行前提就得不到我们想要的结果)
所以对多个操作进行打包,进而有序的执行;再比如BP时loss进行参数更新,但是我们希望下一批次loss更新的参数是进行滑动平均后的参数,所以这就产生了依赖关系,而这两个API可以很好的处理这种情况!
笔记,如有错误,请加以指正,感谢~~~~~~~~~~~~~