# Define some large matrices
a = ...
b = ...
c = ...
pred = tf.placeholder(tf.bool)
def if_true():
return tf.matmul(a, b)
def if_false():
return tf.matmul(b, c)
# Will be `tf.cond()` in the next release.
from tensorflow.python.ops import control_flow_ops
result = tf.cond(pred, if_true, if_false)
sess = tf.Session()
sess.run(result, feed_dict={pred: True}) # ==> executes only (a x b)
sess.run(result, feed_dict={pred: False}) # ==> executes only (b x c)
def is_false():
ori_loss = tf.cond(
tf.equal(mode,1),
lambda:tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(action_one,1),logits=action_one),
lambda:tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(action_two,1),logits=action_two),
)
gene_loss = reward_score * ori_loss
train_gene_op = optimize(gene_loss, self.global_step)
return gene_loss,train_gene_op
return tf.cond(tf.equal(mode,0),lambda:tf.zeros(shape=[]),tf.zeros(shape=[]),is_false)