数据和枚举的对应关系:{A:1, B:2, C:3, ...}
网络层输入按照枚举方式,比如是A的话,那么输入层就是
A: [1, 0, 0, 0, 0, 0, ...], 如果是B,输入就是:
B: [0, 1, 0, 0, 0, 0, ...], 依次类推:
C: [0, 0, 1, 0, 0, 0, ...]
有多少枚举,就有多少个输入。
从输入到Hidden层,因为只有一个1,其他都是0,如下图:
其实没必要把所有输入都计算了。因为其余都是0,只计算所选的那个就好了。
这就是 nn.Embadding(num_embaddings, num_dim)的意义。
- num_embaddings就是枚举的个数,也是输入节点数,他会根据输入自动转换为枚举,比如输入2,输入为。
- num_dim是hidden层的数量。
- padding_idx 就是说这个index不用,作为补齐的。那么遇到这个index,所有输入都是,就等于什么也不运算。