ChatGLM微调后推理的两种方式并生成服务

模型下载

以上代码会由 transformers 自动下载模型实现和参数。完整的模型实现可以在 Hugging Face Hub。如果你的网络环境较差,下载模型参数可能会花费较长时间甚至失败。此时可以先将模型下载到本地,然后从本地加载。

从 Hugging Face Hub 下载模型需要先安装Git LFS,然后运行

git clone https://huggingface.co/THUDM/chatglm-6b

如果你从 Hugging Face Hub 上下载 checkpoint 的速度较慢,可以只下载模型实现

GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b

然后从这里手动下载模型参数文件,并将下载的文件替换到本地的 chatglm-6b 目录下。

将模型下载到本地之后,将以上代码中的 THUDM/chatglm-6b 替换为你本地的 chatglm-6b 文件夹的路径,即可从本地加载模型。

lora微调后推理

import torch
from transformers import AutoTokenizer, AutoModel

from peft import get_peft_model, LoraConfig, TaskType
import loralib as lora
import lora_utils.insert_lora

model_path = "/root/autodl-tmp/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
torch.set_default_tensor_type(torch.cuda.HalfTensor)

# 加载基于belle  110万数据微调的lora权重
peft_path = "output/belle/chatglm-lora.pt"
lora_config = {
    'r': 8,
    'lora_alpha': 8,
    'lora_dropout': 0.1,
    'enable_lora': [True, True, True],
}
model = lora_utils.insert_lora.get_lora_model(model, lora_config)
model.load_state_dict(torch.load(peft_path), strict=False)
torch.set_default_tensor_type(torch.cuda.FloatTensor)

response, history = model.chat(tokenizer, "生成一个人野营旅行可能需要的十件物品的清单", history=[])
print(response)

ptuning微调推理

import torch
import os
from transformers import AutoTokenizer, AutoModel,AutoConfig

model_name_or_path = "/root/autodl-tmp/chatglm-6b"
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
config.pre_seq_len = 128

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True).half().cuda()

prefix_state_dict = torch.load(os.path.join("/root/autodl-tmp/huopi-checkpoint-1000/", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

model = model.eval()
schema = ['产品名称','技术','适应症','审批机构','类型','产品描述','注册证号']
news_topic = '''
2023年5月6日,生物技术公司Chiesi Global Rare Diseases与Protalix BioTherapeutics联合宣布,欧盟委员会(EC)已授予PRX-102(pegunigalsidase alfa)在欧盟治疗法布里病成人患者的上市许可。PRX-102是一种用于治疗法布里病的聚乙二醇化酶替代疗法,每两周注射一次,现已获欧盟批准,并正由美国FDA评估。
'''
str1 = news_topic + "\n\n提取上述句子中{}的实体,上述句子中不存在的信息不用强行生成,注意类型中的罗马数字(I II III),并按照JSON格式输出,多个值用数组表示。".format(schema)
response, history = model.chat(tokenizer, str1, history=[])
print(response)

通过fastapi生成一个服务

import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel

import torch
import os
from transformers import AutoTokenizer, AutoModel,AutoConfig

#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model_name_or_path = "/root/autodl-tmp/chatglm-6b"
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
config.pre_seq_len = 128

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, revision="main")
model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True,revision="main").quantize(8).half().cuda()
#model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, revision="main").quantize(4).half().cuda()
#
prefix_state_dict = torch.load(os.path.join("/root/autodl-tmp/huopi-checkpoint-1000/", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
#model = model.cuda()
model = model.eval()

class NewsTopic(BaseModel):
    name: str = ""
    content: str = ""

app = FastAPI()
@app.post("/news_topic_ie/")
async def get_entity(topic: NewsTopic):
    schema = ['产品名称','技术','适应症','审批机构','类型','产品描述','注册证号']
    news_topic = topic.content
    str1 = news_topic + "\n\n提取上述句子中{}的实体,上述句子中不存在的信息不用强行生成,注意类型中的罗马数字(I II III),并按照JSON格式输出,多个值用数组表示。".format(schema)
    response, history = model.chat(tokenizer, str1, history=[]) 
    return {"msg": "success", "code": 200, "response": response}

if __name__ == "__main__":
    uvicorn.run(app, host='0.0.0.0', port=8100)

客户端测试

curl -H "Content-Type:application/json" -d '{"content":"2023年4月25日,三诺生物发布公告,公司于近日收到湖南省药品监督管理局颁发的一项《医疗器械注册证》,产品名称为血压血糖尿酸测试仪(注册证编号:湘械注准20232070361)。适用范围用于测量人体的收缩压、舒张压及脉率(12周岁以上的人),其数值供诊断参考;与配套血糖测试条或者尿酸测试条配合使用,分别用于测试毛细血管全血或静脉全血的葡萄糖、尿酸浓度。"}' http://127.0.0.1:8100/news_topic_ie/

参考:
https://github.com/THUDM/ChatGLM-6B
https://github.com/yanqiangmiffy/InstructGLM

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容