Aucune description
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

prompt_stream_routes.go 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. package routes
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "time"
  9. "git.x2erp.com/qdy/go-base/authbase"
  10. goctx "git.x2erp.com/qdy/go-base/ctx"
  11. "git.x2erp.com/qdy/go-base/logger"
  12. "git.x2erp.com/qdy/go-base/webx"
  13. "git.x2erp.com/qdy/go-base/webx/router"
  14. "git.x2erp.com/qdy/go-svc-code/internal/opencode"
  15. "git.x2erp.com/qdy/go-svc-code/internal/service/event"
  16. )
  17. // PromptStreamRequest 流式对话请求
  18. type PromptStreamRequest struct {
  19. SessionID string `json:"sessionID" binding:"required"`
  20. Parts []opencode.TextPart `json:"parts" binding:"required"`
  21. Agent string `json:"agent,omitempty"`
  22. Model *opencode.ModelInfo `json:"model,omitempty"`
  23. }
  24. // RegisterPromptStreamRoutes 注册流式对话路由
  25. func RegisterPromptStreamRoutes(ws *router.RouterService, webService *webx.WebService, client opencode.OpenCodeClient) {
  26. // 流式对话需要直接处理 HTTP 流式响应,不能使用标准的路由包装
  27. // 我们直接注册到 webService 的底层路由器
  28. webService.GetRouter().Handle("/api/prompt/stream", StreamPromptHandler(client))
  29. }
  30. // StreamPromptHandler 流式对话的 HTTP 处理器(已包含TokenAuth认证)
  31. func StreamPromptHandler(client opencode.OpenCodeClient) http.HandlerFunc {
  32. // 创建内部处理器
  33. handler := func(w http.ResponseWriter, r *http.Request) {
  34. logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 收到流式对话请求: %s %s", r.Method, r.URL.String()))
  35. // 解析请求
  36. var req PromptStreamRequest
  37. if err := BindJSON(r, &req); err != nil {
  38. logger.Error(fmt.Sprintf("🔍 [StreamPromptHandler] 解析请求失败: %v", err))
  39. http.Error(w, fmt.Sprintf("解析请求失败: %v", err), http.StatusBadRequest)
  40. return
  41. }
  42. logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 请求数据: sessionID=%s, agent=%v, parts=%d",
  43. req.SessionID, req.Agent, len(req.Parts)))
  44. if len(req.Parts) > 0 && req.Parts[0].Text != "" {
  45. logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 用户消息: %s", req.Parts[0].Text))
  46. }
  47. // 创建 prompt 请求
  48. disablePasteSummary := true
  49. prompt := &opencode.PromptRequest{
  50. Parts: req.Parts,
  51. Agent: "code-sql", //req.Agent,
  52. Model: req.Model,
  53. Experimental: &struct {
  54. DisablePasteSummary *bool `json:"disable_paste_summary,omitempty"`
  55. }{
  56. DisablePasteSummary: &disablePasteSummary,
  57. },
  58. }
  59. // 设置 SSE 头
  60. w.Header().Set("Content-Type", "text/event-stream")
  61. w.Header().Set("Cache-Control", "no-cache")
  62. w.Header().Set("Connection", "keep-alive")
  63. w.Header().Set("Access-Control-Allow-Origin", "*")
  64. // 创建带超时的上下文 - 增加超时时间确保AI有足够时间生成完整响应
  65. ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
  66. defer cancel()
  67. // 获取事件分发器实例和订阅服务
  68. dispatcher := event.GetEventDispatcher(client.GetBaseURL(), client.GetPort())
  69. subscriptionService := event.GetSubscriptionService()
  70. // 从认证上下文中获取用户ID(用于缓存,按sessionID分发事件)
  71. userID := "unknown-user"
  72. username := "unknown-user"
  73. // TokenAuth中间件通常将用户信息存储在context中
  74. if userVal := r.Context().Value("user"); userVal != nil {
  75. if userMap, ok := userVal.(map[string]interface{}); ok {
  76. if id, ok := userMap["id"].(string); ok {
  77. userID = id
  78. } else if id, ok := userMap["user_id"].(string); ok {
  79. userID = id
  80. }
  81. if name, ok := userMap["username"].(string); ok {
  82. username = name
  83. }
  84. }
  85. }
  86. logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 用户ID: %s, 用户名: %s, 会话ID: %s", userID, username, req.SessionID))
  87. // 构建请求上下文
  88. reqCtx := &goctx.RequestContext{
  89. UserID: userID,
  90. Username: username,
  91. // 可以添加更多字段:TraceID、TenantID等
  92. TraceID: fmt.Sprintf("stream-%d", time.Now().UnixNano()),
  93. }
  94. // 使用订阅服务注册会话和订阅事件(如果可用)
  95. var ch <-chan string
  96. var err error
  97. if subscriptionService != nil {
  98. // 使用订阅服务(集成MongoDB和上下文)
  99. subscriptionService.RegisterSessionWithContext(req.SessionID, reqCtx)
  100. ch, err = subscriptionService.SubscribeWithContext(req.SessionID, reqCtx)
  101. if err != nil {
  102. logger.Error(fmt.Sprintf("🔍 [StreamPromptHandler] 订阅服务订阅失败: %v", err))
  103. http.Error(w, fmt.Sprintf("订阅失败: %v", err), http.StatusInternalServerError)
  104. return
  105. }
  106. defer subscriptionService.UnsubscribeWithContext(req.SessionID, ch, reqCtx)
  107. } else {
  108. // 回退到直接使用事件分发器
  109. logger.Warn("⚠️ [StreamPromptHandler] 订阅服务不可用,使用事件分发器回退")
  110. dispatcher.RegisterSession(req.SessionID, userID)
  111. ch, err = dispatcher.Subscribe(req.SessionID, userID)
  112. if err != nil {
  113. logger.Error(fmt.Sprintf("🔍 [StreamPromptHandler] 事件分发器订阅失败: %v", err))
  114. http.Error(w, fmt.Sprintf("订阅失败: %v", err), http.StatusInternalServerError)
  115. return
  116. }
  117. defer dispatcher.Unsubscribe(req.SessionID, ch)
  118. }
  119. // 发送异步请求到 opencode(触发AI处理)
  120. logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 发送异步请求到 opencode, sessionID=%s", req.SessionID))
  121. // 忽略返回的通道,事件通过EventDispatcher分发
  122. _, err = client.SendPromptStream(ctx, req.SessionID, prompt)
  123. if err != nil {
  124. logger.Error(fmt.Sprintf("🔍 [StreamPromptHandler] 发送请求失败: %v", err))
  125. http.Error(w, fmt.Sprintf("发送请求失败: %v", err), http.StatusInternalServerError)
  126. return
  127. }
  128. logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 异步请求发送成功,等待事件流"))
  129. // 发送流式响应
  130. flusher, ok := w.(http.Flusher)
  131. if !ok {
  132. http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
  133. return
  134. }
  135. logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 开始发送流式响应"))
  136. eventCount := 0
  137. // 创建心跳定时器,每30秒发送一次心跳保活(SSE注释格式)
  138. heartbeatTicker := time.NewTicker(30 * time.Second)
  139. defer heartbeatTicker.Stop()
  140. // 发送初始心跳,确保连接立即活跃
  141. fmt.Fprintf(w, ": heartbeat\n\n")
  142. flusher.Flush()
  143. for {
  144. select {
  145. case data, ok := <-ch:
  146. if !ok {
  147. // 通道关闭,发送结束标记
  148. logger.Info(fmt.Sprintf("🔍 [StreamPromptHandler] 流式通道关闭,发送DONE标记,共发送 %d 个事件", eventCount))
  149. fmt.Fprintf(w, "data: [DONE]\n\n")
  150. flusher.Flush()
  151. return
  152. }
  153. // 发送 SSE 数据
  154. eventCount++
  155. // 只记录每第10个事件或重要事件,减少日志量
  156. if eventCount%10 == 0 {
  157. preview := data
  158. if len(preview) > 100 {
  159. preview = preview[:100] + "..."
  160. }
  161. logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 发送SSE数据[%d]: %s", eventCount, preview))
  162. }
  163. // 发送 SSE 数据,opencode 数据已包含 payload 字段,不需要额外包装
  164. var wrappedData string
  165. if data == "[DONE]" {
  166. wrappedData = "[DONE]"
  167. } else {
  168. // 尝试解析为JSON,检查是否已有payload字段
  169. var jsonData interface{}
  170. if err := json.Unmarshal([]byte(data), &jsonData); err == nil {
  171. // 去重处理:移除message.updated事件中的重复content,并过滤不必要的事件
  172. jsonData = removeDuplicateContent(jsonData)
  173. // 如果返回nil,跳过此事件
  174. if jsonData == nil {
  175. continue
  176. }
  177. // 检查是否是对象且包含payload字段
  178. if obj, ok := jsonData.(map[string]interface{}); ok && obj["payload"] != nil {
  179. // 已有payload字段,直接发送原始数据
  180. wrappedData = data
  181. } else {
  182. // 没有payload字段,包装在payload对象中
  183. wrapped := map[string]interface{}{
  184. "payload": jsonData,
  185. }
  186. wrappedBytes, _ := json.Marshal(wrapped)
  187. wrappedData = string(wrappedBytes)
  188. }
  189. } else {
  190. // 不是JSON,按原样发送
  191. wrappedData = data
  192. }
  193. }
  194. fmt.Fprintf(w, "data: %s\n\n", wrappedData)
  195. flusher.Flush()
  196. case <-ctx.Done():
  197. logger.Info(fmt.Sprintf("🔍 [StreamPromptHandler] 上下文超时"))
  198. return
  199. case <-heartbeatTicker.C:
  200. // 发送心跳保活(SSE注释格式)
  201. logger.Debug(fmt.Sprintf("🔍 [StreamPromptHandler] 发送心跳保活"))
  202. fmt.Fprintf(w, ": heartbeat\n\n")
  203. flusher.Flush()
  204. case <-r.Context().Done():
  205. logger.Info(fmt.Sprintf("🔍 [StreamPromptHandler] 客户端断开连接"))
  206. return
  207. }
  208. }
  209. }
  210. // 包装TokenAuth中间件
  211. return authbase.TokenAuth(http.HandlerFunc(handler)).ServeHTTP
  212. }
  213. // BindJSON 简单的 JSON 绑定函数
  214. func BindJSON(r *http.Request, v interface{}) error {
  215. body, err := io.ReadAll(r.Body)
  216. if err != nil {
  217. return err
  218. }
  219. defer r.Body.Close()
  220. return json.Unmarshal(body, v)
  221. }
  222. // removeDuplicateContent 移除message.updated事件中的重复content,避免前端重复显示
  223. func removeDuplicateContent(data interface{}) interface{} {
  224. // 检查是否为map
  225. obj, ok := data.(map[string]interface{})
  226. if !ok {
  227. return data
  228. }
  229. // 递归处理payload字段
  230. if payload, ok := obj["payload"].(map[string]interface{}); ok {
  231. obj["payload"] = removeDuplicateContent(payload)
  232. }
  233. // 如果payload是数组(可能嵌套),处理每个元素
  234. if payloadArr, ok := obj["payload"].([]interface{}); ok {
  235. for i, item := range payloadArr {
  236. if itemMap, ok := item.(map[string]interface{}); ok {
  237. payloadArr[i] = removeDuplicateContent(itemMap)
  238. }
  239. }
  240. }
  241. // 检查type字段
  242. typeVal, hasType := obj["type"]
  243. if !hasType {
  244. return obj
  245. }
  246. typeStr, ok := typeVal.(string)
  247. if !ok {
  248. return obj
  249. }
  250. // 事件过滤策略:保守过滤,保留大多数事件以确保连接稳定
  251. switch typeStr {
  252. case "message.updated":
  253. // 对于message.updated事件,移除content字段避免重复,但不过滤事件本身
  254. if properties, ok := obj["properties"].(map[string]interface{}); ok {
  255. if info, ok := properties["info"].(map[string]interface{}); ok {
  256. // 移除content字段,避免重复
  257. delete(info, "content")
  258. }
  259. }
  260. // 保留事件,不过滤
  261. return obj
  262. case "session.status":
  263. // 保留状态事件,可能包含重要状态信息
  264. return obj
  265. case "session.diff":
  266. // 保留差异事件,避免客户端逻辑中断
  267. return obj
  268. case "server.heartbeat":
  269. // 保留心跳事件,保持连接活跃 - 非常重要!
  270. return obj
  271. case "session.idle":
  272. // 保留事件,不过滤
  273. return obj
  274. }
  275. // 其他所有事件类型都保留,包括message.part.updated(关键事件)
  276. return obj
  277. }