Нет описания
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log"
  7. "net/http"
  8. "git.x2erp.com/qdy/go-base/ctx"
  9. "git.x2erp.com/qdy/go-db/factory/database"
  10. mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
  11. )
  12. // Server 包装 MCP SDK 服务器,提供自动注册和依赖注入
  13. type Server struct {
  14. port int
  15. serviceName string
  16. sdkServer *mcpsdk.Server
  17. transport mcpsdk.Transport
  18. dbFactory *database.DBFactory
  19. //baseCtx *ctx.RequestContext
  20. httpServer *http.Server
  21. handler http.Handler
  22. }
  23. // Config 服务器配置
  24. type Config struct {
  25. Name string
  26. Version string
  27. Port int
  28. ServiceName string
  29. Description string
  30. DBFactory *database.DBFactory
  31. //BaseCtx *ctx.RequestContext
  32. }
  33. // NewServer 创建新的 MCP 服务器
  34. func NewServer(cfg Config) (*Server, error) {
  35. impl := &mcpsdk.Implementation{
  36. Name: cfg.Name,
  37. Version: cfg.Version,
  38. }
  39. sdkServer := mcpsdk.NewServer(impl, nil)
  40. server := &Server{
  41. sdkServer: sdkServer,
  42. dbFactory: cfg.DBFactory,
  43. //baseCtx: cfg.BaseCtx,
  44. port: cfg.Port,
  45. serviceName: cfg.ServiceName,
  46. }
  47. // 自动注册所有已注册的工具
  48. if err := server.registerAllTools(); err != nil {
  49. return nil, fmt.Errorf("failed to register tools: %w", err)
  50. }
  51. return server, nil
  52. }
  53. // registerAllTools 将注册表中的所有工具注册到 MCP 服务器
  54. func (s *Server) registerAllTools() error {
  55. tools := ListTools()
  56. for _, tool := range tools {
  57. if err := s.registerTool(tool); err != nil {
  58. return fmt.Errorf("failed to register tool %s: %w", tool.Name, err)
  59. }
  60. }
  61. log.Printf("Registered %d MCP tools", len(tools))
  62. return nil
  63. }
  64. // registerTool 注册单个工具到 MCP 服务器
  65. func (s *Server) registerTool(tool ToolDefinition) error {
  66. // 创建工具处理器
  67. handler := s.createToolHandler(tool)
  68. // 创建 MCP 工具
  69. mcpTool := &mcpsdk.Tool{
  70. Name: tool.Name,
  71. Description: tool.Description,
  72. InputSchema: tool.InputSchema,
  73. }
  74. // 注册工具到服务器
  75. mcpsdk.AddTool(s.sdkServer, mcpTool, handler)
  76. return nil
  77. }
  78. // createToolHandler 创建 MCP 工具处理器
  79. func (s *Server) createToolHandler(tool ToolDefinition) mcpsdk.ToolHandlerFor[map[string]interface{}, interface{}] {
  80. return func(ctxMcp context.Context, request *mcpsdk.CallToolRequest, input map[string]interface{}) (*mcpsdk.CallToolResult, interface{}, error) {
  81. // 将输入转换为 JSON
  82. inputJSON, err := json.Marshal(input)
  83. if err != nil {
  84. return nil, nil, fmt.Errorf("failed to marshal input: %w", err)
  85. }
  86. // 提取请求上下文信息
  87. //reqCtx := s.extractRequestContext(request)
  88. reqCtx := ctx.FromContext(ctxMcp)
  89. // 创建工具依赖项
  90. toolDeps := &ToolDependencies{
  91. DBFactory: s.dbFactory,
  92. ReqCtx: reqCtx,
  93. }
  94. // 执行工具
  95. result, err := tool.Execute(json.RawMessage(inputJSON), toolDeps)
  96. if err != nil {
  97. // 返回工具错误(非协议错误)
  98. return &mcpsdk.CallToolResult{
  99. IsError: true,
  100. Content: []mcpsdk.Content{
  101. &mcpsdk.TextContent{Text: fmt.Sprintf("tool error: %v", err)},
  102. },
  103. }, nil, nil
  104. }
  105. // 返回成功结果
  106. return nil, result, nil
  107. }
  108. }
  109. // // extractRequestContext 从 MCP 请求中提取上下文信息
  110. // func (s *Server) extractRequestContext(request *mcpsdk.CallToolRequest) *ctx.RequestContext {
  111. // reqCtx := &ctx.RequestContext{}
  112. // if s.baseCtx != nil {
  113. // // 复制基础上下文
  114. // *reqCtx = *s.baseCtx
  115. // }
  116. // // 从请求的 Extra 数据中提取自定义项目 ID
  117. // extra := request.GetExtra()
  118. // if extra != nil && extra.Header != nil {
  119. // // 确定项目 ID 头名称
  120. // projectIDHeader := os.Getenv("MCP_PROJECT_ID_HEADER")
  121. // if projectIDHeader == "" {
  122. // projectIDHeader = "X-Project-ID"
  123. // }
  124. // if projectID := extra.Header.Get(projectIDHeader); projectID != "" {
  125. // // 将项目 ID 存储在 TraceID 中
  126. // reqCtx.ProjectID = projectID
  127. // }
  128. // }
  129. // return reqCtx
  130. // }
  131. // SetTransport 设置传输层
  132. func (s *Server) SetTransport(transport mcpsdk.Transport) {
  133. s.transport = transport
  134. }
  135. // GetSDKServer 返回底层的 SDK 服务器实例
  136. func (s *Server) GetSDKServer() *mcpsdk.Server {
  137. return s.sdkServer
  138. }
  139. // startHTTPServer 启动 HTTP 服务器
  140. func (s *Server) Run(handler http.Handler) {
  141. s.handler = handler
  142. addr := fmt.Sprintf(":%d", s.port)
  143. s.httpServer = &http.Server{
  144. Addr: addr,
  145. Handler: s.handler,
  146. }
  147. log.Printf("%s listening on %s", s.serviceName, addr)
  148. // 在 goroutine 中启动服务器
  149. go func() {
  150. if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  151. log.Fatalf("%s failed to start: %v", s.serviceName, err)
  152. }
  153. }()
  154. }
  155. // GetHTTPServer 返回内部的 HTTP 服务器实例
  156. func (s *Server) GetHTTPServer() *http.Server {
  157. return s.httpServer
  158. }