模型
四种模式
- CNN-rand: 单词向量是随机初始化,向量随着模型学习而改变
- CNN-static: 使用预训练的静态词向量,向量不会随着模型学习而改变
- CNN-non-static: 使用预训练的静态词向量,预训练的向量可以微调(fine-tuned)
- CNN-multichannel: 静态+微调 两个channel都使用预训练的静态词向量,卷积核用在两个channel上,反向传播只改变一个channel
代码
if args.static: #使用预训练的静态词向量
args.embedding_dim = text_field.vocab.vectors.size()[-1]
args.vectors = text_field.vocab.vectors
if args.multichannel:
args.static = True
args.non_static = True
# args.class_num = len(label_field.vocab)
args.class_num = len(label_field.vocab) - 1
import torch
import torch.nn as nn
import torch.nn.functional as F
class TextCNN(nn.Module):
def __init__(self, args):
super(TextCNN, self).__init__()
self.args = args
class_num = args.class_num
channel_num = 1
filter_num = args.filter_num
filter_sizes = args.filter_sizes
vocabulary_size = args.vocabulary_size
embedding_dimension = args.embedding_dim
self.embedding = nn.Embedding(vocabulary_size, embedding_dimension)
if args.static:
self.embedding = self.embedding.from_pretrained(args.vectors, freeze=not args.non_static)
if args.multichannel:
# multichannel:non_static=True and static=True
# channel1 fine-tuned
# channel2 static
self.embedding2 = nn.Embedding(vocabulary_size, embedding_dimension).from_pretrained(args.vectors)
channel_num += 1
else:
self.embedding2 = None
self.convs = nn.ModuleList(
# ModuleList是一个特殊的module,可以包含几个子module,
# 可以像用list一样使用它,但不能直接把输入传给 ModuleList。
# (N, C_in, H, W) => (N, C_out, H, W)
[nn.Conv2d(channel_num, filter_num, (size, embedding_dimension)) for size in filter_sizes])
self.dropout = nn.Dropout(args.dropout)
self.fc = nn.Linear(len(filter_sizes) * filter_num, class_num)
def forward(self, x):
if self.embedding2:
x = torch.stack([self.embedding(x), self.embedding2(x)], dim=1)
else:
x = self.embedding(x)
# torch.unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度
# 升维 (N, size, embedding_dimension) =>
# (N, channel_num, size, embedding_dimension)
x = x.unsqueeze(1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs] #卷积后降维
x = [F.max_pool1d(item, item.size(2)).squeeze(2) for item in x] #最大值池化后降维
#torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度
x = torch.cat(x, 1) # 拼接 3个卷集核,一个卷集核100(filter_num)个值
x = self.dropout(x)
logits = self.fc(x)
return logits
问题
- target = target.data.sub(1)
- len(label_field.vocab) == 3 ?