Ei kuvausta
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

prompt_stream_routes.go 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package routes
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "time"
  9. "git.x2erp.com/qdy/go-base/ctx"
  10. "git.x2erp.com/qdy/go-base/model/response"
  11. "git.x2erp.com/qdy/go-base/webx/router"
  12. "git.x2erp.com/qdy/go-svc-code/internal/opencode"
  13. )
  14. // PromptStreamRequest 流式对话请求
  15. type PromptStreamRequest struct {
  16. SessionID string `json:"sessionID" binding:"required"`
  17. Parts []opencode.TextPart `json:"parts" binding:"required"`
  18. Agent string `json:"agent,omitempty"`
  19. Model *opencode.ModelInfo `json:"model,omitempty"`
  20. }
  21. // RegisterPromptStreamRoutes 注册流式对话路由
  22. func RegisterPromptStreamRoutes(ws *router.RouterService, client *opencode.Client) {
  23. // 流式对话
  24. ws.POST("/api/prompt/stream",
  25. func(req *PromptStreamRequest, ctx context.Context, reqCtx *ctx.RequestContext) (*response.QueryResult[map[string]interface{}], error) {
  26. // 流式响应需要直接写入 HTTP 响应,不能使用标准的路由返回值
  27. // 这里返回一个特殊结果,指示上层使用流式处理
  28. return &response.QueryResult[map[string]interface{}]{
  29. Success: false,
  30. Message: "流式端点需要特殊处理",
  31. }, nil
  32. },
  33. ).Desc("流式对话(Server-Sent Events)").Register()
  34. // 流式对话的原始 HTTP 处理器
  35. // 注意:这个路由需要直接处理 HTTP 流,不能使用标准的 router 包装
  36. // 我们将注册一个原始的 HTTP 处理器到 gin 引擎
  37. // 这个函数将由 main.go 中的额外注册调用
  38. }
  39. // StreamPromptHandler 流式对话的 HTTP 处理器
  40. func StreamPromptHandler(client opencode.OpenCodeClient) http.HandlerFunc {
  41. return func(w http.ResponseWriter, r *http.Request) {
  42. // 解析请求
  43. var req PromptStreamRequest
  44. if err := BindJSON(r, &req); err != nil {
  45. http.Error(w, fmt.Sprintf("解析请求失败: %v", err), http.StatusBadRequest)
  46. return
  47. }
  48. // 创建 prompt 请求
  49. prompt := &opencode.PromptRequest{
  50. Parts: req.Parts,
  51. Agent: req.Agent,
  52. Model: req.Model,
  53. }
  54. // 设置 SSE 头
  55. w.Header().Set("Content-Type", "text/event-stream")
  56. w.Header().Set("Cache-Control", "no-cache")
  57. w.Header().Set("Connection", "keep-alive")
  58. w.Header().Set("Access-Control-Allow-Origin", "*")
  59. // 创建带超时的上下文
  60. ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
  61. defer cancel()
  62. // 获取流式响应通道
  63. ch, err := client.SendPromptStream(ctx, req.SessionID, prompt)
  64. if err != nil {
  65. http.Error(w, fmt.Sprintf("发送流式请求失败: %v", err), http.StatusInternalServerError)
  66. return
  67. }
  68. // 发送流式响应
  69. flusher, ok := w.(http.Flusher)
  70. if !ok {
  71. http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
  72. return
  73. }
  74. for {
  75. select {
  76. case data, ok := <-ch:
  77. if !ok {
  78. // 通道关闭,发送结束标记
  79. fmt.Fprintf(w, "data: [DONE]\n\n")
  80. flusher.Flush()
  81. return
  82. }
  83. // 发送 SSE 数据
  84. fmt.Fprintf(w, "data: %s\n\n", data)
  85. flusher.Flush()
  86. case <-ctx.Done():
  87. return
  88. case <-r.Context().Done():
  89. return
  90. }
  91. }
  92. }
  93. }
  94. // BindJSON 简单的 JSON 绑定函数
  95. func BindJSON(r *http.Request, v interface{}) error {
  96. body, err := io.ReadAll(r.Body)
  97. if err != nil {
  98. return err
  99. }
  100. defer r.Body.Close()
  101. return json.Unmarshal(body, v)
  102. }