one-hot encoding
import torch
import torch.nn as nn
class LinearEmbedding(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(LinearEmbedding, self).__init__()
self.embed = nn.Linear(vocab_size, embedding_dim, bias=False)
def forward(self, input):
# Create one-hot encoding
one_hot = torch.nn.functional.one_hot(input, num_classes=self.embed.in_features)
return self.embed(one_hot.float())
直接输入小数
import torch
import torch.nn as nn
class LinearEmbedding(nn.Module):
def __init__(self, embedding_dim):
super(LinearEmbedding, self).__init__()
self.embed = nn.Linear(1, embedding_dim, bias=False)
def forward(self, input):
# Convert input to float
input = input.float().unsqueeze(-1) # Add extra dimension for nn.Linear
return self.embed(input)
是的,你可以使用 nn.Linear(1, embedding_dim) 来创建一个类似嵌入层的结构,而不需要将输入视为独热编码。但是,有几点需要注意:
nn.Linear(1, embedding_dim) 的输入应该是一个浮点张量。如果你的输入索引是整数(通常单词索引是这样),你需要将它们转换为浮点数。
nn.Linear(1, embedding_dim) 层将为其在输入中遇到的每个唯一浮点值学习一个单独的嵌入。如果你的词汇量很大,这可能会导致参数数量非常大。
与 nn.Embedding 学习整数索引的嵌入不同,nn.Linear(1, embedding_dim) 会在输入值之间进行插值。例如,如果它学习了值 1.0 和 2.0 的嵌入,它也可以生成值 1.5 的嵌入,作为 1.0 和 2.0 的嵌入的线性组合。这对于单词嵌入可能不太理想,因为每个单词都是一个独立的实体,单词之间没有固有的顺序或连续性。