package dbs import ( "encoding/json" "fmt" "regexp" "strings" "time" "git.x2erp.com/qdy/go-svc-mcp/internal/mcp" ) func init() { mcp.Register("get_common_query", "通用数据库查询工具,支持参数化查询和结果列自定义", map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "query_sql": map[string]interface{}{ "type": "string", "description": "基础SQL查询语句(可包含WHERE、ORDER BY,或使用where_clause和order_by参数)", }, "where_clause": map[string]interface{}{ "type": "string", "description": "WHERE条件语句(可选,如果query_sql中已包含则不需要)", "default": "", }, "order_by": map[string]interface{}{ "type": "string", "description": "ORDER BY排序语句(可选,如果query_sql中已包含则不需要)", "default": "", }, "params": map[string]interface{}{ "type": "object", "description": "查询参数键值对,用于参数化查询(例如 {\"status\": \"ACTIVE\", \"min_age\": 18})", "default": map[string]interface{}{}, }, "columns_config": map[string]interface{}{ "type": "array", "description": "执行查询返回的列配置信息,定义字段显示名称和宽度", "items": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "field_name": map[string]interface{}{ "type": "string", "description": "查询返回的字段名称", }, "display_name": map[string]interface{}{ "type": "string", "description": "字段显示名称(中文)", "default": "", }, "width": map[string]interface{}{ "type": "integer", "description": "前端显示宽度(像素)", "default": 100, }, "align": map[string]interface{}{ "type": "string", "description": "对齐方式(left/center/right)", "default": "left", }, }, "required": []string{"field_name"}, }, "default": []interface{}{}, }, "page": map[string]interface{}{ "type": "integer", "description": "页码(从1开始,0表示不分页)", "default": 0, "minimum": 0, }, "page_size": map[string]interface{}{ "type": "integer", "description": "每页记录数(最大100条)", "default": 10, "minimum": 1, "maximum": 100, }, "schema": map[string]interface{}{ "type": "string", "description": "数据库模式/名称(可选,部分数据库需要)", "default": "", }, "include_total_count": map[string]interface{}{ "type": "boolean", "description": "是否包含总记录数", "default": true, }, "database_key": map[string]interface{}{ "type": "string", "description": "数据库配置键名:warehouse(仓库数据库)或 business(业务数据库),可选,默认使用主数据库", "enum": []string{"warehouse", "business"}, "default": "", }, }, "required": []string{"query_sql"}, }, func(input json.RawMessage, deps *mcp.ToolDependencies) (interface{}, error) { var params struct { QuerySQL string `json:"query_sql"` WhereClause string `json:"where_clause"` OrderBy string `json:"order_by"` Params map[string]interface{} `json:"params"` ColumnsConfig []map[string]interface{} `json:"columns_config"` Page int `json:"page"` PageSize int `json:"page_size"` Schema string `json:"schema"` IncludeTotalCount bool `json:"include_total_count"` DatabaseKey string `json:"database_key"` } if len(input) > 0 { if err := json.Unmarshal(input, ¶ms); err != nil { return nil, err } } // 设置默认值 if params.PageSize == 0 { params.PageSize = 10 } if params.PageSize > 100 { params.PageSize = 100 } // 获取数据库工厂 dbFactory, err := GetDBFactory(params.DatabaseKey, deps) if err != nil { return nil, err } // 获取数据库类型 dbType := dbFactory.GetDBType() // 构建完整SQL finalSQL := params.QuerySQL if params.WhereClause != "" { // 检查query_sql是否已包含WHERE关键字 if strings.Contains(strings.ToUpper(finalSQL), "WHERE") { finalSQL += " AND " + params.WhereClause } else { finalSQL += " WHERE " + params.WhereClause } } if params.OrderBy != "" { finalSQL += " ORDER BY " + params.OrderBy } // 根据数据库类型处理参数绑定 processedSQL, queryParams, err := processParameters(finalSQL, params.Params, dbType) if err != nil { return nil, fmt.Errorf("参数处理失败: %v", err) } // 执行查询 results, err := dbFactory.QuerySliceMapWithParams(processedSQL, queryParams...) if err != nil { return nil, fmt.Errorf("查询执行失败: %v", err) } // 处理列配置 columnsMeta := processColumnsMetadata(results, params.ColumnsConfig) // 计算总记录数(如果需要) totalCount := int64(0) if params.IncludeTotalCount && params.Page > 0 { countSQL, countParams, err := buildCountSQL(finalSQL, params.Params, dbType) if err != nil { return nil, fmt.Errorf("构建计数SQL失败: %v", err) } countResults, err := dbFactory.QuerySliceMapWithParams(countSQL, countParams...) if err == nil && len(countResults) > 0 { if count, ok := countResults[0]["total_count"].(int64); ok { totalCount = count } } } // 应用分页(如果启用) var pagedResults []map[string]interface{} if params.Page > 0 && params.PageSize > 0 { offset := (params.Page - 1) * params.PageSize if offset < len(results) { end := offset + params.PageSize if end > len(results) { end = len(results) } pagedResults = results[offset:end] } } else { pagedResults = results } // 构建响应 response := map[string]interface{}{ "tenant_id": deps.ReqCtx.TenantID, "user_id": deps.ReqCtx.UserID, "database_type": dbType, "database_name": dbFactory.GetDatabaseName(), "schema": params.Schema, "original_sql": params.QuerySQL, "processed_sql": processedSQL, "parameters": params.Params, "columns_metadata": columnsMeta, "data": pagedResults, "data_count": len(pagedResults), "total_count": totalCount, "page": params.Page, "page_size": params.PageSize, "include_total_count": params.IncludeTotalCount, "timestamp": time.Now().Format(time.RFC3339), } // 如果启用了分页,添加分页信息 if params.Page > 0 { totalPages := 0 if totalCount > 0 { totalPages = int((totalCount + int64(params.PageSize) - 1) / int64(params.PageSize)) } response["total_pages"] = totalPages response["has_more"] = int64(params.Page*params.PageSize) < totalCount } return response, nil }, ) } // processParameters 处理SQL参数绑定,根据数据库类型转换命名参数 func processParameters(sql string, params map[string]interface{}, dbType string) (string, []interface{}, error) { if len(params) == 0 { return sql, []interface{}{}, nil } // 根据数据库类型选择参数处理器 switch dbType { case "mysql", "doris": return processMySQLParameters(sql, params) case "postgresql": return processPostgreSQLParameters(sql, params) case "oracle": return processOracleParameters(sql, params) case "sqlserver": return processSQLServerParameters(sql, params) default: return sql, []interface{}{}, fmt.Errorf("不支持的数据库类型: %s", dbType) } } // parseNamedParameters 解析命名参数,返回替换后的SQL和有序参数值 func parseNamedParameters(sql string, params map[string]interface{}, placeholderFunc func(int) string) (string, []interface{}, error) { // 正则匹配 :paramName 格式的参数 re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`) matches := re.FindAllStringSubmatch(sql, -1) if len(matches) == 0 { return sql, []interface{}{}, nil } // 确定参数名到索引的映射(按第一次出现的顺序) paramIndex := make(map[string]int) var paramOrder []string for _, match := range matches { paramName := match[1] if _, exists := paramIndex[paramName]; !exists { paramIndex[paramName] = len(paramOrder) paramOrder = append(paramOrder, paramName) } } // 构建参数值列表(按索引顺序) paramValues := make([]interface{}, len(paramOrder)) for i, paramName := range paramOrder { value, exists := params[paramName] if !exists { return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName) } paramValues[i] = value } // 替换占位符 replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string { paramName := match[1:] index := paramIndex[paramName] return placeholderFunc(index + 1) // 通常占位符从1开始 }) return replacedSQL, paramValues, nil } // processColumnsMetadata 处理列元数据,合并查询结果和列配置 func processColumnsMetadata(results []map[string]interface{}, columnsConfig []map[string]interface{}) []map[string]interface{} { if len(results) == 0 { return []map[string]interface{}{} } // 从第一条结果中提取列名 var columns []map[string]interface{} if len(results) > 0 { for fieldName := range results[0] { columnMeta := map[string]interface{}{ "field_name": fieldName, "display_name": fieldName, "width": 100, "align": "left", } // 查找列配置 for _, config := range columnsConfig { if configFieldName, ok := config["field_name"].(string); ok && configFieldName == fieldName { if displayName, ok := config["display_name"].(string); ok && displayName != "" { columnMeta["display_name"] = displayName } if width, ok := config["width"].(float64); ok && width > 0 { columnMeta["width"] = int(width) } if align, ok := config["align"].(string); ok && align != "" { columnMeta["align"] = align } break } } columns = append(columns, columnMeta) } } return columns } // buildCountSQL 构建计数SQL func buildCountSQL(originalSQL string, params map[string]interface{}, dbType string) (string, []interface{}, error) { // 移除ORDER BY子句 countSQL := regexp.MustCompile(`(?i)\s+ORDER BY\s+.*$`).ReplaceAllString(originalSQL, "") // 构建COUNT查询 countSQL = fmt.Sprintf("SELECT COUNT(*) as total_count FROM (%s) as count_query", countSQL) // 处理参数 return processParameters(countSQL, params, dbType) } // processMySQLParameters 处理MySQL/Doris数据库参数绑定 func processMySQLParameters(sql string, params map[string]interface{}) (string, []interface{}, error) { if len(params) == 0 { return sql, []interface{}{}, nil } // 正则匹配 :paramName 格式的参数 re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`) matches := re.FindAllStringSubmatch(sql, -1) if len(matches) == 0 { return sql, []interface{}{}, nil } // 按出现顺序收集参数值(允许重复) var paramValues []interface{} replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string { paramName := match[1:] value, exists := params[paramName] if !exists { // 如果参数未提供,保留原占位符(后续会报错) return match } paramValues = append(paramValues, value) return "?" }) // 检查是否有参数未提供 if len(paramValues) != len(matches) { // 找出未提供的参数名 for _, match := range matches { paramName := match[1] if _, exists := params[paramName]; !exists { return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName) } } } return replacedSQL, paramValues, nil } // processPostgreSQLParameters 处理PostgreSQL数据库参数绑定 func processPostgreSQLParameters(sql string, params map[string]interface{}) (string, []interface{}, error) { if len(params) == 0 { return sql, []interface{}{}, nil } // 正则匹配 :paramName 格式的参数 re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`) matches := re.FindAllStringSubmatch(sql, -1) if len(matches) == 0 { return sql, []interface{}{}, nil } // 确定参数名到索引的映射(按第一次出现的顺序) paramIndex := make(map[string]int) var paramOrder []string for _, match := range matches { paramName := match[1] if _, exists := paramIndex[paramName]; !exists { paramIndex[paramName] = len(paramOrder) paramOrder = append(paramOrder, paramName) } } // 构建参数值列表(按索引顺序) paramValues := make([]interface{}, len(paramOrder)) for i, paramName := range paramOrder { value, exists := params[paramName] if !exists { return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName) } paramValues[i] = value } // 替换占位符为 $1, $2, ... replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string { paramName := match[1:] index := paramIndex[paramName] return fmt.Sprintf("$%d", index+1) }) return replacedSQL, paramValues, nil } // processOracleParameters 处理Oracle数据库参数绑定 func processOracleParameters(sql string, params map[string]interface{}) (string, []interface{}, error) { if len(params) == 0 { return sql, []interface{}{}, nil } // 正则匹配 :paramName 格式的参数 re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`) matches := re.FindAllStringSubmatch(sql, -1) if len(matches) == 0 { return sql, []interface{}{}, nil } // 确定参数名到索引的映射(按第一次出现的顺序) paramIndex := make(map[string]int) var paramOrder []string for _, match := range matches { paramName := match[1] if _, exists := paramIndex[paramName]; !exists { paramIndex[paramName] = len(paramOrder) paramOrder = append(paramOrder, paramName) } } // 构建参数值列表(按索引顺序) paramValues := make([]interface{}, len(paramOrder)) for i, paramName := range paramOrder { value, exists := params[paramName] if !exists { return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName) } paramValues[i] = value } // 替换占位符为 :1, :2, ...(Oracle支持数字占位符) replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string { paramName := match[1:] index := paramIndex[paramName] return fmt.Sprintf(":%d", index+1) }) return replacedSQL, paramValues, nil } // processSQLServerParameters 处理SQL Server数据库参数绑定 func processSQLServerParameters(sql string, params map[string]interface{}) (string, []interface{}, error) { if len(params) == 0 { return sql, []interface{}{}, nil } // 正则匹配 :paramName 格式的参数 re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`) matches := re.FindAllStringSubmatch(sql, -1) if len(matches) == 0 { return sql, []interface{}{}, nil } // 确定参数名到索引的映射(按第一次出现的顺序) paramIndex := make(map[string]int) var paramOrder []string for _, match := range matches { paramName := match[1] if _, exists := paramIndex[paramName]; !exists { paramIndex[paramName] = len(paramOrder) paramOrder = append(paramOrder, paramName) } } // 构建参数值列表(按索引顺序) paramValues := make([]interface{}, len(paramOrder)) for i, paramName := range paramOrder { value, exists := params[paramName] if !exists { return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName) } paramValues[i] = value } // 替换占位符为 @p1, @p2, ...(SQL Server支持命名参数) replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string { paramName := match[1:] index := paramIndex[paramName] return fmt.Sprintf("@p%d", index+1) }) return replacedSQL, paramValues, nil }