package routes import ( "context" "encoding/json" "fmt" "io" "net/http" "time" "git.x2erp.com/qdy/go-base/authbase" "git.x2erp.com/qdy/go-base/logger" "git.x2erp.com/qdy/go-base/webx" "git.x2erp.com/qdy/go-base/webx/router" "git.x2erp.com/qdy/go-svc-code/internal/opencode" "git.x2erp.com/qdy/go-svc-code/internal/service/event" ) // PromptStreamRequest 流式对话请求 type PromptStreamRequest struct { SessionID string `json:"sessionID" binding:"required"` Parts []opencode.TextPart `json:"parts" binding:"required"` Agent string `json:"agent,omitempty"` Model *opencode.ModelInfo `json:"model,omitempty"` } // RegisterPromptStreamRoutes 注册流式对话路由 func RegisterPromptStreamRoutes(ws *router.RouterService, webService *webx.WebService, client opencode.OpenCodeClient) { // 流式对话需要直接处理 HTTP 流式响应,不能使用标准的路由包装 // 我们直接注册到 webService 的底层路由器 webService.GetRouter().Handle("/api/prompt/stream", StreamPromptHandler(client)) } // StreamPromptHandler 流式对话的 HTTP 处理器(已包含TokenAuth认证) func StreamPromptHandler(client opencode.OpenCodeClient) http.HandlerFunc { // 创建内部处理器 handler := func(w http.ResponseWriter, r *http.Request) { logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 收到流式对话请求: %s %s", r.Method, r.URL.String())) // 解析请求 var req PromptStreamRequest if err := BindJSON(r, &req); err != nil { logger.Error(fmt.Sprintf("🔍 [StreamPromptHandler] 解析请求失败: %v", err)) http.Error(w, fmt.Sprintf("解析请求失败: %v", err), http.StatusBadRequest) return } logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 请求数据: sessionID=%s, agent=%v, parts=%d", req.SessionID, req.Agent, len(req.Parts))) if len(req.Parts) > 0 && req.Parts[0].Text != "" { logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 用户消息: %s", req.Parts[0].Text)) } // 创建 prompt 请求 disablePasteSummary := true prompt := &opencode.PromptRequest{ Parts: req.Parts, Agent: "code-sql", //req.Agent, Model: req.Model, Experimental: &struct { DisablePasteSummary *bool `json:"disable_paste_summary,omitempty"` }{ DisablePasteSummary: &disablePasteSummary, }, } // 设置 SSE 头 w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") // 创建带超时的上下文 - 增加超时时间确保AI有足够时间生成完整响应 ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute) defer cancel() // 获取事件分发器实例 dispatcher := event.GetEventDispatcher(client.GetBaseURL(), client.GetPort()) // 从认证上下文中获取用户ID(用于缓存,按sessionID分发事件) userID := "unknown-user" username := "unknown-user" // TokenAuth中间件通常将用户信息存储在context中 if userVal := r.Context().Value("user"); userVal != nil { if userMap, ok := userVal.(map[string]interface{}); ok { if id, ok := userMap["id"].(string); ok { userID = id } else if id, ok := userMap["user_id"].(string); ok { userID = id } if name, ok := userMap["username"].(string); ok { username = name } } } logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 用户ID: %s, 用户名: %s, 会话ID: %s", userID, username, req.SessionID)) // 使用事件分发器订阅会话事件 ch, err := dispatcher.Subscribe(req.SessionID, userID) if err != nil { logger.Error(fmt.Sprintf("🔍 [StreamPromptHandler] 事件分发器订阅失败: %v", err)) http.Error(w, fmt.Sprintf("订阅失败: %v", err), http.StatusInternalServerError) return } defer dispatcher.Unsubscribe(req.SessionID, ch) // 发送异步请求到 opencode(触发AI处理) logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 发送异步请求到 opencode, sessionID=%s", req.SessionID)) // 忽略返回的通道,事件通过EventDispatcher分发 _, err = client.SendPromptStream(ctx, req.SessionID, prompt) if err != nil { logger.Error(fmt.Sprintf("🔍 [StreamPromptHandler] 发送请求失败: %v", err)) http.Error(w, fmt.Sprintf("发送请求失败: %v", err), http.StatusInternalServerError) return } logger.Debug("🔍 [StreamPromptHandler] 异步请求发送成功,等待事件流") // 发送流式响应 flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming unsupported", http.StatusInternalServerError) return } logger.Debug("🔍 [StreamPromptHandler] 开始发送流式响应") eventCount := 0 // 创建心跳定时器,每30秒发送一次心跳保活(SSE注释格式) heartbeatTicker := time.NewTicker(30 * time.Second) defer heartbeatTicker.Stop() // 发送初始心跳,确保连接立即活跃 fmt.Fprintf(w, ": heartbeat\n\n") flusher.Flush() for { select { case data, ok := <-ch: if !ok { // 通道关闭,发送结束标记 logger.Info(fmt.Sprintf("🔍 [StreamPromptHandler] 流式通道关闭,发送DONE标记,共发送 %d 个事件", eventCount)) fmt.Fprintf(w, "data: [DONE]\n\n") flusher.Flush() return } // 发送 SSE 数据 eventCount++ // 只记录每第10个事件或重要事件,减少日志量 if eventCount%10 == 0 { preview := data if len(preview) > 100 { preview = preview[:100] + "..." } //logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 发送SSE数据[%d]: %s", eventCount, preview)) } // 发送 SSE 数据,opencode 数据已包含 payload 字段,不需要额外包装 var wrappedData string if data == "[DONE]" { wrappedData = "[DONE]" } else { // 尝试解析为JSON,检查是否已有payload字段 var jsonData interface{} if err := json.Unmarshal([]byte(data), &jsonData); err == nil { // 去重处理:移除message.updated事件中的重复content,并过滤不必要的事件 jsonData = removeDuplicateContent(jsonData) // 如果返回nil,跳过此事件 if jsonData == nil { continue } // 检查是否是对象且包含payload字段 if obj, ok := jsonData.(map[string]interface{}); ok && obj["payload"] != nil { // 已有payload字段,直接发送原始数据 wrappedData = data } else { // 没有payload字段,包装在payload对象中 wrapped := map[string]interface{}{ "payload": jsonData, } wrappedBytes, _ := json.Marshal(wrapped) wrappedData = string(wrappedBytes) } } else { // 不是JSON,按原样发送 wrappedData = data } } fmt.Fprintf(w, "data: %s\n\n", wrappedData) flusher.Flush() case <-ctx.Done(): logger.Debug("🔍 [StreamPromptHandler] 上下文超时") return case <-heartbeatTicker.C: // 发送心跳保活(SSE注释格式) logger.Debug("🔍 [StreamPromptHandler] 发送心跳保活") fmt.Fprintf(w, ": heartbeat\n\n") flusher.Flush() case <-r.Context().Done(): logger.Debug("🔍 [StreamPromptHandler] 客户端断开连接") return } } } // 包装TokenAuth中间件 return authbase.TokenAuth(http.HandlerFunc(handler)).ServeHTTP } // BindJSON 简单的 JSON 绑定函数 func BindJSON(r *http.Request, v interface{}) error { body, err := io.ReadAll(r.Body) if err != nil { return err } defer r.Body.Close() return json.Unmarshal(body, v) } // removeDuplicateContent 移除message.updated事件中的重复content,避免前端重复显示 func removeDuplicateContent(data interface{}) interface{} { // 检查是否为map obj, ok := data.(map[string]interface{}) if !ok { return data } // 递归处理payload字段 if payload, ok := obj["payload"].(map[string]interface{}); ok { obj["payload"] = removeDuplicateContent(payload) } // 如果payload是数组(可能嵌套),处理每个元素 if payloadArr, ok := obj["payload"].([]interface{}); ok { for i, item := range payloadArr { if itemMap, ok := item.(map[string]interface{}); ok { payloadArr[i] = removeDuplicateContent(itemMap) } } } // 检查type字段 typeVal, hasType := obj["type"] if !hasType { return obj } typeStr, ok := typeVal.(string) if !ok { return obj } // 事件过滤策略:保守过滤,保留大多数事件以确保连接稳定 switch typeStr { case "message.updated": // 对于message.updated事件,移除content字段避免重复,但不过滤事件本身 if properties, ok := obj["properties"].(map[string]interface{}); ok { if info, ok := properties["info"].(map[string]interface{}); ok { // 移除content字段,避免重复 delete(info, "content") } } // 保留事件,不过滤 return obj case "session.status": // 保留状态事件,可能包含重要状态信息 return obj case "session.diff": // 保留差异事件,避免客户端逻辑中断 return obj case "server.heartbeat": // 保留心跳事件,保持连接活跃 - 非常重要! return obj case "session.idle": // 保留事件,不过滤 return obj } // 其他所有事件类型都保留,包括message.part.updated(关键事件) return obj }