Nav apraksta
Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

server.go 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log"
  7. "net/http"
  8. "os"
  9. "git.x2erp.com/qdy/go-base/ctx"
  10. "git.x2erp.com/qdy/go-db/factory/database"
  11. mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
  12. )
  13. // Server 包装 MCP SDK 服务器,提供自动注册和依赖注入
  14. type Server struct {
  15. port int
  16. serviceName string
  17. sdkServer *mcpsdk.Server
  18. transport mcpsdk.Transport
  19. dbFactory *database.DBFactory
  20. baseCtx *ctx.RequestContext
  21. httpServer *http.Server
  22. handler http.Handler
  23. }
  24. // Config 服务器配置
  25. type Config struct {
  26. Name string
  27. Version string
  28. Port int
  29. ServiceName string
  30. Description string
  31. DBFactory *database.DBFactory
  32. BaseCtx *ctx.RequestContext
  33. }
  34. // NewServer 创建新的 MCP 服务器
  35. func NewServer(cfg Config) (*Server, error) {
  36. impl := &mcpsdk.Implementation{
  37. Name: cfg.Name,
  38. Version: cfg.Version,
  39. }
  40. sdkServer := mcpsdk.NewServer(impl, nil)
  41. server := &Server{
  42. sdkServer: sdkServer,
  43. dbFactory: cfg.DBFactory,
  44. baseCtx: cfg.BaseCtx,
  45. port: cfg.Port,
  46. serviceName: cfg.ServiceName,
  47. }
  48. // 自动注册所有已注册的工具
  49. if err := server.registerAllTools(); err != nil {
  50. return nil, fmt.Errorf("failed to register tools: %w", err)
  51. }
  52. return server, nil
  53. }
  54. // registerAllTools 将注册表中的所有工具注册到 MCP 服务器
  55. func (s *Server) registerAllTools() error {
  56. tools := ListTools()
  57. for _, tool := range tools {
  58. if err := s.registerTool(tool); err != nil {
  59. return fmt.Errorf("failed to register tool %s: %w", tool.Name, err)
  60. }
  61. }
  62. log.Printf("Registered %d MCP tools", len(tools))
  63. return nil
  64. }
  65. // registerTool 注册单个工具到 MCP 服务器
  66. func (s *Server) registerTool(tool ToolDefinition) error {
  67. // 创建工具处理器
  68. handler := s.createToolHandler(tool)
  69. // 创建 MCP 工具
  70. mcpTool := &mcpsdk.Tool{
  71. Name: tool.Name,
  72. Description: tool.Description,
  73. InputSchema: tool.InputSchema,
  74. }
  75. // 注册工具到服务器
  76. mcpsdk.AddTool(s.sdkServer, mcpTool, handler)
  77. return nil
  78. }
  79. // createToolHandler 创建 MCP 工具处理器
  80. func (s *Server) createToolHandler(tool ToolDefinition) mcpsdk.ToolHandlerFor[map[string]interface{}, interface{}] {
  81. return func(_ context.Context, request *mcpsdk.CallToolRequest, input map[string]interface{}) (*mcpsdk.CallToolResult, interface{}, error) {
  82. // 将输入转换为 JSON
  83. inputJSON, err := json.Marshal(input)
  84. if err != nil {
  85. return nil, nil, fmt.Errorf("failed to marshal input: %w", err)
  86. }
  87. // 提取请求上下文信息
  88. reqCtx := s.extractRequestContext(request)
  89. // 创建工具依赖项
  90. toolDeps := &ToolDependencies{
  91. DBFactory: s.dbFactory,
  92. ReqCtx: reqCtx,
  93. }
  94. // 执行工具
  95. result, err := tool.Execute(json.RawMessage(inputJSON), toolDeps)
  96. if err != nil {
  97. // 返回工具错误(非协议错误)
  98. return &mcpsdk.CallToolResult{
  99. IsError: true,
  100. Content: []mcpsdk.Content{
  101. &mcpsdk.TextContent{Text: fmt.Sprintf("tool error: %v", err)},
  102. },
  103. }, nil, nil
  104. }
  105. // 返回成功结果
  106. return nil, result, nil
  107. }
  108. }
  109. // extractRequestContext 从 MCP 请求中提取上下文信息
  110. func (s *Server) extractRequestContext(request *mcpsdk.CallToolRequest) *ctx.RequestContext {
  111. reqCtx := &ctx.RequestContext{}
  112. if s.baseCtx != nil {
  113. // 复制基础上下文
  114. *reqCtx = *s.baseCtx
  115. }
  116. // 从请求的 Extra 数据中提取自定义项目 ID
  117. extra := request.GetExtra()
  118. if extra != nil && extra.Header != nil {
  119. // 确定项目 ID 头名称
  120. projectIDHeader := os.Getenv("MCP_PROJECT_ID_HEADER")
  121. if projectIDHeader == "" {
  122. projectIDHeader = "X-Project-ID"
  123. }
  124. if projectID := extra.Header.Get(projectIDHeader); projectID != "" {
  125. // 将项目 ID 存储在 TraceID 中
  126. reqCtx.ProjectID = projectID
  127. }
  128. }
  129. return reqCtx
  130. }
  131. // SetTransport 设置传输层
  132. func (s *Server) SetTransport(transport mcpsdk.Transport) {
  133. s.transport = transport
  134. }
  135. // GetSDKServer 返回底层的 SDK 服务器实例
  136. func (s *Server) GetSDKServer() *mcpsdk.Server {
  137. return s.sdkServer
  138. }
  139. // startHTTPServer 启动 HTTP 服务器
  140. func (s *Server) Run(handler http.Handler) {
  141. s.handler = handler
  142. addr := fmt.Sprintf(":%d", s.port)
  143. s.httpServer = &http.Server{
  144. Addr: addr,
  145. Handler: s.handler,
  146. }
  147. log.Printf("%s listening on %s", s.serviceName, addr)
  148. // 在 goroutine 中启动服务器
  149. go func() {
  150. if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  151. log.Fatalf("%s failed to start: %v", s.serviceName, err)
  152. }
  153. }()
  154. }
  155. // GetHTTPServer 返回内部的 HTTP 服务器实例
  156. func (s *Server) GetHTTPServer() *http.Server {
  157. return s.httpServer
  158. }