Java语言实现大语言模型输出的打字机效果(Stream流式)

现在市面上Java实现的流式输出代码很少,只能自己动手丰衣足食。

一、为什么大语言模型使用流式输出内容

大语言模型采用流式输出内容的原因主要有以下几点:

  1. 提高用户体验:流式输出使得模型的回复不是一次性生成整个回答,而是逐字逐句地生成。这种方式避免了用户长时间等待整个回复生成完毕的情况,从而提升了用户体验。

  2. 提升交互响应速度:通过逐字蹦出回复,可以实现更快的交互响应。这意味着在用户输入消息后,模型可以快速开始生成回答的开头,并根据上下文逐渐细化回答。

  3. 增强对话透明度:流式输出可以让用户看到模型逐步构建回答的过程,这有助于用户理解模型是如何形成回答的,提高了对话的透明度和可解释性。

  4. 优化性能表现:对于大型语言模型来说,生成完整的内容可能需要较长的计算时间。流式输出允许模型边计算边输出,这样即使模型推理效率不是很高,也能保证用户体验不会受到太大影响。

  5. 实现动画效果:流式输出还可以模仿打字机的动画效果,即一个字或一个词的输出,给用户一种答案逐渐出现的视觉效果。

综上所述,流式输出是大语言模型在交互过程中的一种有效策略,它兼顾了效率和用户体验,同时也增强了模型的互动性和透明度。

二、关于SSE技术

当然WebSocket也可以达到效果,本文使用更轻量的SSE来实现。
SSE,全称为Server-Sent Events(服务器发送事件),是一种允许服务器向浏览器客户端推送实时信息的Web技术。这种机制基于HTTP协议,利用长轮询的方式,让服务器可以主动向客户端发送更新的数据,而无需客户端不断地发起请求去询问是否有新的数据。这个特点正好符合我都需求,看了这么多大语言模型的应用,一直在琢磨底层实现。

SSE的工作原理是建立在传统的HTTP请求之上,但与传统HTTP请求不同的是,一旦建立连接,服务器就可以持续地向客户端发送消息,直到连接被关闭。客户端接收到的消息通常以JSON或其他格式编码,并且每条消息都包含一个事件类型和数据负载。

SSE具有以下特点:

  1. 单向通信:SSE主要用于服务器向客户端发送数据,而不是双向通信。如果需要客户端向服务器发送数据,通常需要另外的机制,如WebSocket。

  2. 简单高效:SSE使用标准的HTTP协议,不需要额外的库或插件,且相比于WebSocket,它在某些情况下可能更加高效,因为它只需要服务器发送数据,而不需要保持全双工的连接。

  3. 自动重连:如果连接中断,SSE会自动尝试重新连接,这对于需要高可靠性的实时数据推送非常有用。

  4. 跨浏览器支持:大多数现代浏览器都支持SSE,包括Chrome、Firefox、Safari和Edge。

SSE常用于需要实时数据更新的应用场景,如股票价格更新、新闻推送、社交媒体通知等。通过SSE,开发者可以轻松地构建实时交互式的Web应用程序,为用户提供更加丰富和动态的体验。

三、SSE代码实现

在Java中,SseEmitter 是 Spring 框架提供的一个用于服务器发送事件(Server-Sent Events, SSE)的工具。SSE 允许服务器向客户端推送实时信息,客户端通过一个持久的HTTP连接接收这些信息。
SpringBoot的pom依赖如下:

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <!--
           <dependency>
           <groupId>org.springframework.boot</groupId>
           <artifactId>spring-boot-starter-webflux</artifactId>
       </dependency>
       -->

在使用SSE之前也测试了webflux框架实现打字机,效果不好放弃。
为了使用 SseEmitter 实现类似打字机效果的流式输出,你需要创建一个 SseEmitter 实例,然后逐步发送数据给客户端。下面是一个简单的例子:

import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
public class TypewriterController {

    @GetMapping("/typewriter")
    public SseEmitter typewriter() {
        SseEmitter emitter = new SseEmitter();
        
        // 模拟从大语言模型获取数据的过程
        String[] data = {"Hello", "World", "from", "the", "large", "language", "model"};
        
        for (String word : data) {
            try {
                // 模拟打字机效果,每个单词之间暂停100毫秒
                Thread.sleep(100);
                emitter.send(SseEmitter.event().data(word));
            } catch (InterruptedException e) {
                emitter.completeWithError(e);
                return emitter;
            }
        }
        
        emitter.complete();
        return emitter;
    }
}

在这个例子中,我们定义了一个 TypewriterController 类,它有一个 /typewriter 端点。当这个端点被访问时,它会创建一个新的 SseEmitter 对象,并逐个发送字符串数组中的单词。每发送一个单词后,线程会暂停100毫秒来模拟打字机的效果。

客户端可以通过建立一个到 /typewriter 端点的持久连接来接收这些事件。例如,如果你使用JavaScript作为客户端,你可以这样写:

<!DOCTYPE html>
<html>
<head>
    <title>Typewriter Effect</title>
</head>
<body>
    <div id="output"></div>
    <script>
        var source = new EventSource('/typewriter');
        
        source.onmessage = function(event) {
            var outputDiv = document.getElementById('output');
            outputDiv.innerHTML += event.data + ' '; // 将接收到的数据添加到页面中
        };
    </script>
</body>
</html>

这段HTML和JavaScript代码会打开一个到服务器的SSE连接,并在接收到新数据时更新页面的内容。每次收到数据时,都会将其追加到 <div id="output"> 元素中,从而实现类似于打字机逐字显示文本的效果。

请注意,在实际应用中,你可能需要处理更多的细节,比如错误处理、连接关闭时的清理工作等。此外,如果你的大语言模型是通过异步方式生成数据的,你可能还需要考虑如何与 SseEmitter 进行集成,以确保数据能够正确地流式传输到客户端。

四、API方式对接SSE服务端

上面基于JavaScript作为客户端获取返回的流式数据,通过后端的Java代码也可以对接SSE服务端,增加其他逻辑,比如:鉴权、计费、敏感词过滤等。
pom文件:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.3.8.RELEASE</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>com.guo.test</groupId>
    <artifactId>streamClient</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>streamClient</name>
    <description>streamClient</description>
    <properties>
        <java.version>11</java.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>com.squareup.okhttp3</groupId>
            <artifactId>okhttp</artifactId>
            <version>4.10.0</version>
        </dependency>
        <dependency>
            <groupId>com.squareup.okhttp3</groupId>
            <artifactId>okhttp-sse</artifactId>
            <version>4.10.0</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>

要注意下面代码设置接收媒体类型为:
text/event-stream,建议使用`MediaType.TEXT_EVENT_STREAM_VALUE`代替字符串编码。

package com.guo.test.streamclient.client;

import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.ResponseBody;
import org.springframework.http.MediaType;

import java.io.IOException;
import java.util.concurrent.TimeUnit;


public class streamClient {
    private static final String SSE_URL = "http://localhost:8080/llm/stream/query?query=地铁安全门的规范"; // 替换为你的SSE端点地址

    public static void main(String[] args) {
        OkHttpClient client = new OkHttpClient.Builder()
                .readTimeout(0, TimeUnit.MILLISECONDS) // 设置无限读取超时,因为SSE是长连接
                .build();

        Request request = new Request.Builder()
                .url(SSE_URL)
                .header("Accept", "text/event-stream") // 设置接收SSE媒体类型
                .build();

        try (Response response = client.newCall(request).execute()) {
            if (!response.isSuccessful()) {
                throw new IOException("Unexpected code " + response);
            }

            // 获取响应体并读取流
            ResponseBody responseBody = response.body();
            if (responseBody == null) {
                return;
            }

            try (java.io.Reader reader = responseBody.charStream()) {
                char[] buffer = new char[1024];
                int bytesRead;
                while ((bytesRead = reader.read(buffer)) != -1) {
                    String data = new String(buffer, 0, bytesRead);
                    // 处理接收到的SSE数据
                    System.out.println("Received data: " + data);

                    // 查找SSE事件边界(通常是"\n\n")
                    /*int eventBoundary = data.indexOf("\n\n");
                    if (eventBoundary != -1) {
                        String event = data.substring(0, eventBoundary);
                        String eventData = data.substring(eventBoundary + 2);
                        // 处理事件头部和事件数据
                        System.out.println("Event: " + event);
                        System.out.println("Event Data: " + eventData);
                    }*/
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

}

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容