Ei kuvausta
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.go 4.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log"
  7. "net/http"
  8. "git.x2erp.com/qdy/go-base/ctx"
  9. "git.x2erp.com/qdy/go-db/factory/database"
  10. mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
  11. )
  12. // Server 包装 MCP SDK 服务器,提供自动注册和依赖注入
  13. type Server struct {
  14. port int
  15. serviceName string
  16. sdkServer *mcpsdk.Server
  17. transport mcpsdk.Transport
  18. dbFactory *database.DBFactory
  19. dbsFactory *database.DBSFactory
  20. //baseCtx *ctx.RequestContext
  21. httpServer *http.Server
  22. handler http.Handler
  23. }
  24. // Config 服务器配置
  25. type Config struct {
  26. Name string
  27. Version string
  28. Port int
  29. ServiceName string
  30. Description string
  31. DBFactory *database.DBFactory
  32. DBSFactory *database.DBSFactory
  33. //BaseCtx *ctx.RequestContext
  34. }
  35. // NewServer 创建新的 MCP 服务器
  36. func NewServer(cfg Config) (*Server, error) {
  37. impl := &mcpsdk.Implementation{
  38. Name: cfg.Name,
  39. Version: cfg.Version,
  40. }
  41. sdkServer := mcpsdk.NewServer(impl, nil)
  42. server := &Server{
  43. sdkServer: sdkServer,
  44. dbFactory: cfg.DBFactory,
  45. dbsFactory: cfg.DBSFactory,
  46. //baseCtx: cfg.BaseCtx,
  47. port: cfg.Port,
  48. serviceName: cfg.ServiceName,
  49. }
  50. // 自动注册所有已注册的工具
  51. if err := server.registerAllTools(); err != nil {
  52. return nil, fmt.Errorf("failed to register tools: %w", err)
  53. }
  54. return server, nil
  55. }
  56. // registerAllTools 将注册表中的所有工具注册到 MCP 服务器
  57. func (s *Server) registerAllTools() error {
  58. tools := ListTools()
  59. for _, tool := range tools {
  60. if err := s.registerTool(tool); err != nil {
  61. return fmt.Errorf("failed to register tool %s: %w", tool.Name, err)
  62. }
  63. }
  64. log.Printf("Registered %d MCP tools", len(tools))
  65. return nil
  66. }
  67. // registerTool 注册单个工具到 MCP 服务器
  68. func (s *Server) registerTool(tool ToolDefinition) error {
  69. // 创建工具处理器
  70. handler := s.createToolHandler(tool)
  71. // 创建 MCP 工具
  72. mcpTool := &mcpsdk.Tool{
  73. Name: tool.Name,
  74. Description: tool.Description,
  75. InputSchema: tool.InputSchema,
  76. }
  77. // 注册工具到服务器
  78. mcpsdk.AddTool(s.sdkServer, mcpTool, handler)
  79. return nil
  80. }
  81. // createToolHandler 创建 MCP 工具处理器
  82. func (s *Server) createToolHandler(tool ToolDefinition) mcpsdk.ToolHandlerFor[map[string]interface{}, interface{}] {
  83. return func(ctxMcp context.Context, request *mcpsdk.CallToolRequest, input map[string]interface{}) (*mcpsdk.CallToolResult, interface{}, error) {
  84. // 将输入转换为 JSON
  85. inputJSON, err := json.Marshal(input)
  86. if err != nil {
  87. return nil, nil, fmt.Errorf("failed to marshal input: %w", err)
  88. }
  89. // 提取请求上下文信息
  90. //reqCtx := s.extractRequestContext(request)
  91. reqCtx := ctx.FromContext(ctxMcp)
  92. // 创建工具依赖项
  93. toolDeps := &ToolDependencies{
  94. DBFactory: s.dbFactory,
  95. DBSFactory: s.dbsFactory,
  96. ReqCtx: reqCtx,
  97. }
  98. // 执行工具
  99. result, err := tool.Execute(json.RawMessage(inputJSON), toolDeps)
  100. if err != nil {
  101. // 返回工具错误(非协议错误)
  102. return &mcpsdk.CallToolResult{
  103. IsError: true,
  104. Content: []mcpsdk.Content{
  105. &mcpsdk.TextContent{Text: fmt.Sprintf("tool error: %v", err)},
  106. },
  107. }, nil, nil
  108. }
  109. // 返回成功结果
  110. return nil, result, nil
  111. }
  112. }
  113. // // extractRequestContext 从 MCP 请求中提取上下文信息
  114. // func (s *Server) extractRequestContext(request *mcpsdk.CallToolRequest) *ctx.RequestContext {
  115. // reqCtx := &ctx.RequestContext{}
  116. // if s.baseCtx != nil {
  117. // // 复制基础上下文
  118. // *reqCtx = *s.baseCtx
  119. // }
  120. // // 从请求的 Extra 数据中提取自定义项目 ID
  121. // extra := request.GetExtra()
  122. // if extra != nil && extra.Header != nil {
  123. // // 确定项目 ID 头名称
  124. // projectIDHeader := os.Getenv("MCP_PROJECT_ID_HEADER")
  125. // if projectIDHeader == "" {
  126. // projectIDHeader = "X-Project-ID"
  127. // }
  128. // if projectID := extra.Header.Get(projectIDHeader); projectID != "" {
  129. // // 将项目 ID 存储在 TraceID 中
  130. // reqCtx.ProjectID = projectID
  131. // }
  132. // }
  133. // return reqCtx
  134. // }
  135. // SetTransport 设置传输层
  136. func (s *Server) SetTransport(transport mcpsdk.Transport) {
  137. s.transport = transport
  138. }
  139. // GetSDKServer 返回底层的 SDK 服务器实例
  140. func (s *Server) GetSDKServer() *mcpsdk.Server {
  141. return s.sdkServer
  142. }
  143. // startHTTPServer 启动 HTTP 服务器
  144. func (s *Server) Run(handler http.Handler) {
  145. s.handler = handler
  146. addr := fmt.Sprintf(":%d", s.port)
  147. s.httpServer = &http.Server{
  148. Addr: addr,
  149. Handler: s.handler,
  150. }
  151. log.Printf("%s listening on %s", s.serviceName, addr)
  152. // 在 goroutine 中启动服务器
  153. go func() {
  154. if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  155. log.Fatalf("%s failed to start: %v", s.serviceName, err)
  156. }
  157. }()
  158. }
  159. // GetHTTPServer 返回内部的 HTTP 服务器实例
  160. func (s *Server) GetHTTPServer() *http.Server {
  161. return s.httpServer
  162. }