tf.nn.embedding_lookup

tf.nn.embedding_lookup( params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)

查找张量中序号为ids的
params:可以是张量,也可以是数组(embedding矩阵)
ids:

params = [[0, 0, 0, 0], [1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]]
params = np.asarray(params)
index = [[1, 2, 3, 4, 0], [3, 4, 2, 1, 0]]
t0 = tf.nn.embedding_lookup(params, [2, 1, 3, 4, 0])
t1 = tf.nn.embedding_lookup(params, index)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(t0))
    print('='*20)
    print(sess.run(t1))

[[2 3 4 5]
 [1 2 3 4]
 [3 4 5 6]
 [4 5 6 7]
 [0 0 0 0]]
====================
[[[1 2 3 4]
  [2 3 4 5]
  [3 4 5 6]
  [4 5 6 7]
  [0 0 0 0]]

 [[3 4 5 6]
  [4 5 6 7]
  [2 3 4 5]
  [1 2 3 4]
  [0 0 0 0]]]


©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容