package mcp import ( "context" "encoding/json" "fmt" "log" "net/http" "git.x2erp.com/qdy/go-base/ctx" "git.x2erp.com/qdy/go-db/factory/database" mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" ) // Server 包装 MCP SDK 服务器,提供自动注册和依赖注入 type Server struct { port int serviceName string sdkServer *mcpsdk.Server transport mcpsdk.Transport dbFactory *database.DBFactory //baseCtx *ctx.RequestContext httpServer *http.Server handler http.Handler } // Config 服务器配置 type Config struct { Name string Version string Port int ServiceName string Description string DBFactory *database.DBFactory //BaseCtx *ctx.RequestContext } // NewServer 创建新的 MCP 服务器 func NewServer(cfg Config) (*Server, error) { impl := &mcpsdk.Implementation{ Name: cfg.Name, Version: cfg.Version, } sdkServer := mcpsdk.NewServer(impl, nil) server := &Server{ sdkServer: sdkServer, dbFactory: cfg.DBFactory, //baseCtx: cfg.BaseCtx, port: cfg.Port, serviceName: cfg.ServiceName, } // 自动注册所有已注册的工具 if err := server.registerAllTools(); err != nil { return nil, fmt.Errorf("failed to register tools: %w", err) } return server, nil } // registerAllTools 将注册表中的所有工具注册到 MCP 服务器 func (s *Server) registerAllTools() error { tools := ListTools() for _, tool := range tools { if err := s.registerTool(tool); err != nil { return fmt.Errorf("failed to register tool %s: %w", tool.Name, err) } } log.Printf("Registered %d MCP tools", len(tools)) return nil } // registerTool 注册单个工具到 MCP 服务器 func (s *Server) registerTool(tool ToolDefinition) error { // 创建工具处理器 handler := s.createToolHandler(tool) // 创建 MCP 工具 mcpTool := &mcpsdk.Tool{ Name: tool.Name, Description: tool.Description, InputSchema: tool.InputSchema, } // 注册工具到服务器 mcpsdk.AddTool(s.sdkServer, mcpTool, handler) return nil } // createToolHandler 创建 MCP 工具处理器 func (s *Server) createToolHandler(tool ToolDefinition) mcpsdk.ToolHandlerFor[map[string]interface{}, interface{}] { return func(ctxMcp context.Context, request *mcpsdk.CallToolRequest, input map[string]interface{}) (*mcpsdk.CallToolResult, interface{}, error) { // 将输入转换为 JSON inputJSON, err := json.Marshal(input) if err != nil { return nil, nil, fmt.Errorf("failed to marshal input: %w", err) } // 提取请求上下文信息 //reqCtx := s.extractRequestContext(request) reqCtx := ctx.FromContext(ctxMcp) // 创建工具依赖项 toolDeps := &ToolDependencies{ DBFactory: s.dbFactory, ReqCtx: reqCtx, } // 执行工具 result, err := tool.Execute(json.RawMessage(inputJSON), toolDeps) if err != nil { // 返回工具错误(非协议错误) return &mcpsdk.CallToolResult{ IsError: true, Content: []mcpsdk.Content{ &mcpsdk.TextContent{Text: fmt.Sprintf("tool error: %v", err)}, }, }, nil, nil } // 返回成功结果 return nil, result, nil } } // // extractRequestContext 从 MCP 请求中提取上下文信息 // func (s *Server) extractRequestContext(request *mcpsdk.CallToolRequest) *ctx.RequestContext { // reqCtx := &ctx.RequestContext{} // if s.baseCtx != nil { // // 复制基础上下文 // *reqCtx = *s.baseCtx // } // // 从请求的 Extra 数据中提取自定义项目 ID // extra := request.GetExtra() // if extra != nil && extra.Header != nil { // // 确定项目 ID 头名称 // projectIDHeader := os.Getenv("MCP_PROJECT_ID_HEADER") // if projectIDHeader == "" { // projectIDHeader = "X-Project-ID" // } // if projectID := extra.Header.Get(projectIDHeader); projectID != "" { // // 将项目 ID 存储在 TraceID 中 // reqCtx.ProjectID = projectID // } // } // return reqCtx // } // SetTransport 设置传输层 func (s *Server) SetTransport(transport mcpsdk.Transport) { s.transport = transport } // GetSDKServer 返回底层的 SDK 服务器实例 func (s *Server) GetSDKServer() *mcpsdk.Server { return s.sdkServer } // startHTTPServer 启动 HTTP 服务器 func (s *Server) Run(handler http.Handler) { s.handler = handler addr := fmt.Sprintf(":%d", s.port) s.httpServer = &http.Server{ Addr: addr, Handler: s.handler, } log.Printf("%s listening on %s", s.serviceName, addr) // 在 goroutine 中启动服务器 go func() { if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("%s failed to start: %v", s.serviceName, err) } }() } // GetHTTPServer 返回内部的 HTTP 服务器实例 func (s *Server) GetHTTPServer() *http.Server { return s.httpServer }