| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- package mcp
-
- import (
- "context"
- "encoding/json"
- "fmt"
- "log"
- "net/http"
- "os"
-
- "git.x2erp.com/qdy/go-base/ctx"
- "git.x2erp.com/qdy/go-db/factory/database"
- mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
- )
-
- // Server 包装 MCP SDK 服务器,提供自动注册和依赖注入
- type Server struct {
- port int
- serviceName string
- sdkServer *mcpsdk.Server
- transport mcpsdk.Transport
- dbFactory *database.DBFactory
- baseCtx *ctx.RequestContext
- httpServer *http.Server
- handler http.Handler
- }
-
- // Config 服务器配置
- type Config struct {
- Name string
- Version string
- Port int
- ServiceName string
- Description string
- DBFactory *database.DBFactory
- BaseCtx *ctx.RequestContext
- }
-
- // NewServer 创建新的 MCP 服务器
- func NewServer(cfg Config) (*Server, error) {
- impl := &mcpsdk.Implementation{
- Name: cfg.Name,
- Version: cfg.Version,
- }
- sdkServer := mcpsdk.NewServer(impl, nil)
-
- server := &Server{
- sdkServer: sdkServer,
- dbFactory: cfg.DBFactory,
- baseCtx: cfg.BaseCtx,
- port: cfg.Port,
- serviceName: cfg.ServiceName,
- }
-
- // 自动注册所有已注册的工具
- if err := server.registerAllTools(); err != nil {
- return nil, fmt.Errorf("failed to register tools: %w", err)
- }
-
- return server, nil
- }
-
- // registerAllTools 将注册表中的所有工具注册到 MCP 服务器
- func (s *Server) registerAllTools() error {
- tools := ListTools()
- for _, tool := range tools {
- if err := s.registerTool(tool); err != nil {
- return fmt.Errorf("failed to register tool %s: %w", tool.Name, err)
- }
- }
- log.Printf("Registered %d MCP tools", len(tools))
- return nil
- }
-
- // registerTool 注册单个工具到 MCP 服务器
- func (s *Server) registerTool(tool ToolDefinition) error {
- // 创建工具处理器
- handler := s.createToolHandler(tool)
-
- // 创建 MCP 工具
- mcpTool := &mcpsdk.Tool{
- Name: tool.Name,
- Description: tool.Description,
- InputSchema: tool.InputSchema,
- }
-
- // 注册工具到服务器
- mcpsdk.AddTool(s.sdkServer, mcpTool, handler)
- return nil
- }
-
- // createToolHandler 创建 MCP 工具处理器
- func (s *Server) createToolHandler(tool ToolDefinition) mcpsdk.ToolHandlerFor[map[string]interface{}, interface{}] {
- return func(_ context.Context, request *mcpsdk.CallToolRequest, input map[string]interface{}) (*mcpsdk.CallToolResult, interface{}, error) {
- // 将输入转换为 JSON
- inputJSON, err := json.Marshal(input)
- if err != nil {
- return nil, nil, fmt.Errorf("failed to marshal input: %w", err)
- }
-
- // 提取请求上下文信息
- reqCtx := s.extractRequestContext(request)
-
- // 创建工具依赖项
- toolDeps := &ToolDependencies{
- DBFactory: s.dbFactory,
- ReqCtx: reqCtx,
- }
-
- // 执行工具
- result, err := tool.Execute(json.RawMessage(inputJSON), toolDeps)
- if err != nil {
- // 返回工具错误(非协议错误)
- return &mcpsdk.CallToolResult{
- IsError: true,
- Content: []mcpsdk.Content{
- &mcpsdk.TextContent{Text: fmt.Sprintf("tool error: %v", err)},
- },
- }, nil, nil
- }
-
- // 返回成功结果
- return nil, result, nil
- }
- }
-
- // extractRequestContext 从 MCP 请求中提取上下文信息
- func (s *Server) extractRequestContext(request *mcpsdk.CallToolRequest) *ctx.RequestContext {
- reqCtx := &ctx.RequestContext{}
- if s.baseCtx != nil {
- // 复制基础上下文
- *reqCtx = *s.baseCtx
- }
-
- // 从请求的 Extra 数据中提取自定义项目 ID
- extra := request.GetExtra()
- if extra != nil && extra.Header != nil {
- // 确定项目 ID 头名称
- projectIDHeader := os.Getenv("MCP_PROJECT_ID_HEADER")
- if projectIDHeader == "" {
- projectIDHeader = "X-Project-ID"
- }
- if projectID := extra.Header.Get(projectIDHeader); projectID != "" {
- // 将项目 ID 存储在 TraceID 中
- reqCtx.ProjectID = projectID
- }
- }
-
- return reqCtx
- }
-
- // SetTransport 设置传输层
- func (s *Server) SetTransport(transport mcpsdk.Transport) {
- s.transport = transport
- }
-
- // GetSDKServer 返回底层的 SDK 服务器实例
- func (s *Server) GetSDKServer() *mcpsdk.Server {
- return s.sdkServer
- }
-
- // startHTTPServer 启动 HTTP 服务器
- func (s *Server) Run(handler http.Handler) {
-
- s.handler = handler
- addr := fmt.Sprintf(":%d", s.port)
-
- s.httpServer = &http.Server{
- Addr: addr,
- Handler: s.handler,
- }
-
- log.Printf("%s listening on %s", s.serviceName, addr)
-
- // 在 goroutine 中启动服务器
- go func() {
- if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- log.Fatalf("%s failed to start: %v", s.serviceName, err)
- }
- }()
- }
-
- // GetHTTPServer 返回内部的 HTTP 服务器实例
- func (s *Server) GetHTTPServer() *http.Server {
- return s.httpServer
- }
|