include "chat.h"
include <pybind11/pybind11.h>
include <pybind11/stl.h>
include "models.cpp"
namespace chatllm {
namespace py = pybind11;
using namespace pybind11::literals;
// class PyBaseTokenizer : public BaseTokenizer {
// public:
// using BaseTokenizer::BaseTokenizer;
// std::vector<int> encode(const std::string &text, int max_length) const override {
// PYBIND11_OVERRIDE_PURE(std::vector<int>, BaseTokenizer, encode, text, max_length);
// }
// std::string decode(const std::vector<int> &ids) const override {
// PYBIND11_OVERLOAD_PURE(std::string, BaseTokenizer, decode, ids);
// }
// std::vector<int> encode_messages(const std::vector<ChatMessage> &history, int max_length) const override {
// PYBIND11_OVERLOAD_PURE(std::vector<int>, BaseTokenizer, encode_messages, history, max_length);
// }
// };
// class PyBaseModelForCausalLM : public BaseModelForCausalLM {
// public:
// using BaseModelForCausalLM::BaseModelForCausalLM;
// void load(ModelLoader &loader) override { PYBIND11_OVERLOAD_PURE(void, PyBaseModelForCausalLM, load, loader); }
// ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx,
// bool is_decoding) const override {
// PYBIND11_OVERLOAD_PURE(ggml_tensor *, PyBaseModelForCausalLM, forward, ctx, input_ids, n_past, n_ctx,
// is_decoding)
// }
// };
template <typename T>
static inline std::string to_string(const T &obj) {
std::ostringstream oss;
oss << obj;
return oss.str();
}
PYBIND11_MODULE(_C, m) {
m.doc() = "ChatLLM.cpp python binding";
py::enum_<ModelType>(m, "ModelType")
.value("MINICPM", ModelType::MODEL_TYPE_MINICPM);
py::class_<minicpm::Config>(m, "MiniCPMConfig")
// .def_readonly("dtype", &BaseConfig::dtype)
.def_readonly("vocab_size", &minicpm::Config::vocab_size)
.def_readonly("hidden_size", &minicpm::Config::hidden_size)
.def_readonly("num_attention_heads", &minicpm::Config::num_attention_heads)
.def_readonly("num_hidden_layers", &minicpm::Config::num_hidden_layers)
.def_readonly("intermediate_size", &minicpm::Config::intermediate_size)
.def_readonly("max_length", &minicpm::Config::max_length)
.def_readonly("bos_token_id", &minicpm::Config::bos_token_id)
.def_readonly("eos_token_id", &minicpm::Config::eos_token_id)
.def_readonly("pad_token_id", &minicpm::Config::pad_token_id)
.def_readonly("sep_token_id", &minicpm::Config::sep_token_id)
.def_readonly("num_key_value_heads", &minicpm::Config::num_key_value_heads)
.def_readonly("rope_scaling", &minicpm::Config::rope_scaling)
.def_readonly("rope_theta", &minicpm::Config::rope_theta)
.def_readonly("scale_depth", &minicpm::Config::scale_depth);
py::class_<GenerationConfig>(m, "GenerationConfig")
.def(py::init<int, int, bool, int, float, float, int>(), "max_length"_a = 2048,
"max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0,
"top_p"_a = 0.7, "temperature"_a = 0.95, "num_threads"_a = 0)
.def_readwrite("max_length", &GenerationConfig::max_length)
.def_readwrite("max_context_length", &GenerationConfig::max_context_length)
.def_readwrite("do_sample", &GenerationConfig::do_sample)
.def_readwrite("top_k", &GenerationConfig::top_k)
.def_readwrite("top_p", &GenerationConfig::top_p)
.def_readwrite("temperature", &GenerationConfig::temperature)
.def_readwrite("num_threads", &GenerationConfig::num_threads);
// py::class_<ChatMessage>(m, "ChatMessage")
// .def(py::init<std::string, std::string, std::vector<ToolCallMessage>>(), "role"_a, "content"_a,
// "tool_calls"_a = std::vector<ToolCallMessage>{})
// .def("__repr__", &to_string<ChatMessage>)
// .def("__str__", &to_string<ChatMessage>)
// .def_readonly_static("ROLE_SYSTEM", &ChatMessage::ROLE_SYSTEM)
// .def_readonly_static("ROLE_USER", &ChatMessage::ROLE_USER)
// .def_readonly_static("ROLE_ASSISTANT", &ChatMessage::ROLE_ASSISTANT)
// .def_readonly_static("ROLE_OBSERVATION", &ChatMessage::ROLE_OBSERVATION)
// .def_readwrite("role", &ChatMessage::role)
// .def_readwrite("content", &ChatMessage::content)
// .def_readwrite("tool_calls", &ChatMessage::tool_calls);
// py::class_<minicpm::Tokenizer>(m, "Tokenizer")
// .def("encode", &minicpm::Tokenizer::encode, py::arg("text"))
// .def("decode", &minicpm::Tokenizer::decode, "ids"_a);
// py::class_<chatllm::BaseHistoryEncoder>(m, "BaseHistoryEncoder");
// py::class_<chatllm::BaseTokenizer>(m, "BaseTokenizer")
// .def("load", [](chatllm::BaseTokenizer& tokenizer, const char *buffer, int n_vocab){
// });
// py::class_<chatllm::BaseStreamer>(m, "BaseStreamer");
// py::class_<chatllm::TextStreamer>(m, "TextStreamer");
// .def(py::init<chatllm::BaseTokenizer>(), "tokenizer"_a); // 有bug
py::class_<chatllm::BaseTokenizer, minicpm::Tokenizer>(m, "MiniCPMTokenizer")
.def("encode", [](minicpm::Tokenizer& tokenizer, const std::string& text){
return tokenizer.encode(text);
})
.def("decode", [](minicpm::Tokenizer& tokenizer, const std::vector<int> &ids){
return tokenizer.decode(ids);
});
// .def("load", [](minicpm::Tokenizer& tokenizer, const char *buffer, int n_vocab){
// return tokenizer.load(buffer, n_vocab);
// });
// py::class_<minicpm::ConditionalGeneration>(m, "MiniCPMModel")
// .def("generate_next_token", &minicpm::ConditionalGeneration::generate_next_token,
// "input_ids"_a, "gen_config"_a);
py::class_<minicpm::ConditionalGeneration>(m, "MiniCPMModel")
.def("generate_next_token", [](minicpm::ConditionalGeneration& generation, const std::vector<int> &input_ids, const GenerationConfig &gen_config) {
int gen_token = -1;
if (generation.get_n_past() == 0) {
gen_token = generation.generate_next_token(input_ids, gen_config);
generation.set_n_past(generation.get_n_past() + input_ids.size());
} else {
int lastElement = input_ids.back();
const std::vector<int> &lastElementVec = {lastElement};
gen_token = generation.generate_next_token(lastElementVec, gen_config);
generation.set_n_past(generation.get_n_past() + 1);
}
return gen_token;
})
.def("reset_n_past", [](minicpm::ConditionalGeneration& generation){
generation.set_n_past(0);
})
.def_readonly("config", &minicpm::ConditionalGeneration::config);
// .def("generate", [](minicpm::ConditionalGeneration& generation, const std::vector<int> &input_ids, const GenerationConfig &gen_config,
// const bool continuous,
// bool &completed){
// });
// ===== ChatGLM3 =====
// py::class_<ChatGLM3Tokenizer, BaseTokenizer>(m, "ChatGLM3Tokenizer");
// ===== Pipeline ====
py::class_<Pipeline>(m, "Pipeline")
.def(py::init<const std::string &>(), "path"_a)
.def_property_readonly("model", [](const Pipeline &self) { return self.model; })
.def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer; })
.def("chat", [](Pipeline& pipeline, std::vector<std::string> &history, const GenerationConfig &gen_config){
return pipeline.chat(history, gen_config);
});
}
} // namespace chatglm
from pathlib import Path
import chatllm_cpp._C as _C
class Pipeline(_C.Pipeline):
def init(self, model_path: str) -> None:
if Path(model_path).is_file():
# load ggml model
super().init(str(model_path))
else:
raise RuntimeError("参数错误")
def chat(
self,
message: str,
*,
max_length: int = 2048,
max_context_length: int = 512,
do_sample: bool = True,
top_k: int = 0,
top_p: float = 0.7,
temperature: float = 0.95,
num_threads: int = 0,
# stream: bool = False,
):
input_ids = self.tokenizer.encode(message)
gen_config = _C.GenerationConfig(
max_length=max_length,
max_new_tokens=max_new_tokens,
max_context_length=max_context_length,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
num_threads=num_threads,
)
_C.
if stream:
return self._stream_chat(input_ids=input_ids, gen_config=gen_config)
return self._sync_chat(input_ids=input_ids, gen_config=gen_config)
import _C
pipeline = _C.Pipeline(r"C:\Users\KyoDa\Downloads\chatllm.cpp\quantized_16.bin")
question = "Hello."
ids = pipeline.tokenizer.encode(f" <用户>{question}<AI>")
config = _C.GenerationConfig()
new_token = 0
pipeline.model.reset_n_past()
print(pipeline.model.config.eos_token_id, "<< id")
while new_token != pipeline.model.config.eos_token_id:
new_token = pipeline.model.generate_next_token(ids, config)
ids.append(new_token);
print(new_token, end=',', flush=True)
print(pipeline.tokenizer.decode(ids))
pipeline.chat(["Hello."], config)