| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- package mcp
-
- import (
- "encoding/json"
- "fmt"
- "reflect"
- "sync"
-
- "git.x2erp.com/qdy/go-base/ctx"
- "git.x2erp.com/qdy/go-db/factory/database"
- )
-
- // Tool 接口,工具实现此接口以支持自动注册
- type Tool interface {
- Name() string
- Description() string
- InputSchema() map[string]interface{}
- Execute(input json.RawMessage, dep *ToolDependencies) (interface{}, error)
- }
-
- // ToolDefinition 定义 MCP 工具
- type ToolDefinition struct {
- Name string `json:"name"`
- Description string `json:"description"`
- InputSchema map[string]interface{} `json:"inputSchema"`
- Execute ToolExecuteFunc `json:"-"`
- }
-
- // ToolExecuteFunc 工具执行函数签名
- type ToolExecuteFunc func(input json.RawMessage, dep *ToolDependencies) (interface{}, error)
-
- // ToolDependencies 工具执行依赖项
- type ToolDependencies struct {
- DBFactory *database.DBFactory
- ReqCtx *ctx.RequestContext
- }
-
- // globalRegistry 全局工具注册表
- var (
- globalRegistry = make(map[string]ToolDefinition)
- registryMu sync.RWMutex
- dependencies *ToolDependencies
- dependenciesOnce sync.Once
- )
-
- // Register 注册一个工具
- func Register(name, description string, inputSchema map[string]interface{}, execute ToolExecuteFunc) {
- registryMu.Lock()
- defer registryMu.Unlock()
-
- if _, exists := globalRegistry[name]; exists {
- panic(fmt.Sprintf("tool already registered: %s", name))
- }
-
- globalRegistry[name] = ToolDefinition{
- Name: name,
- Description: description,
- InputSchema: inputSchema,
- Execute: execute,
- }
- }
-
- // GetTool 获取工具定义
- func GetTool(name string) (ToolDefinition, bool) {
- registryMu.RLock()
- defer registryMu.RUnlock()
- tool, ok := globalRegistry[name]
- return tool, ok
- }
-
- // ListTools 返回所有工具定义
- func ListTools() []ToolDefinition {
- registryMu.RLock()
- defer registryMu.RUnlock()
- tools := make([]ToolDefinition, 0, len(globalRegistry))
- for _, tool := range globalRegistry {
- tools = append(tools, tool)
- }
- return tools
- }
-
- // // SetDependencies 设置全局依赖项
- // func SetDependencies(dbFactory *database.DBFactory, reqCtx *ctx.RequestContext) {
- // dependenciesOnce.Do(func() {
- // dependencies = &ToolDependencies{
- // DBFactory: dbFactory,
- // ReqCtx: reqCtx,
- // }
- // })
- // }
-
- // GetDependencies 获取依赖项(如果已设置)
- func GetDependencies() *ToolDependencies {
- return dependencies
- }
-
- // AutoRegister 自动注册实现 Tool 接口的类型
- func AutoRegister(tool interface{}) {
- val := reflect.ValueOf(tool)
- typ := val.Type()
-
- // 检查是否实现了 Tool 接口
- if tool, ok := tool.(Tool); ok {
- Register(tool.Name(), tool.Description(), tool.InputSchema(), tool.Execute)
- return
- }
-
- // 检查是否具有适当方法的其他接口
- // 这里可以根据需要扩展
- panic(fmt.Sprintf("type %v does not implement Tool interface", typ))
- }
|