tensorflow静态图真的是非常难用,尤其是当想要使用控制语句或者循环的时候。
(一) 循环 : while_loop
1)函数定义
tf.while_loop(cond, body, initial_val),其中:
- cond:循环条件,不满足条件时终止
- body:循环体,每次循环都执行一次
- initial_val:参数列表,其中cond和body的参数列表都是initial_val,并且body的返回值与initial_val一一对应
一次循环操作中,1)检查cond是否满足;2)执行循环体,返回initial_val; 3) 不满足cond后循环推出。
2) TensorArray使用
在循环中,可以采用TensorArray来构建具有动态大小的张量。下面就构造了一个有5个元素的tensor,clear_after_read设为false的话,可以多次读,否则只能读一次。需要注意的是,读操作一定要在写操作之后,并且只能写一次。
row = tensor_array_ops.TensorArray(
dtype=tf.float32,
size=5,
dynamic_size=False,
clear_after_read=False)
关于写操作,一个容易犯的错误是:
row.write(1, tf.constant(1., tf.float32, [5])
这样写是无法更新row的值的,正确的写法是:
row = row.write(1, tf.constant(1., tf.float32, [5])
3) 使用张量shape作为循环次数
如果张量T的大小是不确定的,可以采用tf.shape:
def cond(dim, i):
return tf.less(i, dim)
def body(dim, i):
# loop body...
return dim, i + 1
tf.while_loop(cond, body, [tf.shape(T)[i_dim], 0])
(二) 控制语句
1) 函数
tf.cond(pred, fn1, fn2, name=None),其中:
- pred: 判断函数,true则执行fn1,false执行fn2
- fn1, fn2: 条件分支,两个函数的返回值格式应该相同
2) 注意
tensorflow在建图的时候会把两个函数执行的操作都build出来,往图里面传入数据后,算出pred的值,再决定执行哪个函数。所以,如果报ValueError,可能是因为某个分支输入或者输出的数据是空的,要确保默认情况下,两个分支都是可以执行的。
(三) 示例代码
下面的代码将控制语句和循环语句结合起来,执行3次循环,前2次做conv操作,最后一次做加法。有时间还可以研究一下namescope reuse的问题,这里就不讨论了。
import tensorflow as tf
from tensorflow.python.ops import tensor_array_ops
import os
gpu_options = tf.GPUOptions(allow_growth=True)
iters = tf.constant(4)
row = tensor_array_ops.TensorArray(
dtype=tf.float32,
size=5,
dynamic_size=False,
clear_after_read=False)
# element_shape=[2])
def conv(name, x, filter_size, in_filters, out_filters, strides):
with tf.variable_scope(name):
w = tf.get_variable('DW', [filter_size, filter_size, in_filters, out_filters],
initializer=tf.contrib.layers.xavier_initializer_conv2d())
b = tf.get_variable('biases', out_filters, initializer=tf.constant_initializer(0.))
return tf.nn.conv2d(x, w, strides, padding='SAME') + b
mat_adj = tf.constant(1., tf.float32, [1, 2, 2])
initial_feat = tf.constant(1, tf.float32, [1, 2, 2, 4])
iter_rate = tf.placeholder(tf.int32)
def cond(i, row):
return tf.less(i, iters)
def body(i, row):
def f1():
# out = graph_conv(row.read(i - 1), 2, 4, mat_adj, "g", "s")
out = conv("conv", row.read(i - 1), 1, 4, 4, [1,1,1,1])
row_update = row.write(i, out)
return i + 1, row_update
def f2():
out = row.read(i - 1) + 1
# out = graph_conv(row.read(i - 1), 2, 4, mat_adj, "g", "s")
row_update = row.write(i, out)
return i + 1, row_update
i, row_update = tf.cond(tf.less(i, iter_rate), f1, f2)
return [i, row_update]
row = row.write(0, initial_feat)
res = tf.while_loop(cond, body, [1, row])
val = res[1].read(0)
config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
for i in range(3):
print("{}th".format(i))
print(sess.run([val, res[1].read(1), res[1].read(2), res[1].read(3)], feed_dict={iter_rate:3}))