pybind11尝试编写

include "chat.h"

include <pybind11/pybind11.h>

include <pybind11/stl.h>

include "models.cpp"

namespace chatllm {

namespace py = pybind11;
using namespace pybind11::literals;

// class PyBaseTokenizer : public BaseTokenizer {
// public:
// using BaseTokenizer::BaseTokenizer;

// std::vector<int> encode(const std::string &text, int max_length) const override {
// PYBIND11_OVERRIDE_PURE(std::vector<int>, BaseTokenizer, encode, text, max_length);
// }
// std::string decode(const std::vector<int> &ids) const override {
// PYBIND11_OVERLOAD_PURE(std::string, BaseTokenizer, decode, ids);
// }
// std::vector<int> encode_messages(const std::vector<ChatMessage> &history, int max_length) const override {
// PYBIND11_OVERLOAD_PURE(std::vector<int>, BaseTokenizer, encode_messages, history, max_length);
// }
// };

// class PyBaseModelForCausalLM : public BaseModelForCausalLM {
// public:
// using BaseModelForCausalLM::BaseModelForCausalLM;

// void load(ModelLoader &loader) override { PYBIND11_OVERLOAD_PURE(void, PyBaseModelForCausalLM, load, loader); }

// ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx,
// bool is_decoding) const override {
// PYBIND11_OVERLOAD_PURE(ggml_tensor *, PyBaseModelForCausalLM, forward, ctx, input_ids, n_past, n_ctx,
// is_decoding)
// }
// };

template <typename T>
static inline std::string to_string(const T &obj) {
std::ostringstream oss;
oss << obj;
return oss.str();
}

PYBIND11_MODULE(_C, m) {
m.doc() = "ChatLLM.cpp python binding";

py::enum_<ModelType>(m, "ModelType")
    .value("MINICPM", ModelType::MODEL_TYPE_MINICPM);

py::class_<minicpm::Config>(m, "MiniCPMConfig")
    // .def_readonly("dtype", &BaseConfig::dtype)
    .def_readonly("vocab_size", &minicpm::Config::vocab_size)
    .def_readonly("hidden_size", &minicpm::Config::hidden_size)
    .def_readonly("num_attention_heads", &minicpm::Config::num_attention_heads)
    .def_readonly("num_hidden_layers", &minicpm::Config::num_hidden_layers)
    .def_readonly("intermediate_size", &minicpm::Config::intermediate_size)
    .def_readonly("max_length", &minicpm::Config::max_length)
    .def_readonly("bos_token_id", &minicpm::Config::bos_token_id)
    .def_readonly("eos_token_id", &minicpm::Config::eos_token_id)
    .def_readonly("pad_token_id", &minicpm::Config::pad_token_id)
    .def_readonly("sep_token_id", &minicpm::Config::sep_token_id)
    .def_readonly("num_key_value_heads", &minicpm::Config::num_key_value_heads)
    .def_readonly("rope_scaling", &minicpm::Config::rope_scaling)
    .def_readonly("rope_theta", &minicpm::Config::rope_theta)
    .def_readonly("scale_depth", &minicpm::Config::scale_depth);

py::class_<GenerationConfig>(m, "GenerationConfig")
    .def(py::init<int, int, bool, int, float, float, int>(), "max_length"_a = 2048,
        "max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0,
        "top_p"_a = 0.7, "temperature"_a = 0.95, "num_threads"_a = 0)
    .def_readwrite("max_length", &GenerationConfig::max_length)
    .def_readwrite("max_context_length", &GenerationConfig::max_context_length)
    .def_readwrite("do_sample", &GenerationConfig::do_sample)
    .def_readwrite("top_k", &GenerationConfig::top_k)
    .def_readwrite("top_p", &GenerationConfig::top_p)
    .def_readwrite("temperature", &GenerationConfig::temperature)
    .def_readwrite("num_threads", &GenerationConfig::num_threads);

// py::class_<ChatMessage>(m, "ChatMessage")
//     .def(py::init<std::string, std::string, std::vector<ToolCallMessage>>(), "role"_a, "content"_a,
//          "tool_calls"_a = std::vector<ToolCallMessage>{})
//     .def("__repr__", &to_string<ChatMessage>)
//     .def("__str__", &to_string<ChatMessage>)
//     .def_readonly_static("ROLE_SYSTEM", &ChatMessage::ROLE_SYSTEM)
//     .def_readonly_static("ROLE_USER", &ChatMessage::ROLE_USER)
//     .def_readonly_static("ROLE_ASSISTANT", &ChatMessage::ROLE_ASSISTANT)
//     .def_readonly_static("ROLE_OBSERVATION", &ChatMessage::ROLE_OBSERVATION)
//     .def_readwrite("role", &ChatMessage::role)
//     .def_readwrite("content", &ChatMessage::content)
//     .def_readwrite("tool_calls", &ChatMessage::tool_calls);

// py::class_<minicpm::Tokenizer>(m, "Tokenizer")
//     .def("encode", &minicpm::Tokenizer::encode, py::arg("text"))
//     .def("decode", &minicpm::Tokenizer::decode, "ids"_a);

// py::class_<chatllm::BaseHistoryEncoder>(m, "BaseHistoryEncoder");
// py::class_<chatllm::BaseTokenizer>(m, "BaseTokenizer")
//     .def("load", [](chatllm::BaseTokenizer& tokenizer, const char *buffer, int n_vocab){

//     });
// py::class_<chatllm::BaseStreamer>(m, "BaseStreamer");
// py::class_<chatllm::TextStreamer>(m, "TextStreamer");
    // .def(py::init<chatllm::BaseTokenizer>(), "tokenizer"_a); // 有bug

py::class_<chatllm::BaseTokenizer, minicpm::Tokenizer>(m, "MiniCPMTokenizer")
    .def("encode", [](minicpm::Tokenizer& tokenizer, const std::string& text){
        return tokenizer.encode(text);
    })
    .def("decode", [](minicpm::Tokenizer& tokenizer, const std::vector<int> &ids){
        return tokenizer.decode(ids);
    });
    // .def("load", [](minicpm::Tokenizer& tokenizer, const char *buffer, int n_vocab){
    //     return tokenizer.load(buffer, n_vocab);
    // });

// py::class_<minicpm::ConditionalGeneration>(m, "MiniCPMModel")
//     .def("generate_next_token", &minicpm::ConditionalGeneration::generate_next_token, 
//     "input_ids"_a, "gen_config"_a);

py::class_<minicpm::ConditionalGeneration>(m, "MiniCPMModel")
    .def("generate_next_token", [](minicpm::ConditionalGeneration& generation, const std::vector<int> &input_ids, const GenerationConfig &gen_config) {
        int gen_token = -1;
        if (generation.get_n_past() == 0) {
            gen_token = generation.generate_next_token(input_ids, gen_config);
            generation.set_n_past(generation.get_n_past() + input_ids.size());
        } else {
            int lastElement = input_ids.back();
            const std::vector<int> &lastElementVec = {lastElement};
            gen_token = generation.generate_next_token(lastElementVec, gen_config);
            generation.set_n_past(generation.get_n_past() + 1);
        }
        return gen_token;
    })
    .def("reset_n_past", [](minicpm::ConditionalGeneration& generation){
        generation.set_n_past(0);
    })
    .def_readonly("config", &minicpm::ConditionalGeneration::config);
    // .def("generate", [](minicpm::ConditionalGeneration& generation, const std::vector<int> &input_ids, const GenerationConfig &gen_config,
    //                           const bool continuous,
    //                           bool &completed){
        
    // });

// ===== ChatGLM3 =====

// py::class_<ChatGLM3Tokenizer, BaseTokenizer>(m, "ChatGLM3Tokenizer");

// ===== Pipeline ====

py::class_<Pipeline>(m, "Pipeline")
    .def(py::init<const std::string &>(), "path"_a)
    .def_property_readonly("model", [](const Pipeline &self) { return self.model; })
    .def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer; })
    .def("chat", [](Pipeline& pipeline, std::vector<std::string> &history, const GenerationConfig &gen_config){
        return pipeline.chat(history, gen_config);
    });

}

} // namespace chatglm

from pathlib import Path
import chatllm_cpp._C as _C

class Pipeline(_C.Pipeline):
def init(self, model_path: str) -> None:
if Path(model_path).is_file():
# load ggml model
super().init(str(model_path))
else:
raise RuntimeError("参数错误")

def chat(
    self,
    message: str,
    *,
    max_length: int = 2048,
    max_context_length: int = 512,
    do_sample: bool = True,
    top_k: int = 0,
    top_p: float = 0.7,
    temperature: float = 0.95,
    num_threads: int = 0,
    # stream: bool = False,
):
    input_ids = self.tokenizer.encode(message)
    
    gen_config = _C.GenerationConfig(
        max_length=max_length,
        max_new_tokens=max_new_tokens,
        max_context_length=max_context_length,
        do_sample=do_sample,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        num_threads=num_threads,
    )
    _C.
    if stream:
        return self._stream_chat(input_ids=input_ids, gen_config=gen_config)
    return self._sync_chat(input_ids=input_ids, gen_config=gen_config)

import _C
pipeline = _C.Pipeline(r"C:\Users\KyoDa\Downloads\chatllm.cpp\quantized_16.bin")
question = "Hello."
ids = pipeline.tokenizer.encode(f" <用户>{question}<AI>")
config = _C.GenerationConfig()
new_token = 0
pipeline.model.reset_n_past()
print(pipeline.model.config.eos_token_id, "<< id")
while new_token != pipeline.model.config.eos_token_id:
new_token = pipeline.model.generate_next_token(ids, config)
ids.append(new_token);
print(new_token, end=',', flush=True)

print(pipeline.tokenizer.decode(ids))

pipeline.chat(["Hello."], config)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,084评论 6 503
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,623评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 163,450评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,322评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,370评论 6 390
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,274评论 1 300
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,126评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,980评论 0 275
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,414评论 1 313
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,599评论 3 334
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,773评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,470评论 5 344
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,080评论 3 327
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,713评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,852评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,865评论 2 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,689评论 2 354

推荐阅读更多精彩内容

  • 记录源码编译Tensorflow的曲折弯路 前言 通过tensorflow训练深度学习神经网络模型一般是pytho...
    V_爱一世春秋阅读 3,602评论 1 0
  • 前言 在之前的pybind11系列实践中,开发流程大致是这样的: 第一步: 首先在C/C++ IDE中编写C/C+...
    侠之大者_7d3f阅读 14,155评论 4 4
  • CPP 1、在main执行之前和之后执行的代码可能是什么? main函数执行之前,主要就是初始化系统相关资源: 设...
    voidFan阅读 1,695评论 1 6
  • 废话不多说,自己进入今天的主题 1、面向对象的特征有哪些方面? 答:面向对象的特征主要有以下几个方面: - 抽象:...
    传奇内服号阅读 2,349评论 1 31
  • 程序员面试宝典 一、C++ 基础 1. 位运算 返回x二进制数中的1的个数? 返回x,y的平均值? 返回绝对值?...
    小任同学an阅读 1,194评论 0 0