Aucune description
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

registry.go 2.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package mcp
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "reflect"
  6. "sync"
  7. "git.x2erp.com/qdy/go-base/ctx"
  8. "git.x2erp.com/qdy/go-db/factory/database"
  9. )
  10. // Tool 接口,工具实现此接口以支持自动注册
  11. type Tool interface {
  12. Name() string
  13. Description() string
  14. InputSchema() map[string]interface{}
  15. Execute(input json.RawMessage, dep *ToolDependencies) (interface{}, error)
  16. }
  17. // ToolDefinition 定义 MCP 工具
  18. type ToolDefinition struct {
  19. Name string `json:"name"`
  20. Description string `json:"description"`
  21. InputSchema map[string]interface{} `json:"inputSchema"`
  22. Execute ToolExecuteFunc `json:"-"`
  23. }
  24. // ToolExecuteFunc 工具执行函数签名
  25. type ToolExecuteFunc func(input json.RawMessage, dep *ToolDependencies) (interface{}, error)
  26. // ToolDependencies 工具执行依赖项
  27. type ToolDependencies struct {
  28. DBFactory *database.DBFactory
  29. ReqCtx *ctx.RequestContext
  30. }
  31. // globalRegistry 全局工具注册表
  32. var (
  33. globalRegistry = make(map[string]ToolDefinition)
  34. registryMu sync.RWMutex
  35. dependencies *ToolDependencies
  36. dependenciesOnce sync.Once
  37. )
  38. // Register 注册一个工具
  39. func Register(name, description string, inputSchema map[string]interface{}, execute ToolExecuteFunc) {
  40. registryMu.Lock()
  41. defer registryMu.Unlock()
  42. if _, exists := globalRegistry[name]; exists {
  43. panic(fmt.Sprintf("tool already registered: %s", name))
  44. }
  45. globalRegistry[name] = ToolDefinition{
  46. Name: name,
  47. Description: description,
  48. InputSchema: inputSchema,
  49. Execute: execute,
  50. }
  51. }
  52. // GetTool 获取工具定义
  53. func GetTool(name string) (ToolDefinition, bool) {
  54. registryMu.RLock()
  55. defer registryMu.RUnlock()
  56. tool, ok := globalRegistry[name]
  57. return tool, ok
  58. }
  59. // ListTools 返回所有工具定义
  60. func ListTools() []ToolDefinition {
  61. registryMu.RLock()
  62. defer registryMu.RUnlock()
  63. tools := make([]ToolDefinition, 0, len(globalRegistry))
  64. for _, tool := range globalRegistry {
  65. tools = append(tools, tool)
  66. }
  67. return tools
  68. }
  69. // // SetDependencies 设置全局依赖项
  70. // func SetDependencies(dbFactory *database.DBFactory, reqCtx *ctx.RequestContext) {
  71. // dependenciesOnce.Do(func() {
  72. // dependencies = &ToolDependencies{
  73. // DBFactory: dbFactory,
  74. // ReqCtx: reqCtx,
  75. // }
  76. // })
  77. // }
  78. // GetDependencies 获取依赖项(如果已设置)
  79. func GetDependencies() *ToolDependencies {
  80. return dependencies
  81. }
  82. // AutoRegister 自动注册实现 Tool 接口的类型
  83. func AutoRegister(tool interface{}) {
  84. val := reflect.ValueOf(tool)
  85. typ := val.Type()
  86. // 检查是否实现了 Tool 接口
  87. if tool, ok := tool.(Tool); ok {
  88. Register(tool.Name(), tool.Description(), tool.InputSchema(), tool.Execute)
  89. return
  90. }
  91. // 检查是否具有适当方法的其他接口
  92. // 这里可以根据需要扩展
  93. panic(fmt.Sprintf("type %v does not implement Tool interface", typ))
  94. }