写在前面
SSE是LLM进行流式通信常用的技术方案, 下图是 kimi 的示例
kimi回答时使用SSE
SSE 简介
Server-Sent Events(SSE)是一种允许服务器向客户端实时推送数据的技术。它基于HTTP协议,允许服务器通过一个持久的HTTP连接向客户端发送事件流。以下是SSE的一些关键点:
SSE的本质:SSE利用HTTP协议的流信息(streaming)特性,实现服务器向客户端的单向通信。客户端保持连接打开,等待服务器发送新的数据流。
-
SSE的特点:
- 使用HTTP协议,现有的服务器软件都支持。
- 轻量级,使用简单,与WebSocket相比,协议相对简单。
- 默认支持断线重连,而WebSocket需要自己实现。
- 一般只用来传送文本数据,二进制数据需要编码后传送。
- 支持自定义发送的消息类型。
-
客户端API:
-
EventSource
对象用于创建与服务器的连接并接收事件。 - 通过监听
message
事件接收服务器发送的消息。 - 可以监听自定义事件,不仅限于
message
事件。
-
-
服务器端发送事件:
- 服务器端脚本需要使用
text/event-stream
MIME类型响应内容。 - 每个通知以文本块形式发送,并以一对换行符结尾。
- 消息由字段组成,包括
event
、data
、id
和retry
等。
- 服务器端脚本需要使用
-
事件流格式:
- 事件流是一个简单的文本数据流,使用UTF-8编码。
- 消息由一对换行符分开,以冒号开头的行为注释行,会被忽略。
- 每条消息由一行或多行文字组成,列出该消息的字段。
-
浏览器兼容性:
- SSE在现代浏览器中得到了广泛支持,除了IE/Edge外,其他浏览器如Firefox、Chrome、Safari等都支持SSE。
SSE适用于需要服务器向客户端单向实时推送数据的场景,如实时通知、股票行情、新闻推送等。它是一种有效降低服务器负载和网络资源消耗的技术,通过服务器主动向客户端发送更新事件,实现实时通信。
py 中使用 SSE
- py 中异步:
async + await
- py 中流式接收 SSE:
httpx
包- py 中流式返回 SSE:
from fastapi.responses import StreamingResponse as FastapiStreamingResponse
- 路由定义
@router.post("/stream", tags=["chat"])
async def streaming_chat(
params: QuestionParams, current_user: TokenData = Depends(get_current_user)
):
if not params.user_id:
params.user_id = current_user.uid
async_generator = RetrievalController().stream_answer(params)
return StreamingResponse(async_generator)
- 流式输出定义
from typing import Mapping
from fastapi.responses import StreamingResponse as FastapiStreamingResponse
from starlette.background import BackgroundTask
from starlette.responses import ContentStream
class StreamingResponse(FastapiStreamingResponse):
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
default_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache", "X-Accel-Buffering": "no"}
default_headers.update(headers or {})
super().__init__(content, status_code, default_headers, media_type, background)
- 流式接收并流式返回
@LogDecorate(
func_name="retrieval_controller::process_stream_answer", raise_exc=True
)
async def stream_answer(self, params: QuestionParams, model: int = 1):
"""
:param model: 1-8B 2-32B
"""
session_id = params.session_id
if params.new_session:
session_id = str(uuid.uuid1()).replace("-", "")
request_body = dict(
messages=msgs,
user_id=params.user_id,
)
stream_answer_api = f"{AI_DOMAIN}{STREAM_ANSWER_API}"
answer = ""
# 流式接收
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
stream_answer_api,
json=request_body,
timeout=60,
headers=dict(trace_id=get_req_ctx("trace_id")),
) as response:
async for chunk in response.aiter_text():
answer += chunk
yield self.get_yield_data(
{"content": chunk, "create_at": int(time.time() * 1000)}
)
yield self.get_yield_data("[DONE]")
yield self.get_yield_data({"session_id": session_id})
yield self.get_yield_data("[END]")
# 落库
await user_qa_dao.save_user_qa(params.q, answer, session_id, params.user_id)
Go中使用SSE
使用
https://github.com/hertz-contrib/sse
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/google/uuid"
"github.com/hertz-contrib/sse"
"github.com/spf13/cast"
)
func ChatStream(ctx context.Context, c *app.RequestContext) {
u := ctl.CtxUser(c)
var req struct {
Query string `form:"query" json:"query"`
Model int `form:"model" json:"model"`
Sid string `form:"sid" json:"sid"` // session id
}
if err := c.BindAndValidate(&req); err != nil {
utils.RespErr(c, err)
return
}
// 聊天消息支持多轮对话
var sid string
if req.Sid != "" {
sid = req.Sid
} else {
sid = uuid.New().String()
}
msg := chat.SaveUserMsg(ctx, sid, req.Query)
content := &chat.Content{
Messages: msg,
UserId: cast.ToString(u.ID),
UserName: u.Name,
}
b, _ := json.Marshal(content)
// https://github.com/hertz-contrib/sse/blob/main/examples/client/quickstart/main.go
cli := sse.NewClient(conf.GetConf().Dev.AIDomain + "xxx")
cli.SetMethod("POST")
cli.SetHeaders(map[string]string{"Content-Type": "application/json", "trace_id": httpx.TraceId()})
cli.SetBody(b)
var ans, allAns string // AI 返回内容
var flag bool // reply正文标识
events := make(chan *sse.Event)
errChan := make(chan error)
s := sse.NewStream(c)
go func() {
cErr := cli.Subscribe(func(msg *sse.Event) {
if msg != nil && msg.Data != nil {
events <- msg
return
}
})
errChan <- cErr
}()
for {
select {
case e := <-events:
m := map[string]any{}
_ = json.Unmarshal(e.Data, &m)
if v, ok := m["content"]; ok {
allAns += v.(string)
if flag {
ans += v.(string)
}
if v == "__REPLY_START__" {
flag = true
}
da := map[string]any{
"content": v,
"create_at": time.Now().Unix(),
}
jsonData, _ := json.Marshal(da)
hlog.Info("publish event data = %s", string(jsonData))
_ = s.Publish(&sse.Event{Data: jsonData})
} else {
hlog.Info("invalid event data = %s", string(e.Data))
}
case err := <-errChan:
if err != nil {
hlog.CtxErrorf(context.Background(), "err = %s", err.Error())
}
chat.SaveAssistantMsg(ctx, sid, ans, msg)
chat.SaveQA(u.ID, sid, req.Query, allAns)
_ = s.Publish(&sse.Event{Data: []byte("[DONE]")})
_ = s.Publish(&sse.Event{Data: []byte(fmt.Sprintf(`{"session_id": "%s"}`, sid))})
_ = s.Publish(&sse.Event{Data: []byte("[END]")})
hlog.Info("cli get all event")
return
}
}
}
写在最后
需要注意的点
- py 使用
httpx
接收 SSE 流式数据, 对数据结构没有要求, 比如 SSE event 常见的data: xxx
, 可以不带data
标识返回 - go 中使用
https://github.com/hertz-contrib/sse
接收 SSE 流式数据- 底层会解析 SSE 数据格式, 需要判断
data
标识, 如果没有, 会导致解析失败 - 如果数据包含
\n
换行, 也会导致数据解析失败, 比较简单的做法data: json 格式数据
- 底层会解析 SSE 数据格式, 需要判断
// go 中对应 SSE 库数据解析源码
func (c *Client) processEvent(msg []byte) (event *Event, err error) {
var e Event
if len(msg) < 1 {
return nil, fmt.Errorf("event message was empty")
}
// Normalize the crlf to lf to make it easier to split the lines.
// Split the line by "\n" or "\r", per the spec.
for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) {
switch {
case bytes.HasPrefix(line, headerID):
e.ID = string(append([]byte(nil), trimHeader(len(headerID), line)...))
case bytes.HasPrefix(line, headerData):
// The spec allows for multiple data fields per event, concatenated them with "\n".
e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...)
// The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body.
case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))):
e.Data = append(e.Data, byte('\n'))
case bytes.HasPrefix(line, headerEvent):
e.Event = string(append([]byte(nil), trimHeader(len(headerEvent), line)...))
case bytes.HasPrefix(line, headerRetry):
e.Retry, err = strconv.ParseUint(b2s(append([]byte(nil), trimHeader(len(headerRetry), line)...)), 10, 64)
if err != nil {
return nil, fmt.Errorf("process message `retry` failed, err is %s", err)
}
default:
// Ignore any garbage that doesn't match what we're looking for.
}
}
// Trim the last "\n" per the spec.
e.Data = bytes.TrimSuffix(e.Data, []byte("\n"))
if c.encodingBase64 {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data)))
n, err := base64.StdEncoding.Decode(buf, e.Data)
if err != nil {
err = fmt.Errorf("failed to decode event message: %s", err)
return &e, err
}
e.Data = buf[:n]
}
return &e, err
}