背景
上一篇文章中介绍了如何将word和md导出为Json, 此事是第一步,有了Json才能给NLTK模型进行匹配,调试,训练。
设计
NLTK的设计分为一下几个步骤:
- 读取faq.json数据集
- 构建问题字典
- 文本预处理,分句,分词,去除通用词,词干提取
- 匹配问题函数,使用BM25召回算法,使用最高分进行问题回复
- 返回配置的回复
- 封装CRUD数据库,使用脚本将大Json写入到DB里,对DB操作大大提高查询性能
- 封装Python对外 Public API,供业务层使用
- 部署本机环境
- 编写一个简单的html进行测试
1.文本预处理,将文本拆分为句子子列表
def split_sentences(text: str) -> List[str]:
"""将文本拆分为句子列表"""
return re.split(r'[.!?\n]+', text)
2.对文本进行预处理,包括分词,停用词删除
def preprocess_text(text: str) -> List[str]:
"""对文本进行预处理,包括分词和停用词删除"""
tokens = word_tokenize(text.lower())
stop_words = set(stopwords.words('english'))
return [token for token in tokens if token not in stop_words and token.isalnum() and token not in {'p', 'ol', 'li', 'img', 'https', 'strong'}]
3.使用snowballStemmer 对单词进行词干提取
def stem_words(words: List[str]) -> List[str]:
"""使用 SnowballStemmer 对单词进行词干提取"""
stemmer = SnowballStemmer("english")
return [stemmer.stem(word) for word in words]
4.匹配问题,使用BM25相关新计算得分
def bm25_score(query: List[str], doc: List[str], question_list: List[dict], k1=1.5, b=0.75) -> float:
"""
计算BM25相关性得分
"""
score = 0
doc_length = len(doc) # 计算文档长度一次
term_freq = {term: doc.count(term) for term in query} # 预计算词频
for term in query:
tf = term_freq[term]
doc_containing_term = sum(1 for q in question_list for q_text in [q['question']] if term in q_text)
if doc_containing_term >= len(question_list):
idf = 0
else:
idf = log((len(question_list) - doc_containing_term + 0.5) / (doc_containing_term + 0.5))
score += idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * doc_length / 300)) # 使用预计算的文档长度
return score
5. 匹配问题逻辑
匹配问题并返回答案,业务API
- 优先匹配keyWords,
- 当keyWords匹配结果是None,则使用question匹配最高分
- 当前两个步骤都没匹配到,则最终匹配answer中的关键词,这里就会存在较大误差
- 最终任何一个都没匹配到则返回None
def match_question(query: str, q_platform: str) -> Tuple[str, str]:
"""
匹配问题并返回答案
"""
query_words = preprocess_text(query)
query_words = stem_words(query_words)
question_arr = get_question_dict(q_platform)
best_score = 0
best_answer = None
for questions in question_arr:
keywords = questions['keywords']
answer = questions['answer']
answer_words = preprocess_text(keywords)
answer_words = stem_words(answer_words)
answer_words = [word for word in answer_words if word.isalnum()]
#print(f"answer_words: {answer_words}\n\n")
score = bm25_score(query_words, answer_words, question_arr)
if score > best_score and answer: # 只在找到更好的分数时更新
best_score = score
best_answer = answer
if best_answer is None:
for questions in question_arr:
keywords = questions['question']
answer = questions['answer']
answer_words = preprocess_text(keywords)
answer_words = stem_words(answer_words)
score = bm25_score(query_words, answer_words, question_arr)
if score == 0:
if all(word in answer_words for word in query_words):
best_answer = answer
break
elif score > best_score and answer:
best_score = score
best_answer = answer
if best_answer is None:
for questions in question_arr:
keywords = questions['answer']
answer = questions['answer']
answer_words = preprocess_text(answer)
answer_words = stem_words(answer_words)
# 移除符号字符串
answer_words = [word for word in answer_words if word.isalnum()]
#print(f"answer_words: {answer_words}\n\n")
score = bm25_score(query_words, answer_words, question_arr)
if score > best_score and answer: # 只在找到更好的分数时更新
best_score = score
best_answer = answer
return best_answer or "Sorry, I couldn't find a suitable answer."
6.封装一些CRUD DB方法,用来进行1级问题,2级问题,3级问题检索
def get_answer_by_id(question_ids: str, platform: str) -> List[Dict[str, Union[str, int]]]: # 修改参数类型为str,返回类型为List[Dict[str, Union[str, int]]]
"""
通过问题ID返回答案
"""
question_id_list = question_ids.split(',')
# 从DB查询
answers = DB.sqlManager.get_faq_by_id(question_id_list,None, int(platform_auto(platform)))
if answers:
return answers
return answers
#获取一级标题
def get_top_level_questions(platform: str) -> List[Tuple[int, str, str]]:
top_level_questions = DB.sqlManager.get_faq_by_id(None, [0], int(platform_auto(platform))) # Ensure parent_id=0 is passed
return top_level_questions
#获取id下的子问题
def get_subQuestions_by_id(p_id: int, platform: str) -> List[Tuple[int, int, str, str]]:
matching_questions = DB.sqlManager.get_faq_by_id(None,[p_id], int(platform_auto(platform)))
return matching_questions
def upsert_faq_from_json(faq_data):
return DB.sqlManager.upsert_faq_from_json(faq_data)
7.封装Python sqlManager.py
- 建表,要对id,parentId, platform建立表索引,提高查询性能
class FAQ(Base):
__tablename__ = 'faqs'
id = Column(Integer, primary_key=True, index=True)
parentId = Column(Integer, index=True) # Added index for performance
question = Column(String(255))
answer = Column(Text)
keywords = Column(Text)
isLast = Column(Integer)
platform = Column(Integer, index=True) # Added index for performance
type = Column(Integer)
def to_dict(self):
return {
'id': self.id,
'parentId': self.parentId,
'question': self.question,
'answer': self.answer,
'keywords': self.keywords,
'isLast': self.isLast,
'platform': self.platform,
'type': self.type
}
7.1封装的CRUD要加锁,保证数据的原子性
# 增加数据
def add_faq_from_json(json_data):
session = Session()
try:
faqs = [FAQ(**item) for item in json_data]
session.bulk_save_objects(faqs)
session.commit()
except SQLAlchemyError as e:
session.rollback()
print(f"Error: {e}")
finally:
session.close()
部署本机Python后台服务
新建一个run.py, 用来运行后台服务
- 封装查询API:
app = Flask(__name__)
CORS(app)
@app.route('/v1/faq', methods=['GET'])
def handle_faq_request():
user_query = request.args.get('query')
platform = request.args.get('platform')
if user_query:
try:
answer = kami_ntlk.match_question(user_query, platform)
return jsonify({'code': 200, 'data': answer})
except Exception as e:
return jsonify({'code': 502, 'data': 'An error occurred: {}'.format(str(e))}), 400
else:
return jsonify({'code': 502, 'data': 'No query parameter provided.'}), 400
- 使用WGIServer创建服务
if __name__ == '__main__':
# Debug
#app.run(host='0.0.0.0', port=8000, debug=False)
# Production
http_server = WSGIServer(('', 8000), app)
http_server.serve_forever()
- 运行
--python run.py
--python -m http.server 8000 --bind 0.0.0.0
开启本机服务