- 构建词库(字典)
def build_vocab(file_path, tokenizer, max_size, min_freq):
vocab_dic = {}
with open(file_path, 'r', encoding='UTF-8') as f:
for line in tqdm(f):
lin = line.strip()
if not lin:
continue
content = lin.split('\t')[0]
for word in tokenizer(content):
vocab_dic[word] = vocab_dic.get(word, 0) + 1 # 记录下出现次数
# 保留下高频词汇
vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
# 词汇表字典 idx word
vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
# 将未知 和填充加进去
vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
return vocab_dic
- 创建dateset
class get_Dataset():
def __init__(self,idx,all_data): # 设置初始信息
self.vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
def __len__(self): # 返回长度
return len(self.seg_data)
def __getitem__(self, item): # 根据item返回数据
return self.seg_data[item]
vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)