Nessuna descrizione
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

prompt_stream_routes.go 6.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. package routes
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "time"
  9. "git.x2erp.com/qdy/go-base/authbase"
  10. "git.x2erp.com/qdy/go-base/webx"
  11. "git.x2erp.com/qdy/go-base/webx/router"
  12. "git.x2erp.com/qdy/go-svc-code/internal/opencode"
  13. )
  14. // PromptStreamRequest 流式对话请求
  15. type PromptStreamRequest struct {
  16. SessionID string `json:"sessionID" binding:"required"`
  17. Parts []opencode.TextPart `json:"parts" binding:"required"`
  18. Agent string `json:"agent,omitempty"`
  19. Model *opencode.ModelInfo `json:"model,omitempty"`
  20. }
  21. // RegisterPromptStreamRoutes 注册流式对话路由
  22. func RegisterPromptStreamRoutes(ws *router.RouterService, webService *webx.WebService, client opencode.OpenCodeClient) {
  23. // 流式对话需要直接处理 HTTP 流式响应,不能使用标准的路由包装
  24. // 我们直接注册到 webService 的底层路由器
  25. webService.GetRouter().Handle("/api/prompt/stream", StreamPromptHandler(client))
  26. }
  27. // StreamPromptHandler 流式对话的 HTTP 处理器(已包含TokenAuth认证)
  28. func StreamPromptHandler(client opencode.OpenCodeClient) http.HandlerFunc {
  29. // 创建内部处理器
  30. handler := func(w http.ResponseWriter, r *http.Request) {
  31. fmt.Printf("🔍 [StreamPromptHandler] 收到流式对话请求: %s %s\n", r.Method, r.URL.String())
  32. // 解析请求
  33. var req PromptStreamRequest
  34. if err := BindJSON(r, &req); err != nil {
  35. fmt.Printf("🔍 [StreamPromptHandler] 解析请求失败: %v\n", err)
  36. http.Error(w, fmt.Sprintf("解析请求失败: %v", err), http.StatusBadRequest)
  37. return
  38. }
  39. fmt.Printf("🔍 [StreamPromptHandler] 请求数据: sessionID=%s, agent=%v, parts=%d\n",
  40. req.SessionID, req.Agent, len(req.Parts))
  41. if len(req.Parts) > 0 && req.Parts[0].Text != "" {
  42. fmt.Printf("🔍 [StreamPromptHandler] 用户消息: %s\n", req.Parts[0].Text)
  43. }
  44. // 创建 prompt 请求
  45. prompt := &opencode.PromptRequest{
  46. Parts: req.Parts,
  47. Agent: req.Agent,
  48. Model: req.Model,
  49. }
  50. // 设置 SSE 头
  51. w.Header().Set("Content-Type", "text/event-stream")
  52. w.Header().Set("Cache-Control", "no-cache")
  53. w.Header().Set("Connection", "keep-alive")
  54. w.Header().Set("Access-Control-Allow-Origin", "*")
  55. // 创建带超时的上下文
  56. ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
  57. defer cancel()
  58. fmt.Printf("🔍 [StreamPromptHandler] 调用 SendPromptStream, sessionID=%s\n", req.SessionID)
  59. // 获取流式响应通道
  60. ch, err := client.SendPromptStream(ctx, req.SessionID, prompt)
  61. if err != nil {
  62. fmt.Printf("🔍 [StreamPromptHandler] 发送流式请求失败: %v\n", err)
  63. http.Error(w, fmt.Sprintf("发送流式请求失败: %v", err), http.StatusInternalServerError)
  64. return
  65. }
  66. fmt.Printf("🔍 [StreamPromptHandler] 成功获取流式响应通道\n")
  67. // 发送流式响应
  68. flusher, ok := w.(http.Flusher)
  69. if !ok {
  70. http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
  71. return
  72. }
  73. fmt.Printf("🔍 [StreamPromptHandler] 开始发送流式响应\n")
  74. eventCount := 0
  75. for {
  76. select {
  77. case data, ok := <-ch:
  78. if !ok {
  79. // 通道关闭,发送结束标记
  80. fmt.Printf("🔍 [StreamPromptHandler] 流式通道关闭,发送DONE标记,共发送 %d 个事件\n", eventCount)
  81. fmt.Fprintf(w, "data: [DONE]\n\n")
  82. flusher.Flush()
  83. return
  84. }
  85. // 发送 SSE 数据
  86. eventCount++
  87. fmt.Printf("🔍 [StreamPromptHandler] 发送SSE数据[%d]: %s\n", eventCount, data)
  88. // 发送 SSE 数据,opencode 数据已包含 payload 字段,不需要额外包装
  89. var wrappedData string
  90. if data == "[DONE]" {
  91. wrappedData = "[DONE]"
  92. } else {
  93. // 尝试解析为JSON,检查是否已有payload字段
  94. var jsonData interface{}
  95. if err := json.Unmarshal([]byte(data), &jsonData); err == nil {
  96. // 去重处理:移除message.updated事件中的重复content,并过滤不必要的事件
  97. jsonData = removeDuplicateContent(jsonData)
  98. // 如果返回nil,跳过此事件
  99. if jsonData == nil {
  100. continue
  101. }
  102. // 检查是否是对象且包含payload字段
  103. if obj, ok := jsonData.(map[string]interface{}); ok && obj["payload"] != nil {
  104. // 已有payload字段,直接发送原始数据
  105. wrappedData = data
  106. } else {
  107. // 没有payload字段,包装在payload对象中
  108. wrapped := map[string]interface{}{
  109. "payload": jsonData,
  110. }
  111. wrappedBytes, _ := json.Marshal(wrapped)
  112. wrappedData = string(wrappedBytes)
  113. }
  114. } else {
  115. // 不是JSON,按原样发送
  116. wrappedData = data
  117. }
  118. }
  119. fmt.Fprintf(w, "data: %s\n\n", wrappedData)
  120. flusher.Flush()
  121. case <-ctx.Done():
  122. fmt.Printf("🔍 [StreamPromptHandler] 上下文超时\n")
  123. return
  124. case <-r.Context().Done():
  125. fmt.Printf("🔍 [StreamPromptHandler] 客户端断开连接\n")
  126. return
  127. }
  128. }
  129. }
  130. // 包装TokenAuth中间件
  131. return authbase.TokenAuth(http.HandlerFunc(handler)).ServeHTTP
  132. }
  133. // BindJSON 简单的 JSON 绑定函数
  134. func BindJSON(r *http.Request, v interface{}) error {
  135. body, err := io.ReadAll(r.Body)
  136. if err != nil {
  137. return err
  138. }
  139. defer r.Body.Close()
  140. return json.Unmarshal(body, v)
  141. }
  142. // removeDuplicateContent 移除message.updated事件中的重复content,避免前端重复显示
  143. func removeDuplicateContent(data interface{}) interface{} {
  144. // 检查是否为map
  145. obj, ok := data.(map[string]interface{})
  146. if !ok {
  147. return data
  148. }
  149. // 递归处理payload字段
  150. if payload, ok := obj["payload"].(map[string]interface{}); ok {
  151. obj["payload"] = removeDuplicateContent(payload)
  152. }
  153. // 如果payload是数组(可能嵌套),处理每个元素
  154. if payloadArr, ok := obj["payload"].([]interface{}); ok {
  155. for i, item := range payloadArr {
  156. if itemMap, ok := item.(map[string]interface{}); ok {
  157. payloadArr[i] = removeDuplicateContent(itemMap)
  158. }
  159. }
  160. }
  161. // 检查type字段
  162. typeVal, hasType := obj["type"]
  163. if !hasType {
  164. return obj
  165. }
  166. typeStr, ok := typeVal.(string)
  167. if !ok {
  168. return obj
  169. }
  170. // 事件过滤策略:减少发送给前端的事件数量
  171. switch typeStr {
  172. case "message.updated":
  173. // 检查是否有properties字段
  174. if properties, ok := obj["properties"].(map[string]interface{}); ok {
  175. if info, ok := properties["info"].(map[string]interface{}); ok {
  176. // 移除content字段,避免重复
  177. delete(info, "content")
  178. // 检查是否有completed时间,如果没有则过滤掉(只发送最终状态)
  179. if timeInfo, ok := info["time"].(map[string]interface{}); ok {
  180. if timeInfo["completed"] == nil {
  181. // 没有completed时间,这是中间状态,过滤掉
  182. return nil
  183. }
  184. }
  185. }
  186. }
  187. case "session.status":
  188. // session.status事件很频繁但前端可能不需要,过滤掉
  189. return nil
  190. case "session.diff":
  191. // session.diff事件通常为空,过滤掉
  192. return nil
  193. case "server.heartbeat":
  194. // 心跳事件,过滤掉
  195. return nil
  196. case "session.idle":
  197. // 空闲事件,过滤掉
  198. return nil
  199. }
  200. return obj
  201. }