Aucune description
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

prompt_stream_routes.go 3.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package routes
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "time"
  9. "git.x2erp.com/qdy/go-base/authbase"
  10. "git.x2erp.com/qdy/go-base/ctx"
  11. "git.x2erp.com/qdy/go-base/model/response"
  12. "git.x2erp.com/qdy/go-base/webx/router"
  13. "git.x2erp.com/qdy/go-svc-code/internal/opencode"
  14. )
  15. // PromptStreamRequest 流式对话请求
  16. type PromptStreamRequest struct {
  17. SessionID string `json:"sessionID" binding:"required"`
  18. Parts []opencode.TextPart `json:"parts" binding:"required"`
  19. Agent string `json:"agent,omitempty"`
  20. Model *opencode.ModelInfo `json:"model,omitempty"`
  21. }
  22. // RegisterPromptStreamRoutes 注册流式对话路由
  23. func RegisterPromptStreamRoutes(ws *router.RouterService, client *opencode.Client) {
  24. // 流式对话(需要Token认证)
  25. ws.POST("/api/prompt/stream",
  26. func(req *PromptStreamRequest, ctx context.Context, reqCtx *ctx.RequestContext) (*response.QueryResult[map[string]interface{}], error) {
  27. // 流式响应需要直接写入 HTTP 响应,不能使用标准的路由返回值
  28. // 这里返回一个特殊结果,指示上层使用流式处理
  29. return &response.QueryResult[map[string]interface{}]{
  30. Success: false,
  31. Message: "流式端点需要特殊处理",
  32. }, nil
  33. },
  34. ).Use(authbase.TokenAuth).Desc("流式对话(Server-Sent Events)").Register()
  35. // 流式对话的原始 HTTP 处理器
  36. // 注意:这个路由需要直接处理 HTTP 流,不能使用标准的 router 包装
  37. // 我们将注册一个原始的 HTTP 处理器到 gin 引擎
  38. // 这个函数将由 main.go 中的额外注册调用
  39. }
  40. // StreamPromptHandler 流式对话的 HTTP 处理器(已包含TokenAuth认证)
  41. func StreamPromptHandler(client opencode.OpenCodeClient) http.HandlerFunc {
  42. // 创建内部处理器
  43. handler := func(w http.ResponseWriter, r *http.Request) {
  44. // 解析请求
  45. var req PromptStreamRequest
  46. if err := BindJSON(r, &req); err != nil {
  47. http.Error(w, fmt.Sprintf("解析请求失败: %v", err), http.StatusBadRequest)
  48. return
  49. }
  50. // 创建 prompt 请求
  51. prompt := &opencode.PromptRequest{
  52. Parts: req.Parts,
  53. Agent: req.Agent,
  54. Model: req.Model,
  55. }
  56. // 设置 SSE 头
  57. w.Header().Set("Content-Type", "text/event-stream")
  58. w.Header().Set("Cache-Control", "no-cache")
  59. w.Header().Set("Connection", "keep-alive")
  60. w.Header().Set("Access-Control-Allow-Origin", "*")
  61. // 创建带超时的上下文
  62. ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
  63. defer cancel()
  64. // 获取流式响应通道
  65. ch, err := client.SendPromptStream(ctx, req.SessionID, prompt)
  66. if err != nil {
  67. http.Error(w, fmt.Sprintf("发送流式请求失败: %v", err), http.StatusInternalServerError)
  68. return
  69. }
  70. // 发送流式响应
  71. flusher, ok := w.(http.Flusher)
  72. if !ok {
  73. http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
  74. return
  75. }
  76. for {
  77. select {
  78. case data, ok := <-ch:
  79. if !ok {
  80. // 通道关闭,发送结束标记
  81. fmt.Fprintf(w, "data: [DONE]\n\n")
  82. flusher.Flush()
  83. return
  84. }
  85. // 发送 SSE 数据
  86. fmt.Fprintf(w, "data: %s\n\n", data)
  87. flusher.Flush()
  88. case <-ctx.Done():
  89. return
  90. case <-r.Context().Done():
  91. return
  92. }
  93. }
  94. }
  95. // 包装TokenAuth中间件
  96. return authbase.TokenAuth(http.HandlerFunc(handler)).ServeHTTP
  97. }
  98. // BindJSON 简单的 JSON 绑定函数
  99. func BindJSON(r *http.Request, v interface{}) error {
  100. body, err := io.ReadAll(r.Body)
  101. if err != nil {
  102. return err
  103. }
  104. defer r.Body.Close()
  105. return json.Unmarshal(body, v)
  106. }