tensorflow学习-embedding_lookup()用法

embedding_lookup( )的用法
关于tensorflow中embedding_lookup( )的用法,在Udacity的word2vec会涉及到,本文将通俗的进行解释

#!/usr/bin/env/python
# coding=utf-8
import tensorflow as tf
import numpy as np

input_ids = tf.placeholder(dtype=tf.int32, shape=[None])

embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print(embedding.eval())
print(sess.run(input_embedding, feed_dict={input_ids:[1, 2, 3, 0, 3, 2, 1]}))

代码中先使用palceholder定义了一个未知变量input_ids用于存储索引,和一个已知变量embedding,是一个5*5的对角矩阵。
运行结果为:

embedding = [[1 0 0 0 0]
             [0 1 0 0 0]
             [0 0 1 0 0]
             [0 0 0 1 0]
             [0 0 0 0 1]]
input_embedding = [[0 1 0 0 0]
                   [0 0 1 0 0]
                   [0 0 0 1 0]
                   [1 0 0 0 0]
                   [0 0 0 1 0]
                   [0 0 1 0 0]
                   [0 1 0 0 0]]

简单的讲就是根据input_ids中的id,寻找embedding中的对应元素。比如,input_ids=[1,3,5],则找出embedding中下标为1,3,5的向量组成一个矩阵返回。
如果将input_ids改写成下面的格式:

input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
print(sess.run(input_embedding, feed_dict={input_ids:[[1, 2], [2, 1], [3, 3]]}))

输出结果就会变成如下的格式:

[[[0 1 0 0 0]
  [0 0 1 0 0]]
 [[0 0 1 0 0]
  [0 1 0 0 0]]
 [[0 0 0 1 0]
  [0 0 0 1 0]]]

对比上下两个结果不难发现,相当于在np.array中直接采用下标数组获取数据。需要注意的细节是返回的tensor的dtype和传入的被查询的tensor的dtype保持一致;和ids的dtype无关。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 1. tf函数 tensorflow 封装的工具类函数 | 操作组 | 操作 ||:-------------| ...
    南墙已破阅读 5,326评论 0 5
  • TF API数学计算tf...... :math(1)刚开始先给一个运行实例。tf是基于图(Graph)的计算系统...
    MachineLP阅读 3,669评论 0 1
  • 昨天还在低头苦算二重积分,今天就在挤地铁,发邮件。昨天周围的人还是单纯善良的同学朋友,你们之间真心总是能够换取真心...
    冯枝繁阅读 189评论 0 0
  • 最近,兰子夫妇开发了他们的第二职业——伟大的保险事业。 我不喜欢保险,特别讨厌。或许是因为一次有点奇葩的购买留下阴...
    小鸟笑笑阅读 542评论 0 3
  • 第一次接触到“团队”这个词,是我爸在跟我讲他自己在做生意时,带部下出去参加“拓展训练”的故事。 拓展训练,是一项在...
    皮卡弟弟阅读 461评论 0 1