为什么接入私有模型:
1.openai 发送警告,要求调用的gpt接口前先调用验证违规接口验证一下再请求
为了不影响用户的使用,只好接入私有模型,如果验证失败就走自己的模型。。
任务点:
- 基于llama的再训练(只要有数据,训练教程有很多)
- 部署
2.1 基于huggingface自有服务部署
inference-endpoints文档
操作文档
2.2 基于docker部署
github文档 - api调用
text-generation-inference文档
swagger文档
问题点:
用python可以快速解决很多问题,但是用Java要稍微花点时间
- 数据导出
由于导出的数据是一个量很大的json文件,用fastjson一次导出直接oom。。选用gson解决了问题
private void exportJson(List<Object> resultVos, String path) {
try {
Gson gson = new GsonBuilder()
.setPrettyPrinting()
.disableHtmlEscaping()
.create();
JsonWriter writer = new JsonWriter(new FileWriter(path));
JsonElement jsonElement = gson.toJsonTree(resultVos);
gson.toJson(jsonElement, writer);
writer.close();
} catch (Exception e) {
e.printStackTrace();
}
}
-
关于参数max_new_tokens
基于inference-endpoints部署模型时
必须满足:max_new_tokens <= max_number_of_tokens - max_input_length
max_number_of_tokens 必须小于模型的最大支持参数 比如8b的模型最大8192
- 关于参数inputs
openai gpt接口传入的是[message]参数,里面message包含role和content参数,转化inputs代码:
public class LlamaUtils {
private final static String B_INST = "[INST]";
private final static String E_INST = "[/INST]";
private final static String B_SYS = "<<SYS>>\\n";
private final static String E_SYS = "\\n<</SYS>>\\n\\n";
//sep="</s>"
public final static String sep = "</s>";
public static String messageToInput(List<GptMessage> messages) {
StringBuilder builder = new StringBuilder();
messages.forEach(message -> {
if (ChatMessageRole.SYSTEM.value().equals(message.getRole())) {
builder.append(B_SYS);
builder.append(message.getContent());
builder.append(E_SYS);
builder.append(sep);
} else if (ChatMessageRole.USER.value().equals(message.getRole())) {
if (StringUtils.isNotEmpty(message.getContent())) {
builder.append(B_INST);
builder.append(message.getContent().trim());
builder.append(E_INST);
}
} else {
builder.append(message.getContent());
builder.append(sep);
}
});
return builder.toString();
}
}
附录
java接口实体
@Data
public class Llama2Request {
private String inputs;
private Parameters parameters;
private Boolean stream;
@Data
public static class Parameters {
private Double temperature;
private Integer max_new_tokens;
private Integer max_time;
private Boolean return_full_text;
private List<String> stop_sequences;
}
}
//注:返回的列表 使用jsonArray转换
@Data
@NoArgsConstructor
public class LlamaTextGenResponse {
/**
* generated_text : test*
*/
public String generated_text;
}
@Data
@NoArgsConstructor
public class LlamaTextGenSteamResponse {
/**
* details : {"finish_reason":"length","generated_tokens":1,"seed":42}
* generated_text : test
* token : {"id":0,"logprob":-0.34,"special":false,"text":"test"}
*/
public Details details;
public String generated_text;
public Token token;
@NoArgsConstructor
@Data
public static class Details {
/**
* finish_reason : length
* generated_tokens : 1
* seed : 42
*/
public String finish_reason;
public int generated_tokens;
public int seed;
}
@NoArgsConstructor
@Data
public static class Token {
/**
* id : 0
* logprob : -0.34
* special : false
* text : test
*/
public int id;
public double logprob;
public boolean special;
public String text;
}
}