[tf]使用tf.where进行选择,tf.greater(v1,v2)进行大小比较

  • tf.greater:输入时两个张量,这个函数会比较两个输入张量中每一个元素的大小,然后返回一个bool类型的tensor,如果两个张量的维度不一致的话,会进行类似Numpy一样的广播操作。
  • tf.where:函数有三个参数,根据第一个条件是否成立,当为True的时候选择第二个参数中的值,否则使用第三个参数中的值。
loss  = tf.reduce_sum (tf.where(tf.greater(v1,v2), (v1 - v2 ) * a, (v2 -v1) * b))
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容