説明なし
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

prompt_stream_routes.go 8.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. package routes
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "time"
  9. "git.x2erp.com/qdy/go-base/authbase"
  10. "git.x2erp.com/qdy/go-base/webx"
  11. "git.x2erp.com/qdy/go-base/webx/router"
  12. "git.x2erp.com/qdy/go-svc-code/internal/opencode"
  13. "git.x2erp.com/qdy/go-svc-code/internal/service/event"
  14. )
  15. // PromptStreamRequest 流式对话请求
  16. type PromptStreamRequest struct {
  17. SessionID string `json:"sessionID" binding:"required"`
  18. Parts []opencode.TextPart `json:"parts" binding:"required"`
  19. Agent string `json:"agent,omitempty"`
  20. Model *opencode.ModelInfo `json:"model,omitempty"`
  21. }
  22. // RegisterPromptStreamRoutes 注册流式对话路由
  23. func RegisterPromptStreamRoutes(ws *router.RouterService, webService *webx.WebService, client opencode.OpenCodeClient) {
  24. // 流式对话需要直接处理 HTTP 流式响应,不能使用标准的路由包装
  25. // 我们直接注册到 webService 的底层路由器
  26. webService.GetRouter().Handle("/api/prompt/stream", StreamPromptHandler(client))
  27. }
  28. // StreamPromptHandler 流式对话的 HTTP 处理器(已包含TokenAuth认证)
  29. func StreamPromptHandler(client opencode.OpenCodeClient) http.HandlerFunc {
  30. // 创建内部处理器
  31. handler := func(w http.ResponseWriter, r *http.Request) {
  32. fmt.Printf("🔍 [StreamPromptHandler] 收到流式对话请求: %s %s\n", r.Method, r.URL.String())
  33. // 解析请求
  34. var req PromptStreamRequest
  35. if err := BindJSON(r, &req); err != nil {
  36. fmt.Printf("🔍 [StreamPromptHandler] 解析请求失败: %v\n", err)
  37. http.Error(w, fmt.Sprintf("解析请求失败: %v", err), http.StatusBadRequest)
  38. return
  39. }
  40. fmt.Printf("🔍 [StreamPromptHandler] 请求数据: sessionID=%s, agent=%v, parts=%d\n",
  41. req.SessionID, req.Agent, len(req.Parts))
  42. if len(req.Parts) > 0 && req.Parts[0].Text != "" {
  43. fmt.Printf("🔍 [StreamPromptHandler] 用户消息: %s\n", req.Parts[0].Text)
  44. }
  45. // 创建 prompt 请求
  46. prompt := &opencode.PromptRequest{
  47. Parts: req.Parts,
  48. Agent: "code-sql", //req.Agent,
  49. Model: req.Model,
  50. }
  51. // 设置 SSE 头
  52. w.Header().Set("Content-Type", "text/event-stream")
  53. w.Header().Set("Cache-Control", "no-cache")
  54. w.Header().Set("Connection", "keep-alive")
  55. w.Header().Set("Access-Control-Allow-Origin", "*")
  56. // 创建带超时的上下文 - 增加超时时间确保AI有足够时间生成完整响应
  57. ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
  58. defer cancel()
  59. // 获取事件分发器实例
  60. dispatcher := event.GetEventDispatcher(client.GetBaseURL(), client.GetPort())
  61. // 从认证上下文中获取用户ID(用于缓存,按sessionID分发事件)
  62. userID := "unknown-user"
  63. // TokenAuth中间件通常将用户信息存储在context中
  64. if userVal := r.Context().Value("user"); userVal != nil {
  65. if userMap, ok := userVal.(map[string]interface{}); ok {
  66. if id, ok := userMap["id"].(string); ok {
  67. userID = id
  68. } else if id, ok := userMap["user_id"].(string); ok {
  69. userID = id
  70. }
  71. }
  72. }
  73. fmt.Printf("🔍 [StreamPromptHandler] 用户ID: %s, 会话ID: %s\n", userID, req.SessionID)
  74. // 注册会话到缓存
  75. dispatcher.RegisterSession(req.SessionID, userID)
  76. // 订阅该会话的事件
  77. ch, err := dispatcher.Subscribe(req.SessionID, userID)
  78. if err != nil {
  79. fmt.Printf("🔍 [StreamPromptHandler] 订阅事件失败: %v\n", err)
  80. http.Error(w, fmt.Sprintf("订阅事件失败: %v", err), http.StatusInternalServerError)
  81. return
  82. }
  83. defer dispatcher.Unsubscribe(req.SessionID, ch)
  84. // 发送异步请求到 opencode(触发AI处理)
  85. fmt.Printf("🔍 [StreamPromptHandler] 发送异步请求到 opencode, sessionID=%s\n", req.SessionID)
  86. // 忽略返回的通道,事件通过EventDispatcher分发
  87. _, err = client.SendPromptStream(ctx, req.SessionID, prompt)
  88. if err != nil {
  89. fmt.Printf("🔍 [StreamPromptHandler] 发送请求失败: %v\n", err)
  90. http.Error(w, fmt.Sprintf("发送请求失败: %v", err), http.StatusInternalServerError)
  91. return
  92. }
  93. fmt.Printf("🔍 [StreamPromptHandler] 异步请求发送成功,等待事件流\n")
  94. // 发送流式响应
  95. flusher, ok := w.(http.Flusher)
  96. if !ok {
  97. http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
  98. return
  99. }
  100. fmt.Printf("🔍 [StreamPromptHandler] 开始发送流式响应\n")
  101. eventCount := 0
  102. // 创建心跳定时器,每30秒发送一次心跳保活(SSE注释格式)
  103. heartbeatTicker := time.NewTicker(30 * time.Second)
  104. defer heartbeatTicker.Stop()
  105. // 发送初始心跳,确保连接立即活跃
  106. fmt.Fprintf(w, ": heartbeat\n\n")
  107. flusher.Flush()
  108. for {
  109. select {
  110. case data, ok := <-ch:
  111. if !ok {
  112. // 通道关闭,发送结束标记
  113. fmt.Printf("🔍 [StreamPromptHandler] 流式通道关闭,发送DONE标记,共发送 %d 个事件\n", eventCount)
  114. fmt.Fprintf(w, "data: [DONE]\n\n")
  115. flusher.Flush()
  116. return
  117. }
  118. // 发送 SSE 数据
  119. eventCount++
  120. fmt.Printf("🔍 [StreamPromptHandler] 发送SSE数据[%d]: %s\n", eventCount, data)
  121. // 发送 SSE 数据,opencode 数据已包含 payload 字段,不需要额外包装
  122. var wrappedData string
  123. if data == "[DONE]" {
  124. wrappedData = "[DONE]"
  125. } else {
  126. // 尝试解析为JSON,检查是否已有payload字段
  127. var jsonData interface{}
  128. if err := json.Unmarshal([]byte(data), &jsonData); err == nil {
  129. // 去重处理:移除message.updated事件中的重复content,并过滤不必要的事件
  130. jsonData = removeDuplicateContent(jsonData)
  131. // 如果返回nil,跳过此事件
  132. if jsonData == nil {
  133. continue
  134. }
  135. // 检查是否是对象且包含payload字段
  136. if obj, ok := jsonData.(map[string]interface{}); ok && obj["payload"] != nil {
  137. // 已有payload字段,直接发送原始数据
  138. wrappedData = data
  139. } else {
  140. // 没有payload字段,包装在payload对象中
  141. wrapped := map[string]interface{}{
  142. "payload": jsonData,
  143. }
  144. wrappedBytes, _ := json.Marshal(wrapped)
  145. wrappedData = string(wrappedBytes)
  146. }
  147. } else {
  148. // 不是JSON,按原样发送
  149. wrappedData = data
  150. }
  151. }
  152. fmt.Fprintf(w, "data: %s\n\n", wrappedData)
  153. flusher.Flush()
  154. case <-ctx.Done():
  155. fmt.Printf("🔍 [StreamPromptHandler] 上下文超时\n")
  156. return
  157. case <-heartbeatTicker.C:
  158. // 发送心跳保活(SSE注释格式)
  159. fmt.Printf("🔍 [StreamPromptHandler] 发送心跳保活\n")
  160. fmt.Fprintf(w, ": heartbeat\n\n")
  161. flusher.Flush()
  162. case <-r.Context().Done():
  163. fmt.Printf("🔍 [StreamPromptHandler] 客户端断开连接\n")
  164. return
  165. }
  166. }
  167. }
  168. // 包装TokenAuth中间件
  169. return authbase.TokenAuth(http.HandlerFunc(handler)).ServeHTTP
  170. }
  171. // BindJSON 简单的 JSON 绑定函数
  172. func BindJSON(r *http.Request, v interface{}) error {
  173. body, err := io.ReadAll(r.Body)
  174. if err != nil {
  175. return err
  176. }
  177. defer r.Body.Close()
  178. return json.Unmarshal(body, v)
  179. }
  180. // removeDuplicateContent 移除message.updated事件中的重复content,避免前端重复显示
  181. func removeDuplicateContent(data interface{}) interface{} {
  182. // 检查是否为map
  183. obj, ok := data.(map[string]interface{})
  184. if !ok {
  185. return data
  186. }
  187. // 递归处理payload字段
  188. if payload, ok := obj["payload"].(map[string]interface{}); ok {
  189. obj["payload"] = removeDuplicateContent(payload)
  190. }
  191. // 如果payload是数组(可能嵌套),处理每个元素
  192. if payloadArr, ok := obj["payload"].([]interface{}); ok {
  193. for i, item := range payloadArr {
  194. if itemMap, ok := item.(map[string]interface{}); ok {
  195. payloadArr[i] = removeDuplicateContent(itemMap)
  196. }
  197. }
  198. }
  199. // 检查type字段
  200. typeVal, hasType := obj["type"]
  201. if !hasType {
  202. return obj
  203. }
  204. typeStr, ok := typeVal.(string)
  205. if !ok {
  206. return obj
  207. }
  208. // 事件过滤策略:保守过滤,保留大多数事件以确保连接稳定
  209. switch typeStr {
  210. case "message.updated":
  211. // 对于message.updated事件,移除content字段避免重复,但不过滤事件本身
  212. if properties, ok := obj["properties"].(map[string]interface{}); ok {
  213. if info, ok := properties["info"].(map[string]interface{}); ok {
  214. // 移除content字段,避免重复
  215. delete(info, "content")
  216. }
  217. }
  218. // 保留事件,不过滤
  219. return obj
  220. case "session.status":
  221. // 保留状态事件,可能包含重要状态信息
  222. return obj
  223. case "session.diff":
  224. // 保留差异事件,避免客户端逻辑中断
  225. return obj
  226. case "server.heartbeat":
  227. // 保留心跳事件,保持连接活跃 - 非常重要!
  228. return obj
  229. case "session.idle":
  230. // 保留事件,不过滤
  231. return obj
  232. }
  233. // 其他所有事件类型都保留,包括message.part.updated(关键事件)
  234. return obj
  235. }