Aucune description
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

registry.go 2.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package mcp
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "reflect"
  6. "sync"
  7. "git.x2erp.com/qdy/go-base/ctx"
  8. "git.x2erp.com/qdy/go-db/factory/database"
  9. )
  10. // Tool 接口,工具实现此接口以支持自动注册
  11. type Tool interface {
  12. Name() string
  13. Description() string
  14. InputSchema() map[string]interface{}
  15. Execute(input json.RawMessage, dep *ToolDependencies) (interface{}, error)
  16. }
  17. // ToolDefinition 定义 MCP 工具
  18. type ToolDefinition struct {
  19. Name string `json:"name"`
  20. Description string `json:"description"`
  21. InputSchema map[string]interface{} `json:"inputSchema"`
  22. Execute ToolExecuteFunc `json:"-"`
  23. }
  24. // ToolExecuteFunc 工具执行函数签名
  25. type ToolExecuteFunc func(input json.RawMessage, dep *ToolDependencies) (interface{}, error)
  26. // ToolDependencies 工具执行依赖项
  27. type ToolDependencies struct {
  28. DBFactory *database.DBFactory
  29. DBSFactory *database.DBSFactory
  30. ReqCtx *ctx.RequestContext
  31. }
  32. // globalRegistry 全局工具注册表
  33. var (
  34. globalRegistry = make(map[string]ToolDefinition)
  35. registryMu sync.RWMutex
  36. dependencies *ToolDependencies
  37. //dependenciesOnce sync.Once
  38. )
  39. // Register 注册一个工具
  40. func Register(name, description string, inputSchema map[string]interface{}, execute ToolExecuteFunc) {
  41. registryMu.Lock()
  42. defer registryMu.Unlock()
  43. if _, exists := globalRegistry[name]; exists {
  44. panic(fmt.Sprintf("tool already registered: %s", name))
  45. }
  46. globalRegistry[name] = ToolDefinition{
  47. Name: name,
  48. Description: description,
  49. InputSchema: inputSchema,
  50. Execute: execute,
  51. }
  52. }
  53. // GetTool 获取工具定义
  54. func GetTool(name string) (ToolDefinition, bool) {
  55. registryMu.RLock()
  56. defer registryMu.RUnlock()
  57. tool, ok := globalRegistry[name]
  58. return tool, ok
  59. }
  60. // ListTools 返回所有工具定义
  61. func ListTools() []ToolDefinition {
  62. registryMu.RLock()
  63. defer registryMu.RUnlock()
  64. tools := make([]ToolDefinition, 0, len(globalRegistry))
  65. for _, tool := range globalRegistry {
  66. tools = append(tools, tool)
  67. }
  68. return tools
  69. }
  70. // // SetDependencies 设置全局依赖项
  71. // func SetDependencies(dbFactory *database.DBFactory, reqCtx *ctx.RequestContext) {
  72. // dependenciesOnce.Do(func() {
  73. // dependencies = &ToolDependencies{
  74. // DBFactory: dbFactory,
  75. // ReqCtx: reqCtx,
  76. // }
  77. // })
  78. // }
  79. // GetDependencies 获取依赖项(如果已设置)
  80. func GetDependencies() *ToolDependencies {
  81. return dependencies
  82. }
  83. // AutoRegister 自动注册实现 Tool 接口的类型
  84. func AutoRegister(tool interface{}) {
  85. val := reflect.ValueOf(tool)
  86. typ := val.Type()
  87. // 检查是否实现了 Tool 接口
  88. if tool, ok := tool.(Tool); ok {
  89. Register(tool.Name(), tool.Description(), tool.InputSchema(), tool.Execute)
  90. return
  91. }
  92. // 检查是否具有适当方法的其他接口
  93. // 这里可以根据需要扩展
  94. panic(fmt.Sprintf("type %v does not implement Tool interface", typ))
  95. }