| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511 |
- package dbs
-
- import (
- "encoding/json"
- "fmt"
- "regexp"
- "strings"
- "time"
-
- "git.x2erp.com/qdy/go-svc-mcp/internal/mcp"
- )
-
- func init() {
- mcp.Register("get_common_query", "通用数据库查询工具,支持参数化查询和结果列自定义",
- map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "query_sql": map[string]interface{}{
- "type": "string",
- "description": "基础SQL查询语句(可包含WHERE、ORDER BY,或使用where_clause和order_by参数)",
- },
- "where_clause": map[string]interface{}{
- "type": "string",
- "description": "WHERE条件语句(可选,如果query_sql中已包含则不需要)",
- "default": "",
- },
- "order_by": map[string]interface{}{
- "type": "string",
- "description": "ORDER BY排序语句(可选,如果query_sql中已包含则不需要)",
- "default": "",
- },
- "params": map[string]interface{}{
- "type": "object",
- "description": "查询参数键值对,用于参数化查询(例如 {\"status\": \"ACTIVE\", \"min_age\": 18})",
- "default": map[string]interface{}{},
- },
- "columns_config": map[string]interface{}{
- "type": "array",
- "description": "执行查询返回的列配置信息,定义字段显示名称和宽度",
- "items": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "field_name": map[string]interface{}{
- "type": "string",
- "description": "查询返回的字段名称",
- },
- "display_name": map[string]interface{}{
- "type": "string",
- "description": "字段显示名称(中文)",
- "default": "",
- },
- "width": map[string]interface{}{
- "type": "integer",
- "description": "前端显示宽度(像素)",
- "default": 100,
- },
- "align": map[string]interface{}{
- "type": "string",
- "description": "对齐方式(left/center/right)",
- "default": "left",
- },
- },
- "required": []string{"field_name"},
- },
- "default": []interface{}{},
- },
- "page": map[string]interface{}{
- "type": "integer",
- "description": "页码(从1开始,0表示不分页)",
- "default": 0,
- "minimum": 0,
- },
- "page_size": map[string]interface{}{
- "type": "integer",
- "description": "每页记录数(最大100条)",
- "default": 10,
- "minimum": 1,
- "maximum": 100,
- },
- "schema": map[string]interface{}{
- "type": "string",
- "description": "数据库模式/名称(可选,部分数据库需要)",
- "default": "",
- },
- "include_total_count": map[string]interface{}{
- "type": "boolean",
- "description": "是否包含总记录数",
- "default": true,
- },
- "database_key": map[string]interface{}{
- "type": "string",
- "description": "数据库配置键名:warehouse(仓库数据库)或 business(业务数据库),可选,默认使用主数据库",
- "enum": []string{"warehouse", "business"},
- "default": "",
- },
- },
- "required": []string{"query_sql"},
- },
- func(input json.RawMessage, deps *mcp.ToolDependencies) (interface{}, error) {
- var params struct {
- QuerySQL string `json:"query_sql"`
- WhereClause string `json:"where_clause"`
- OrderBy string `json:"order_by"`
- Params map[string]interface{} `json:"params"`
- ColumnsConfig []map[string]interface{} `json:"columns_config"`
- Page int `json:"page"`
- PageSize int `json:"page_size"`
- Schema string `json:"schema"`
- IncludeTotalCount bool `json:"include_total_count"`
- DatabaseKey string `json:"database_key"`
- }
-
- if len(input) > 0 {
- if err := json.Unmarshal(input, ¶ms); err != nil {
- return nil, err
- }
- }
-
- // 设置默认值
- if params.PageSize == 0 {
- params.PageSize = 10
- }
- if params.PageSize > 100 {
- params.PageSize = 100
- }
-
- // 获取数据库工厂
- dbFactory, err := GetDBFactory(params.DatabaseKey, deps)
- if err != nil {
- return nil, err
- }
-
- // 获取数据库类型
- dbType := dbFactory.GetDBType()
-
- // 构建完整SQL
- finalSQL := params.QuerySQL
- if params.WhereClause != "" {
- // 检查query_sql是否已包含WHERE关键字
- if strings.Contains(strings.ToUpper(finalSQL), "WHERE") {
- finalSQL += " AND " + params.WhereClause
- } else {
- finalSQL += " WHERE " + params.WhereClause
- }
- }
- if params.OrderBy != "" {
- finalSQL += " ORDER BY " + params.OrderBy
- }
-
- // 根据数据库类型处理参数绑定
- processedSQL, queryParams, err := processParameters(finalSQL, params.Params, dbType)
- if err != nil {
- return nil, fmt.Errorf("参数处理失败: %v", err)
- }
-
- // 执行查询
- results, err := dbFactory.QuerySliceMapWithParams(processedSQL, queryParams...)
- if err != nil {
- return nil, fmt.Errorf("查询执行失败: %v", err)
- }
-
- // 处理列配置
- columnsMeta := processColumnsMetadata(results, params.ColumnsConfig)
-
- // 计算总记录数(如果需要)
- totalCount := int64(0)
- if params.IncludeTotalCount && params.Page > 0 {
- countSQL, countParams, err := buildCountSQL(finalSQL, params.Params, dbType)
- if err != nil {
- return nil, fmt.Errorf("构建计数SQL失败: %v", err)
- }
- countResults, err := dbFactory.QuerySliceMapWithParams(countSQL, countParams...)
- if err == nil && len(countResults) > 0 {
- if count, ok := countResults[0]["total_count"].(int64); ok {
- totalCount = count
- }
- }
- }
-
- // 应用分页(如果启用)
- var pagedResults []map[string]interface{}
- if params.Page > 0 && params.PageSize > 0 {
- offset := (params.Page - 1) * params.PageSize
- if offset < len(results) {
- end := offset + params.PageSize
- if end > len(results) {
- end = len(results)
- }
- pagedResults = results[offset:end]
- }
- } else {
- pagedResults = results
- }
-
- // 构建响应
- response := map[string]interface{}{
- "tenant_id": deps.ReqCtx.TenantID,
- "user_id": deps.ReqCtx.UserID,
- "database_type": dbType,
- "database_name": dbFactory.GetDatabaseName(),
- "schema": params.Schema,
- "original_sql": params.QuerySQL,
- "processed_sql": processedSQL,
- "parameters": params.Params,
- "columns_metadata": columnsMeta,
- "data": pagedResults,
- "data_count": len(pagedResults),
- "total_count": totalCount,
- "page": params.Page,
- "page_size": params.PageSize,
- "include_total_count": params.IncludeTotalCount,
- "timestamp": time.Now().Format(time.RFC3339),
- }
-
- // 如果启用了分页,添加分页信息
- if params.Page > 0 {
- totalPages := 0
- if totalCount > 0 {
- totalPages = int((totalCount + int64(params.PageSize) - 1) / int64(params.PageSize))
- }
- response["total_pages"] = totalPages
- response["has_more"] = int64(params.Page*params.PageSize) < totalCount
- }
-
- return response, nil
- },
- )
- }
-
- // processParameters 处理SQL参数绑定,根据数据库类型转换命名参数
- func processParameters(sql string, params map[string]interface{}, dbType string) (string, []interface{}, error) {
- if len(params) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 根据数据库类型选择参数处理器
- switch dbType {
- case "mysql", "doris":
- return processMySQLParameters(sql, params)
- case "postgresql":
- return processPostgreSQLParameters(sql, params)
- case "oracle":
- return processOracleParameters(sql, params)
- case "sqlserver":
- return processSQLServerParameters(sql, params)
- default:
- return sql, []interface{}{}, fmt.Errorf("不支持的数据库类型: %s", dbType)
- }
- }
-
- // parseNamedParameters 解析命名参数,返回替换后的SQL和有序参数值
- func parseNamedParameters(sql string, params map[string]interface{}, placeholderFunc func(int) string) (string, []interface{}, error) {
- // 正则匹配 :paramName 格式的参数
- re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`)
- matches := re.FindAllStringSubmatch(sql, -1)
- if len(matches) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 确定参数名到索引的映射(按第一次出现的顺序)
- paramIndex := make(map[string]int)
- var paramOrder []string
- for _, match := range matches {
- paramName := match[1]
- if _, exists := paramIndex[paramName]; !exists {
- paramIndex[paramName] = len(paramOrder)
- paramOrder = append(paramOrder, paramName)
- }
- }
-
- // 构建参数值列表(按索引顺序)
- paramValues := make([]interface{}, len(paramOrder))
- for i, paramName := range paramOrder {
- value, exists := params[paramName]
- if !exists {
- return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName)
- }
- paramValues[i] = value
- }
-
- // 替换占位符
- replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string {
- paramName := match[1:]
- index := paramIndex[paramName]
- return placeholderFunc(index + 1) // 通常占位符从1开始
- })
-
- return replacedSQL, paramValues, nil
- }
-
- // processColumnsMetadata 处理列元数据,合并查询结果和列配置
- func processColumnsMetadata(results []map[string]interface{}, columnsConfig []map[string]interface{}) []map[string]interface{} {
- if len(results) == 0 {
- return []map[string]interface{}{}
- }
-
- // 从第一条结果中提取列名
- var columns []map[string]interface{}
- if len(results) > 0 {
- for fieldName := range results[0] {
- columnMeta := map[string]interface{}{
- "field_name": fieldName,
- "display_name": fieldName,
- "width": 100,
- "align": "left",
- }
-
- // 查找列配置
- for _, config := range columnsConfig {
- if configFieldName, ok := config["field_name"].(string); ok && configFieldName == fieldName {
- if displayName, ok := config["display_name"].(string); ok && displayName != "" {
- columnMeta["display_name"] = displayName
- }
- if width, ok := config["width"].(float64); ok && width > 0 {
- columnMeta["width"] = int(width)
- }
- if align, ok := config["align"].(string); ok && align != "" {
- columnMeta["align"] = align
- }
- break
- }
- }
- columns = append(columns, columnMeta)
- }
- }
- return columns
- }
-
- // buildCountSQL 构建计数SQL
- func buildCountSQL(originalSQL string, params map[string]interface{}, dbType string) (string, []interface{}, error) {
- // 移除ORDER BY子句
- countSQL := regexp.MustCompile(`(?i)\s+ORDER BY\s+.*$`).ReplaceAllString(originalSQL, "")
-
- // 构建COUNT查询
- countSQL = fmt.Sprintf("SELECT COUNT(*) as total_count FROM (%s) as count_query", countSQL)
-
- // 处理参数
- return processParameters(countSQL, params, dbType)
- }
-
- // processMySQLParameters 处理MySQL/Doris数据库参数绑定
- func processMySQLParameters(sql string, params map[string]interface{}) (string, []interface{}, error) {
- if len(params) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 正则匹配 :paramName 格式的参数
- re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`)
- matches := re.FindAllStringSubmatch(sql, -1)
- if len(matches) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 按出现顺序收集参数值(允许重复)
- var paramValues []interface{}
- replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string {
- paramName := match[1:]
- value, exists := params[paramName]
- if !exists {
- // 如果参数未提供,保留原占位符(后续会报错)
- return match
- }
- paramValues = append(paramValues, value)
- return "?"
- })
-
- // 检查是否有参数未提供
- if len(paramValues) != len(matches) {
- // 找出未提供的参数名
- for _, match := range matches {
- paramName := match[1]
- if _, exists := params[paramName]; !exists {
- return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName)
- }
- }
- }
-
- return replacedSQL, paramValues, nil
- }
-
- // processPostgreSQLParameters 处理PostgreSQL数据库参数绑定
- func processPostgreSQLParameters(sql string, params map[string]interface{}) (string, []interface{}, error) {
- if len(params) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 正则匹配 :paramName 格式的参数
- re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`)
- matches := re.FindAllStringSubmatch(sql, -1)
- if len(matches) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 确定参数名到索引的映射(按第一次出现的顺序)
- paramIndex := make(map[string]int)
- var paramOrder []string
- for _, match := range matches {
- paramName := match[1]
- if _, exists := paramIndex[paramName]; !exists {
- paramIndex[paramName] = len(paramOrder)
- paramOrder = append(paramOrder, paramName)
- }
- }
-
- // 构建参数值列表(按索引顺序)
- paramValues := make([]interface{}, len(paramOrder))
- for i, paramName := range paramOrder {
- value, exists := params[paramName]
- if !exists {
- return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName)
- }
- paramValues[i] = value
- }
-
- // 替换占位符为 $1, $2, ...
- replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string {
- paramName := match[1:]
- index := paramIndex[paramName]
- return fmt.Sprintf("$%d", index+1)
- })
-
- return replacedSQL, paramValues, nil
- }
-
- // processOracleParameters 处理Oracle数据库参数绑定
- func processOracleParameters(sql string, params map[string]interface{}) (string, []interface{}, error) {
- if len(params) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 正则匹配 :paramName 格式的参数
- re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`)
- matches := re.FindAllStringSubmatch(sql, -1)
- if len(matches) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 确定参数名到索引的映射(按第一次出现的顺序)
- paramIndex := make(map[string]int)
- var paramOrder []string
- for _, match := range matches {
- paramName := match[1]
- if _, exists := paramIndex[paramName]; !exists {
- paramIndex[paramName] = len(paramOrder)
- paramOrder = append(paramOrder, paramName)
- }
- }
-
- // 构建参数值列表(按索引顺序)
- paramValues := make([]interface{}, len(paramOrder))
- for i, paramName := range paramOrder {
- value, exists := params[paramName]
- if !exists {
- return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName)
- }
- paramValues[i] = value
- }
-
- // 替换占位符为 :1, :2, ...(Oracle支持数字占位符)
- replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string {
- paramName := match[1:]
- index := paramIndex[paramName]
- return fmt.Sprintf(":%d", index+1)
- })
-
- return replacedSQL, paramValues, nil
- }
-
- // processSQLServerParameters 处理SQL Server数据库参数绑定
- func processSQLServerParameters(sql string, params map[string]interface{}) (string, []interface{}, error) {
- if len(params) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 正则匹配 :paramName 格式的参数
- re := regexp.MustCompile(`:([a-zA-Z_][a-zA-Z0-9_]*)`)
- matches := re.FindAllStringSubmatch(sql, -1)
- if len(matches) == 0 {
- return sql, []interface{}{}, nil
- }
-
- // 确定参数名到索引的映射(按第一次出现的顺序)
- paramIndex := make(map[string]int)
- var paramOrder []string
- for _, match := range matches {
- paramName := match[1]
- if _, exists := paramIndex[paramName]; !exists {
- paramIndex[paramName] = len(paramOrder)
- paramOrder = append(paramOrder, paramName)
- }
- }
-
- // 构建参数值列表(按索引顺序)
- paramValues := make([]interface{}, len(paramOrder))
- for i, paramName := range paramOrder {
- value, exists := params[paramName]
- if !exists {
- return "", nil, fmt.Errorf("参数 '%s' 未提供", paramName)
- }
- paramValues[i] = value
- }
-
- // 替换占位符为 @p1, @p2, ...(SQL Server支持命名参数)
- replacedSQL := re.ReplaceAllStringFunc(sql, func(match string) string {
- paramName := match[1:]
- index := paramIndex[paramName]
- return fmt.Sprintf("@p%d", index+1)
- })
-
- return replacedSQL, paramValues, nil
- }
|