pytorch.nn.Embadding 详解

数据和枚举的对应关系:{A:1, B:2, C:3, ...}

网络层输入按照枚举方式,比如是A的话,那么输入层就是
A: [1, 0, 0, 0, 0, 0, ...], 如果是B,输入就是:
B: [0, 1, 0, 0, 0, 0, ...], 依次类推:
C: [0, 0, 1, 0, 0, 0, ...]
有多少枚举,就有多少个输入。

从输入到Hidden层,因为只有一个1,其他都是0,如下图:


其实没必要把所有输入i*w+b都计算了。因为其余都是0,只计算所选的那个i就好了。

这就是 nn.Embadding(num_embaddings, num_dim)的意义。

  • num_embaddings就是枚举的个数,也是输入节点数,他会根据输入自动转换为枚举,比如输入2,输入为[0,0,1,0,0,0, \cdots]
  • num_dim是hidden层的数量。
  • padding_idx 就是说这个index不用,作为补齐的。那么遇到这个index,所有输入都是[0,\cdots],就等于什么也不运算。
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容