| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- package routes
-
- import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "time"
-
- "git.x2erp.com/qdy/go-base/authbase"
- "git.x2erp.com/qdy/go-base/ctx"
- "git.x2erp.com/qdy/go-base/model/response"
- "git.x2erp.com/qdy/go-base/webx/router"
- "git.x2erp.com/qdy/go-svc-code/internal/opencode"
- )
-
- // PromptStreamRequest 流式对话请求
- type PromptStreamRequest struct {
- SessionID string `json:"sessionID" binding:"required"`
- Parts []opencode.TextPart `json:"parts" binding:"required"`
- Agent string `json:"agent,omitempty"`
- Model *opencode.ModelInfo `json:"model,omitempty"`
- }
-
- // RegisterPromptStreamRoutes 注册流式对话路由
- func RegisterPromptStreamRoutes(ws *router.RouterService, client *opencode.Client) {
- // 流式对话(需要Token认证)
- ws.POST("/api/prompt/stream",
- func(req *PromptStreamRequest, ctx context.Context, reqCtx *ctx.RequestContext) (*response.QueryResult[map[string]interface{}], error) {
- // 流式响应需要直接写入 HTTP 响应,不能使用标准的路由返回值
- // 这里返回一个特殊结果,指示上层使用流式处理
- return &response.QueryResult[map[string]interface{}]{
- Success: false,
- Message: "流式端点需要特殊处理",
- }, nil
- },
- ).Use(authbase.TokenAuth).Desc("流式对话(Server-Sent Events)").Register()
-
- // 流式对话的原始 HTTP 处理器
- // 注意:这个路由需要直接处理 HTTP 流,不能使用标准的 router 包装
- // 我们将注册一个原始的 HTTP 处理器到 gin 引擎
- // 这个函数将由 main.go 中的额外注册调用
- }
-
- // StreamPromptHandler 流式对话的 HTTP 处理器(已包含TokenAuth认证)
- func StreamPromptHandler(client opencode.OpenCodeClient) http.HandlerFunc {
- // 创建内部处理器
- handler := func(w http.ResponseWriter, r *http.Request) {
- // 解析请求
- var req PromptStreamRequest
- if err := BindJSON(r, &req); err != nil {
- http.Error(w, fmt.Sprintf("解析请求失败: %v", err), http.StatusBadRequest)
- return
- }
-
- // 创建 prompt 请求
- prompt := &opencode.PromptRequest{
- Parts: req.Parts,
- Agent: req.Agent,
- Model: req.Model,
- }
-
- // 设置 SSE 头
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("Access-Control-Allow-Origin", "*")
-
- // 创建带超时的上下文
- ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
- defer cancel()
-
- // 获取流式响应通道
- ch, err := client.SendPromptStream(ctx, req.SessionID, prompt)
- if err != nil {
- http.Error(w, fmt.Sprintf("发送流式请求失败: %v", err), http.StatusInternalServerError)
- return
- }
-
- // 发送流式响应
- flusher, ok := w.(http.Flusher)
- if !ok {
- http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
- return
- }
-
- for {
- select {
- case data, ok := <-ch:
- if !ok {
- // 通道关闭,发送结束标记
- fmt.Fprintf(w, "data: [DONE]\n\n")
- flusher.Flush()
- return
- }
- // 发送 SSE 数据
- fmt.Fprintf(w, "data: %s\n\n", data)
- flusher.Flush()
- case <-ctx.Done():
- return
- case <-r.Context().Done():
- return
- }
- }
- }
-
- // 包装TokenAuth中间件
- return authbase.TokenAuth(http.HandlerFunc(handler)).ServeHTTP
- }
-
- // BindJSON 简单的 JSON 绑定函数
- func BindJSON(r *http.Request, v interface{}) error {
- body, err := io.ReadAll(r.Body)
- if err != nil {
- return err
- }
- defer r.Body.Close()
-
- return json.Unmarshal(body, v)
- }
|