tf.py_func灵活操作tensor

tensorflow中所有的tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的,不能对其进行判值操作,如if ... else...等,在实际问题中,我们可能需要将一个tensor转换成numpy array 然后进行一些 np.的运算,然后返回tensor. 这样可以加强tensorflow的灵活性。

tf.py_func


其中func函数以numpy arrays 作为输入(或placeholder 需要feed),并以numpy arrays 作为输出。在函数中可自由使用针对 numpy arrays 的操作。

解释一下参数:
func: 是用户自定义函数,输入是numpy array 输出是numpy array
inp: 是func函数接受的输入,是一个列表
Tout: 指定numpy转化为tensor 后的形式

tf.py_func 返回值是一个tensor

注意:

tf.py_func中的func是脱离Graph的。在func中不能定义可训练的参数参与网络训练(反向传播)。

举个例子:

import tensorflow as tf

def add(x,y):
     return x+y,x-y,x.dot(y)

a = [[1,2],[3,4]]
b = [[1,2],[1,1]]
x = tf.placeholder(tf.float32,(2,2))
y = tf.placeholder(tf.float32,(2,2))
result1,result2,result3 = tf.py_func(add, [x,y], [tf.float32,tf.float32,tf.float32])

with tf.Session as sess:
    sess.run(tf.global_varbles_initializer())
    s1,s2,s3 = sess.run([result1,result2,result3],feed_dict = {x:a,y:b})
    print(s1,s2,s3)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。