tf.nn.embedding_lookup( params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)
查找张量中序号为ids的
params:可以是张量,也可以是数组(embedding矩阵)
ids:
params = [[0, 0, 0, 0], [1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]]
params = np.asarray(params)
index = [[1, 2, 3, 4, 0], [3, 4, 2, 1, 0]]
t0 = tf.nn.embedding_lookup(params, [2, 1, 3, 4, 0])
t1 = tf.nn.embedding_lookup(params, index)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(t0))
print('='*20)
print(sess.run(t1))
[[2 3 4 5]
[1 2 3 4]
[3 4 5 6]
[4 5 6 7]
[0 0 0 0]]
====================
[[[1 2 3 4]
[2 3 4 5]
[3 4 5 6]
[4 5 6 7]
[0 0 0 0]]
[[3 4 5 6]
[4 5 6 7]
[2 3 4 5]
[1 2 3 4]
[0 0 0 0]]]