现在市面上Java实现的流式输出代码很少,只能自己动手丰衣足食。
一、为什么大语言模型使用流式输出内容
大语言模型采用流式输出内容的原因主要有以下几点:
提高用户体验:流式输出使得模型的回复不是一次性生成整个回答,而是逐字逐句地生成。这种方式避免了用户长时间等待整个回复生成完毕的情况,从而提升了用户体验。
提升交互响应速度:通过逐字蹦出回复,可以实现更快的交互响应。这意味着在用户输入消息后,模型可以快速开始生成回答的开头,并根据上下文逐渐细化回答。
增强对话透明度:流式输出可以让用户看到模型逐步构建回答的过程,这有助于用户理解模型是如何形成回答的,提高了对话的透明度和可解释性。
优化性能表现:对于大型语言模型来说,生成完整的内容可能需要较长的计算时间。流式输出允许模型边计算边输出,这样即使模型推理效率不是很高,也能保证用户体验不会受到太大影响。
实现动画效果:流式输出还可以模仿打字机的动画效果,即一个字或一个词的输出,给用户一种答案逐渐出现的视觉效果。
综上所述,流式输出是大语言模型在交互过程中的一种有效策略,它兼顾了效率和用户体验,同时也增强了模型的互动性和透明度。
二、关于SSE技术
当然WebSocket也可以达到效果,本文使用更轻量的SSE来实现。
SSE,全称为Server-Sent Events(服务器发送事件),是一种允许服务器向浏览器客户端推送实时信息的Web技术。这种机制基于HTTP协议,利用长轮询的方式,让服务器可以主动向客户端发送更新的数据,而无需客户端不断地发起请求去询问是否有新的数据。这个特点正好符合我都需求,看了这么多大语言模型的应用,一直在琢磨底层实现。
SSE的工作原理是建立在传统的HTTP请求之上,但与传统HTTP请求不同的是,一旦建立连接,服务器就可以持续地向客户端发送消息,直到连接被关闭。客户端接收到的消息通常以JSON或其他格式编码,并且每条消息都包含一个事件类型和数据负载。
SSE具有以下特点:
单向通信:SSE主要用于服务器向客户端发送数据,而不是双向通信。如果需要客户端向服务器发送数据,通常需要另外的机制,如WebSocket。
简单高效:SSE使用标准的HTTP协议,不需要额外的库或插件,且相比于WebSocket,它在某些情况下可能更加高效,因为它只需要服务器发送数据,而不需要保持全双工的连接。
自动重连:如果连接中断,SSE会自动尝试重新连接,这对于需要高可靠性的实时数据推送非常有用。
跨浏览器支持:大多数现代浏览器都支持SSE,包括Chrome、Firefox、Safari和Edge。
SSE常用于需要实时数据更新的应用场景,如股票价格更新、新闻推送、社交媒体通知等。通过SSE,开发者可以轻松地构建实时交互式的Web应用程序,为用户提供更加丰富和动态的体验。
三、SSE代码实现
在Java中,SseEmitter
是 Spring 框架提供的一个用于服务器发送事件(Server-Sent Events, SSE)的工具。SSE 允许服务器向客户端推送实时信息,客户端通过一个持久的HTTP连接接收这些信息。
SpringBoot的pom依赖如下:
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!--
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>
-->
在使用SSE之前也测试了webflux框架实现打字机,效果不好放弃。
为了使用 SseEmitter
实现类似打字机效果的流式输出,你需要创建一个 SseEmitter
实例,然后逐步发送数据给客户端。下面是一个简单的例子:
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
public class TypewriterController {
@GetMapping("/typewriter")
public SseEmitter typewriter() {
SseEmitter emitter = new SseEmitter();
// 模拟从大语言模型获取数据的过程
String[] data = {"Hello", "World", "from", "the", "large", "language", "model"};
for (String word : data) {
try {
// 模拟打字机效果,每个单词之间暂停100毫秒
Thread.sleep(100);
emitter.send(SseEmitter.event().data(word));
} catch (InterruptedException e) {
emitter.completeWithError(e);
return emitter;
}
}
emitter.complete();
return emitter;
}
}
在这个例子中,我们定义了一个 TypewriterController
类,它有一个 /typewriter
端点。当这个端点被访问时,它会创建一个新的 SseEmitter
对象,并逐个发送字符串数组中的单词。每发送一个单词后,线程会暂停100毫秒来模拟打字机的效果。
客户端可以通过建立一个到 /typewriter
端点的持久连接来接收这些事件。例如,如果你使用JavaScript作为客户端,你可以这样写:
<!DOCTYPE html>
<html>
<head>
<title>Typewriter Effect</title>
</head>
<body>
<div id="output"></div>
<script>
var source = new EventSource('/typewriter');
source.onmessage = function(event) {
var outputDiv = document.getElementById('output');
outputDiv.innerHTML += event.data + ' '; // 将接收到的数据添加到页面中
};
</script>
</body>
</html>
这段HTML和JavaScript代码会打开一个到服务器的SSE连接,并在接收到新数据时更新页面的内容。每次收到数据时,都会将其追加到 <div id="output">
元素中,从而实现类似于打字机逐字显示文本的效果。
请注意,在实际应用中,你可能需要处理更多的细节,比如错误处理、连接关闭时的清理工作等。此外,如果你的大语言模型是通过异步方式生成数据的,你可能还需要考虑如何与 SseEmitter
进行集成,以确保数据能够正确地流式传输到客户端。
四、API方式对接SSE服务端
上面基于JavaScript作为客户端获取返回的流式数据,通过后端的Java代码也可以对接SSE服务端,增加其他逻辑,比如:鉴权、计费、敏感词过滤等。
pom文件:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.3.8.RELEASE</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.guo.test</groupId>
<artifactId>streamClient</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>streamClient</name>
<description>streamClient</description>
<properties>
<java.version>11</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.10.0</version>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
<version>4.10.0</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
要注意下面代码设置接收媒体类型为:
text/event-stream
,建议使用`MediaType.TEXT_EVENT_STREAM_VALUE`代替字符串编码。
package com.guo.test.streamclient.client;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.ResponseBody;
import org.springframework.http.MediaType;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
public class streamClient {
private static final String SSE_URL = "http://localhost:8080/llm/stream/query?query=地铁安全门的规范"; // 替换为你的SSE端点地址
public static void main(String[] args) {
OkHttpClient client = new OkHttpClient.Builder()
.readTimeout(0, TimeUnit.MILLISECONDS) // 设置无限读取超时,因为SSE是长连接
.build();
Request request = new Request.Builder()
.url(SSE_URL)
.header("Accept", "text/event-stream") // 设置接收SSE媒体类型
.build();
try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) {
throw new IOException("Unexpected code " + response);
}
// 获取响应体并读取流
ResponseBody responseBody = response.body();
if (responseBody == null) {
return;
}
try (java.io.Reader reader = responseBody.charStream()) {
char[] buffer = new char[1024];
int bytesRead;
while ((bytesRead = reader.read(buffer)) != -1) {
String data = new String(buffer, 0, bytesRead);
// 处理接收到的SSE数据
System.out.println("Received data: " + data);
// 查找SSE事件边界(通常是"\n\n")
/*int eventBoundary = data.indexOf("\n\n");
if (eventBoundary != -1) {
String event = data.substring(0, eventBoundary);
String eventData = data.substring(eventBoundary + 2);
// 处理事件头部和事件数据
System.out.println("Event: " + event);
System.out.println("Event Data: " + eventData);
}*/
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
}