package routes import ( "context" "encoding/json" "fmt" "io" "net/http" "time" "git.x2erp.com/qdy/go-base/authbase" "git.x2erp.com/qdy/go-base/ctx" "git.x2erp.com/qdy/go-base/model/response" "git.x2erp.com/qdy/go-base/webx/router" "git.x2erp.com/qdy/go-svc-code/internal/opencode" ) // PromptStreamRequest 流式对话请求 type PromptStreamRequest struct { SessionID string `json:"sessionID" binding:"required"` Parts []opencode.TextPart `json:"parts" binding:"required"` Agent string `json:"agent,omitempty"` Model *opencode.ModelInfo `json:"model,omitempty"` } // RegisterPromptStreamRoutes 注册流式对话路由 func RegisterPromptStreamRoutes(ws *router.RouterService, client *opencode.Client) { // 流式对话(需要Token认证) ws.POST("/api/prompt/stream", func(req *PromptStreamRequest, ctx context.Context, reqCtx *ctx.RequestContext) (*response.QueryResult[map[string]interface{}], error) { // 流式响应需要直接写入 HTTP 响应,不能使用标准的路由返回值 // 这里返回一个特殊结果,指示上层使用流式处理 return &response.QueryResult[map[string]interface{}]{ Success: false, Message: "流式端点需要特殊处理", }, nil }, ).Use(authbase.TokenAuth).Desc("流式对话(Server-Sent Events)").Register() // 流式对话的原始 HTTP 处理器 // 注意:这个路由需要直接处理 HTTP 流,不能使用标准的 router 包装 // 我们将注册一个原始的 HTTP 处理器到 gin 引擎 // 这个函数将由 main.go 中的额外注册调用 } // StreamPromptHandler 流式对话的 HTTP 处理器(已包含TokenAuth认证) func StreamPromptHandler(client opencode.OpenCodeClient) http.HandlerFunc { // 创建内部处理器 handler := func(w http.ResponseWriter, r *http.Request) { // 解析请求 var req PromptStreamRequest if err := BindJSON(r, &req); err != nil { http.Error(w, fmt.Sprintf("解析请求失败: %v", err), http.StatusBadRequest) return } // 创建 prompt 请求 prompt := &opencode.PromptRequest{ Parts: req.Parts, Agent: req.Agent, Model: req.Model, } // 设置 SSE 头 w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") // 创建带超时的上下文 ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute) defer cancel() // 获取流式响应通道 ch, err := client.SendPromptStream(ctx, req.SessionID, prompt) if err != nil { http.Error(w, fmt.Sprintf("发送流式请求失败: %v", err), http.StatusInternalServerError) return } // 发送流式响应 flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming unsupported", http.StatusInternalServerError) return } for { select { case data, ok := <-ch: if !ok { // 通道关闭,发送结束标记 fmt.Fprintf(w, "data: [DONE]\n\n") flusher.Flush() return } // 发送 SSE 数据 fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() case <-ctx.Done(): return case <-r.Context().Done(): return } } } // 包装TokenAuth中间件 return authbase.TokenAuth(http.HandlerFunc(handler)).ServeHTTP } // BindJSON 简单的 JSON 绑定函数 func BindJSON(r *http.Request, v interface{}) error { body, err := io.ReadAll(r.Body) if err != nil { return err } defer r.Body.Close() return json.Unmarshal(body, v) }