tf.nn.embedding_lookup的用法主要是选取一个张量里面索引对应的元素
原型:tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)
params 代表输入的张量,ids代表要选取params里对应的那个维度的数据
简单来个例子(粘贴可直接运行)
import tensorflow as tf
import numpy as np
a = [[0.1, 0.2, 0.3], [1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]]
a = np.asarray(a)
idx1 = tf.Variable([0, 2, 3, 1], tf.int32)
idx2 = tf.Variable([[0, 2, 3, 1], [4, 0, 2, 2]], tf.int32)
b = [[0.1, 0.2, 1], [2.1, 1.2, 1]]
b = np.asarray(b)
idx3 = tf.placeholder(tf.int32, [None, 3], name="input_x")
out1 = tf.nn.embedding_lookup(a, idx1)
out2 = tf.nn.embedding_lookup(a, idx2)
out3 = tf.nn.embedding_lookup(a, idx3)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print (sess.run(out1))
print (out1)
print ('==================')
print (sess.run(out2))
print (out2)
print (sess.run(out3, feed_dict ={idx3: b}))
print (out3)
其输入内容为
咱们一个一个分析
1.第一个out1代表从a中依次取第 0, 2, 3, 1维数据进行拼装,拼出来的shape还是(4,3)
2.第二个out2代表从a中依次取 第0, 2, 3, 1维数据拼装一个(4,3)的数据 接着再从a中依次取4, 0, 2, 2 维来进行拼装,之后再把两个(4, 3) 拼装在一起形成(2,4,3)的张量(tensor)
3.第三个使用了placeholder来输入ids,placeholder的shape为(?,3),代表从数据里先取3个数据出来,每个数据有3个元素,最后再 ?个(3, 3)拼接在一起组成(?,3,3)的tensor
自己多动手多跑跑例子就可以了。
如有问题欢迎大家指正,谢谢