背景
springboot对接gpt,实现流式对话传输
后端
接口层
为了实现对话流式传输,需要设置接口返回类型,同时设置下响应Header(Header不添加也可以)
@PostMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter sse(@Validated @RequestBody AnalyzeChatVO vo, HttpServletResponse response) {
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
return gptService.analyzeChatStream(vo);
}
实现
需要注意的是,输出的内容需要异步返回,你用线程池或者线程都可以,只需要异步就可以了
public SseEmitter analyzeChatStream(AnalyzeChatVO vo) {
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
if (StringUtils.isEmpty(vo.getUser())) {
vo.setUser(UsernameHolder.getUsername());
}
ThreadPoolExecutor executor = ThreadPoolUtil.simpleThreadPool("chat", 1, 1);
try {
CompletableFuture.runAsync(() -> streamRequest(vo, new SseListener(emitter, this)), executor).whenComplete((r, t) -> {
if (t != null) {
emitter.completeWithError(t);
log.error("Stream request start error,", t);
}
});
} finally {
executor.shutdown();
}
return emitter;
}
这里的监听器是通过okhttp来实现的,因此需要先引入okhttp的sse模块
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
<version>4.9.1</version>
</dependency>
然后将我们自定义的监听器注册上去,其中baseUrl就是我们模型的地址,然后携带对应的token就可以了;
模型请求中携带的chatId,主要是为了用户隔离
private void streamRequest(AnalyzeChatVO vo, EventSourceListener listener) {
GptClient client = getStreamClient();
String url = client.getAttribute().getBaseUrl() + "/api/v1/chat/completions";
log.info("Stream url:{}", url);
OkHttpClient okHttpClient = client.getOkHttpClient();
EventSource.Factory factory = EventSources.createFactory(okHttpClient);
String requestBody = String.format("{\"chatId\": \"%s\",\"stream\": true, \"messages\": [{\"role\":\"user\", \"content\": \"%s\"}]}",
vo.getUser(), vo.getQuestion().replace("\n", ""));
Request.Builder builder = new Request.Builder()
.url(url)
.header("Authorization", client.getToken());
.post(RequestBody.create(requestBody, okhttp3.MediaType.parse(MediaType.APPLICATION_JSON.toString())));
Request request = builder.build();
factory.newEventSource(request, listener);
}
监听器
这里自定义的监听器主要就是实现EventSourceListener 的相关方法;ChatCompletion主要就是定义了对话的返回结构体,lastMessage 用于接收整个完整的返回消息,因为消息是按照流式一部分一部分返回的,这里拼接下完整消息内容,也可以去掉
public abstract class AbstractStreamListener extends EventSourceListener {
protected String lastMessage = "";
private static final String STREAM_END = "[DONE]";
@Setter
@Getter
protected Consumer<String> onComplete = s -> {
};
public abstract void onMsg(String message);
public abstract void onError(Throwable throwable, String response);
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("Open");
}
@Override
public void onClosed(EventSource eventSource) {
log.info("Closed");
}
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("Event:{}", data);
if (STREAM_END.equals(data)) {
onMsg(data);
onComplete.accept(lastMessage);
return;
}
ChatCompletion response = JSON.parseObject(data, ChatCompletion.class);
String text = response.toPlainStringStream();
Map<String, String> dataToSend = Maps.newHashMap();
dataToSend.put("content", text);
if (StringUtils.isNotEmpty(text)) {
lastMessage += text;
// fix to raw data, avoid '\n' messages be resolved
onMsg(JSON.toJSONString(dataToSend));
}
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable throwable, Response response) {
log.info("Fail", throwable);
try {
String responseText = "";
if (Objects.nonNull(response) && Objects.nonNull(response.body())) {
responseText = response.body().string();
}
log.error("Listener failure response:{}", responseText);
this.onError(throwable, responseText);
} catch (Exception e) {
log.error("Listener on failure error,", e);
} finally {
eventSource.cancel();
}
}
}
最终的实现在SseListener 中,将监听器中收到的消息转发到SseEmitter;同时在消息完成后,打印一下完整的消息内容
public class SseListener extends AbstractStreamListener {
private SseEmitter emitter;
public SseListener(SseEmitter emitter) {
this.emitter = emitter;
super.setOnComplete((s) -> {
log.info("Complete message:{}", s);
emitter.complete();
});
}
@Override
public void onMsg(String message) {
log.info(message);
try {
emitter.send(message);
} catch (IOException e) {
log.error("Send message error,", e);
}
}
@Override
public void onError(Throwable throwable, String response) {
log.error("Listener error: {}", response, throwable);
emitter.completeWithError(throwable);
}
}
这样后端的简单实现就算是完成了
Nginx
如果你的项目中请求是通过nginx代理的,那么还需要调整下nginx的配置,主要是添加proxy_redirect off; proxy_buffering off;
这两个配置,关闭nginx的缓存功能
location /sse {
add_header Access-Control-Allow-Origin *;
add_header Access-Control-Allow-Methods 'GET, POST, OPTIONS';
add_header Access-Control-Allow-Headers 'DNT,X-Mx-ReqToken,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Authorization';
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header Host $http_host;
proxy_redirect off;
proxy_buffering off;
proxy_cache off;
proxy_pass http://upstream;
}