tf.nn.embedding_lookup() 详解

tf.nn.embedding_lookup() 的用途主要是选取一个张量里面索引对应的元素。

原理
  假设一共有m个物体,每个物体有自己唯一的id,那么从物体的集合到R^m有一个trivial的嵌入,就是把它映射到R^m中的标准基,这种嵌入叫做 One-hot embedding/encoding.
  应用中一般将物体嵌入到一个低维空间R^n(n << m),只需要在compose上一个从R^mR^n的线性映射就好了。每一个n*m的矩阵M都定义了一个从R^mR^n的线性映射:x \mapsto Mx。当x是一个标准基向量的时候Mx对应矩阵M中的一列,这就是对应id的向量表示。这个概念用神经网络图来表示如下:


从id(索引)找到对应的One-hot encoding,然后红色的weight就直接对应了输出节点的值(注意这里没有activation function),也就是对应的embedding向量。

函数原型:

tf.nn.embedding_lookup(
     params,
     ids,
     partition_strategy='mod',
     name=None,
     validate_indices=True,
     max_norm=None
)
  • params:由一个tensor或者多个tensor组成的列表(多个tensor组成时,每个tensor除了第一个维度其他维度需相等);
  • ids:一个类型为int32或int64的Tensor,包含要在params中查找的id;
  • partition_strategy:逻辑index是由partition_strategy指定,partition_strategy用来设定ids的切分方式,目前有两种切分方式’div’和’mod’.
  • name:操作名称(可选)
  • validate_indices: 是否验证收集索引
  • max_norm: 如果不是None,嵌入值将被l2归一化为max_norm的值

返回值是一个dense tensor。返回的shape为shape(ids)+shape(params)[1:]


实际上tf.nn.embedding_lookup的作用就是找到要寻找的embedding data中的对应的行下的vector。

import numpy as np
import tensorflow as tf
data = np.array([[[2],[1]],[[3],[4]],[[6],[7]]])
data = tf.convert_to_tensor(data)
lk = [[0,1],[1,0],[0,0]]
lookup_data = tf.nn.embedding_lookup(data, lk)
init = tf.global_variables_initializer()

先让我们看下不同数据对应的维度:

In [76]: data.shape
Out[76]: (3, 2, 1)
In [77]: np.array(lk).shape
Out[77]: (3, 2)
In [78]: lookup_data
Out[78]: <tf.Tensor 'embedding_lookup_8:0' shape=(3, 2, 2, 1) dtype=int64>

这个是怎么做到的呢?关键的部分来了,看下图:



lk中的值,在要寻找的embedding数据中找对应的index下的vector进行拼接。永远是look(lk)部分的维度+embedding(data)部分的除了第一维后的维度拼接。很明显,我们也可以得到,lk里面值是必须要小于等于embedding(data)的最大维度减一的。

以上的结果就是:

In [79]: data
Out[79]:
array([[[2],
        [1]],

       [[3],
        [4]],

       [[6],
        [7]]])

In [80]: lk
Out[80]: [[0, 1], [1, 0], [0, 0]]

# lk[0]也就是[0,1]对应着下面sess.run(lookup_data)的结果恰好是把data中的[[2],[1]],[[3],[4]]

In [81]: sess.run(lookup_data)
Out[81]:
array([[[[2],
         [1]],

        [[3],
         [4]]],


       [[[3],
         [4]],

        [[2],
         [1]]],


       [[[2],
         [1]],

        [[2],
         [1]]]])

最后,partition_strategy是用于当len(params) > 1,params的元素分割不能整分的话,则前(max_id + 1) % len(params)多分一个id.
当partition_strategy = 'mod'的时候,13个ids划分为5个分区:[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]],也就是是按照数据列进行映射,然后再进行look_up操作。
当partition_strategy = 'div'的时候,13个ids划分为5个分区:[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]],也就是是按照数据先后进行排序标序,然后再进行look_up操作。

参考:https://www.jianshu.com/p/abea0d9d2436

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 6月28日 坚持分享第158天 看了如何与压力做朋友后,分享给朋友们视频开篇,从业十多年的健康心理学家就...
    周老师成长记录仪阅读 197评论 0 0
  • 我住的宿舍中华上下5000年中的《将相和》。战国时,蔺相如因完璧归赵有功。被赵王封为大夫。有一次他带什么课坐车出门...
    邵铭琳阅读 182评论 0 0
  • 读狂人日记 文/小月二水 狂人日记是一份喧嚣的孤寂,在寒冷中祭奠温热的情怀 一个狂人至死不休欢唱生命 当他一面面撕...
    小月二水阅读 190评论 0 1