基于 Qwen3 + LoRA + SFT 的医疗问答系统

今天将分享基于 Qwen3 + LoRA + SFT 的医疗问答系统完整实现版本,为了方便大家学习理解整个流程,将整个流程步骤进行了整理,并给出详细的步骤结果。感兴趣的朋友赶紧动手试一试吧。 一、背景介绍

随着人工智能和自然语言处理(Natural Language Processing, NLP)的快速发展,大规模预训练语言模型(Large Language Models, LLMs)已在通用问答、文本生成与知识推理等任务中表现出卓越的能力。近年来,诸如 GPT 系列LLaMA 系列 和 Qwen 系列 等模型通过在海量语料上进行预训练,显著提升了模型的语言理解与生成能力。然而,这些通用语言模型通常基于开放域数据进行训练,缺乏医学专业知识,对医疗场景中涉及的疾病、药物、检查、治疗及诊断过程等专业问题理解不足。此外,医学语言具有高度专业化、语义精确性强、风险容忍度低等特点,使得直接应用通用大模型于医疗问答任务可能导致知识错误或回答不当。为此,学界与工业界提出了通过参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)有监督微调(Supervised Fine-Tuning, SFT)相结合的方式,在医学领域数据上对通用大模型进行领域特化训练,从而在保留模型语言理解能力的同时,提高其医学知识掌握与临床表达准确性,为未来智能医疗辅助系统(如临床问答助手、病历摘要生成、影像诊断解释等)提供了重要的技术基础。

二、核心原理

1、Qwen3(由阿里巴巴推出的多语言通用大模型)基于 Transformer 解码器结构,具备以下特点:大规模多语料预训练:覆盖中英文及多语言语料;多任务能力迁移:在指令理解、逻辑推理、知识问答等任务中表现优异;高扩展性:模型参数规模涵盖 1.8B 到 72B,可适配不同计算资源;开放指令调优(Instruct-tuning)机制:为后续领域 SFT 提供了良好的基础。Qwen3 模型在预训练阶段学习到强大的语言建模与语义推理能力,但缺乏医学知识的显式表达,因此需要通过 领域微调 来引导模型学习医学语义特征与回答风格。

2、LoRA 是一种轻量级参数高效微调方法。其核心思想是:在保持原始模型参数冻结不变的情况下,仅在特定权重矩阵(如 Query、Value 投影层)中插入低秩适配矩阵,通过训练这些小规模矩阵实现知识注入。这种方法的优势是:显著降低显存占用与训练成本;避免灾难性遗忘,保留原模型语言能力;可插拔部署,LoRA 权重可独立加载与卸载。在医学场景下,这意味着我们只需微调极少参数即可令模型具备专业医学问答能力。

3、SFT 是指在高质量人工标注数据(如医生问答对、病历摘要等)上进行有监督训练,使模型学习输入–输出对应关系,从而形成符合医学逻辑的问答能力。在医疗问答系统中,SFT 的目标是:使模型理解病人提问语义;生成符合医学事实、符合临床逻辑的医生式回答;具备解释性与安全性 三、Qwen3 + LoRA + SFT 微调流程

1、基础模型加载。加载 Qwen3-7B 或 Qwen3-1.7B 模型,并冻结原始权重。

model_name_or_path: "D:/3dlib/LLM/Qwen3-1.7B"train_data_path: "D:/cjq/project/python/qwen3_Lora_sft_project/data/tokenized/train"eval_data_path: "D:/cjq/project/python/qwen3_Lora_sft_project/data/tokenized/eval"output_dir: "log/medicalQA_sft"# === 数据处理 ===max_seq_length: 2048# === 优化器与学习率 ===learning_rate: 2e-5lr_scheduler_type: "cosine"warmup_ratio: 0.03weight_decay: 0.01# === 训练参数 ===num_train_epochs: 3per_device_train_batch_size: 1per_device_eval_batch_size: 1gradient_accumulation_steps: 16gradient_checkpointing: truefp16: true# === 日志与保存 ===logging_steps: 10save_steps: 500eval_steps: 500save_strategy: "epoch"evaluation_strategy: "epoch"report_to: [ "tensorboard" ]logging_dir: "log/tensorboard"save_total_limit: 2load_best_model_at_end: truemetric_for_best_model: "eval_loss"greater_is_better: false
def __init__(self, sftcfg, loracfg):        # === Load Configs ===        with open(sftcfg, "r", encoding="utf-8"as f:            self.sft_cfg = yaml.safe_load(f)        with open(loracfg, "r", encoding="utf-8"as f:            self.lora_cfg = yaml.safe_load(f)        # === Load Model + LoRA ===        self.model = AutoModelForCausalLM.from_pretrained(self.sft_cfg["model_name_or_path"], device_map="auto")        self.tokenizer = AutoTokenizer.from_pretrained(self.sft_cfg["model_name_or_path"], trust_remote_code=True)        self.peft_config = LoraConfig(**self.lora_cfg)        self.model = get_peft_model(self.model, self.peft_config)

2、 LoRA 注入。在注意力层的 Query/Value 权重中插入低秩矩阵。

r8lora_alpha16lora_dropout0.05target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"]bias"none"task_type"CAUSAL_LM"

3、数据构建。将医疗问答数据(患者提问、医生回答)格式化为监督样本。

import jsonfrom datasets import Datasetdef load_medical_data(path):    with open(path, 'r', encoding='utf-8'as f:        data = [json.loads(line) for line in f]    return datadef processjsontotokenized_dataset(train_path, eval_path, output_path):    train_data = load_medical_data(train_path)    eval_data = load_medical_data(eval_path)    train_dataset = Dataset.from_list(train_data)    valid_dataset = Dataset.from_list(eval_data)    train_dataset.save_to_disk(f"{output_path}/train")    valid_dataset.save_to_disk(f"{output_path}/eval")if __name__ == "__main__":    processjsontotokenized_dataset(r"D:\cjq\project\python\qwen3_Lora_sft_project\data/medical_train.jsonl",                                   r"D:\cjq\project\python\qwen3_Lora_sft_project\data/medical_eval.jsonl",                                   r"D:\cjq\project\python\qwen3_Lora_sft_project\data\tokenized")

4、SFT 训练。通过交叉熵损失训练 LoRA 层,使模型学习医学问答映射关系。

def Update(self):        # === Load Dataset ===        train_dataset = load_tokenized_dataset(self.sft_cfg["model_name_or_path"], self.sft_cfg["train_data_path"])        eval_dataset = load_tokenized_dataset(self.sft_cfg["model_name_or_path"], self.sft_cfg["eval_data_path"])        data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)        # === Training Args ===        args = TrainingArguments(            output_dir=self.sft_cfg["output_dir"],            per_device_train_batch_size=int(self.sft_cfg["per_device_train_batch_size"]),            num_train_epochs=int(self.sft_cfg["num_train_epochs"]),            learning_rate=float(self.sft_cfg["learning_rate"]),            logging_steps=int(self.sft_cfg["logging_steps"]),            save_strategy=self.sft_cfg["save_strategy"],            report_to="none",        )        trainer = Trainer(            model=self.model,            args=args,            train_dataset=train_dataset,            eval_dataset=eval_dataset,            tokenizer=self.tokenizer,            data_collator=data_collator,        )        trainer.train()        self.model.save_pretrained(self.sft_cfg["output_dir"])

5、评估与推理。在未见过的医学问答集上测试生成结果,评价其专业性与安全性。

class LORASFTinferenceModel(object):    def __init__(self, model_dir):        # 1️⃣ 加载 tokenizer        self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)        # 2️⃣ 加载模型(自动选择 GPU/CPU,使用半精度加速)        self.model = AutoModelForCausalLM.from_pretrained(model_dir,                                                          torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,                                                          device_map="auto", trust_remote_code=True)        self.model.eval()    def infernece(self, prompt):        # inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)#会把输入也解码出来        # 3️⃣ 构造输入(Qwen3 推荐使用 apply_chat_template)        # 如果是多轮对话,可传入 messages=[{"role":"user","content":prompt}]        inputs = self.tokenizer.apply_chat_template(            [{"role""user""content": prompt}],            tokenize=True,            add_generation_prompt=True,            return_tensors="pt").to(self.model.device)        attention_mask = (inputs != self.tokenizer.pad_token_id).to(self.model.device)        # 4️⃣ 生成输出(加上 temperature、top_p 控制生成多样性)        outputs = self.model.generate(inputs, attention_mask=attention_mask, max_new_tokens=1024,                                      temperature=0.7, top_p=0.9, do_sample=True,                                      pad_token_id=self.tokenizer.eos_token_id)        # 5️⃣ 解码输出(取生成的新内容部分)        response = self.tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)        print("🩺 模型回复:", response)        return response
输入:"我最近连续两周晚上咳嗽,尤其是干咳,没有发热,这是支气管炎吗?需要服用抗生素吗?"🩺 模型回复: <think>考虑是由于上呼吸道感染导致的这种情况建议你注意休息加强营养多喝白开水避免熬夜和劳累。可以用头孢拉定片或者阿奇霉素分散片治疗观察看看吧。如果效果不好的话再进一步检查一下确诊后再对症用药才好啊。另外就是要注意保暖不要着凉了。还有就是平时可以经常用热水泡脚促进血液循环改善睡眠质量增加抵抗力。饮食方面要多吃新鲜蔬菜水果补充维生素C等有利于提高免疫力。希望我的回答对你有帮助,

完整代码已经上传到github上:

https://github.com/junqiangchen/qwen3_Lora_sft_medicalQA点击阅读原文可以访问该项目,如果大家觉得这个项目还不错,希望大家给个Star并Fork,可以让更多的人学习。如果有任何问题,随时给我留言我会及时回复的。

本文使用 文章同步助手 同步

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容