package dbs import ( "encoding/json" "fmt" "strings" "time" "git.x2erp.com/qdy/go-svc-mcp/internal/mcp" ) func init() { mcp.Register("get_oracle_create_table", "在Oracle数据库中创建新表,支持字段约束和注释", map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "table_name": map[string]interface{}{ "type": "string", "description": "表名称", }, "table_description": map[string]interface{}{ "type": "string", "description": "表描述/注释", "default": "", }, "schema": map[string]interface{}{ "type": "string", "description": "模式名称(默认为当前用户)", "default": "", }, "database_key": map[string]interface{}{ "type": "string", "description": "数据库配置键名(如:business),可选,默认使用主数据库", "enum": []string{"warehouse", "business"}, "default": "warehouse", }, "fields": map[string]interface{}{ "type": "array", "description": "字段定义", "items": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "field_name": map[string]interface{}{ "type": "string", "description": "字段名称", }, "data_type": map[string]interface{}{ "type": "string", "description": "数据类型(如 VARCHAR2(255), NUMBER, DATE 等)", }, "is_nullable": map[string]interface{}{ "type": "boolean", "description": "是否允许为空", "default": true, }, "field_default": map[string]interface{}{ "type": "string", "description": "默认值", "default": "", }, "field_description": map[string]interface{}{ "type": "string", "description": "字段描述/注释", "default": "", }, "is_primary_key": map[string]interface{}{ "type": "boolean", "description": "是否为主键", "default": false, }, "is_unique": map[string]interface{}{ "type": "boolean", "description": "是否唯一约束", "default": false, }, "auto_increment": map[string]interface{}{ "type": "boolean", "description": "是否自增(使用GENERATED ALWAYS AS IDENTITY)", "default": false, }, }, "required": []string{"field_name", "data_type"}, }, }, }, "required": []string{"table_name", "fields"}, }, func(input json.RawMessage, deps *mcp.ToolDependencies) (interface{}, error) { var params struct { TableName string `json:"table_name"` TableDescription string `json:"table_description"` Schema string `json:"schema"` DatabaseKey string `json:"database_key"` Fields []struct { FieldName string `json:"field_name"` DataType string `json:"data_type"` IsNullable bool `json:"is_nullable"` FieldDefault string `json:"field_default"` FieldDescription string `json:"field_description"` IsPrimaryKey bool `json:"is_primary_key"` IsUnique bool `json:"is_unique"` AutoIncrement bool `json:"auto_increment"` } `json:"fields"` } if len(input) > 0 { if err := json.Unmarshal(input, ¶ms); err != nil { return nil, err } } // 获取数据库工厂 dbFactory, err := GetDBFactory(params.DatabaseKey, deps) if err != nil { return nil, err } // 获取数据库类型,确保是Oracle dbType := dbFactory.GetDBType() if dbType != "oracle" { return nil, fmt.Errorf("当前数据库类型为 %s,此工具仅支持Oracle数据库", dbType) } // 获取当前用户名(Oracle中的schema) currentUser := dbFactory.GetDatabaseName() schema := strings.TrimSpace(params.Schema) if schema == "" { schema = currentUser } tableName := strings.TrimSpace(params.TableName) if tableName == "" { return nil, fmt.Errorf("表名称不能为空") } if len(params.Fields) == 0 { return nil, fmt.Errorf("至少需要定义一个字段") } // 检查表是否已存在 var tableExistsQuery string var tableCheckResults []map[string]interface{} if strings.ToUpper(schema) == strings.ToUpper(currentUser) { // 检查当前用户下的表 tableExistsQuery = `SELECT COUNT(*) as table_count FROM user_tables WHERE table_name = UPPER(:1)` tableCheckResults, err = dbFactory.QuerySliceMapWithParams(tableExistsQuery, tableName) } else { // 检查指定模式下的表 tableExistsQuery = `SELECT COUNT(*) as table_count FROM all_tables WHERE owner = UPPER(:1) AND table_name = UPPER(:2)` tableCheckResults, err = dbFactory.QuerySliceMapWithParams(tableExistsQuery, schema, tableName) } if err != nil { return nil, fmt.Errorf("检查表是否存在失败: %v", err) } tableExists := false if len(tableCheckResults) > 0 { if count, ok := tableCheckResults[0]["table_count"].(int64); ok && count > 0 { tableExists = true } } if tableExists { return nil, fmt.Errorf("表 '%s' 已存在于模式 '%s' 中", tableName, schema) } // 构建建表SQL var fieldDefinitions []string var primaryKeyFields []string for _, field := range params.Fields { fieldName := strings.TrimSpace(field.FieldName) dataType := strings.TrimSpace(field.DataType) if fieldName == "" || dataType == "" { return nil, fmt.Errorf("字段名称和数据类型不能为空") } // 构建字段定义 fieldDef := fmt.Sprintf(`"%s" %s`, fieldName, dataType) // 添加NOT NULL约束 if !field.IsNullable { fieldDef += " NOT NULL" } // 添加默认值 if field.FieldDefault != "" { fieldDef += fmt.Sprintf(" DEFAULT %s", field.FieldDefault) } // 添加自增(Oracle 12c+ 使用IDENTITY列) if field.AutoIncrement { fieldDef += " GENERATED ALWAYS AS IDENTITY" } // 添加唯一约束(如果不是主键) if field.IsUnique && !field.IsPrimaryKey { fieldDef += " UNIQUE" } fieldDefinitions = append(fieldDefinitions, fieldDef) // 收集主键字段 if field.IsPrimaryKey { primaryKeyFields = append(primaryKeyFields, fmt.Sprintf(`"%s"`, fieldName)) } } // 构建完整的CREATE TABLE语句 createTableSQL := fmt.Sprintf(`CREATE TABLE "%s"."%s" (%s)`, schema, tableName, strings.Join(fieldDefinitions, ", ")) // 添加主键约束 if len(primaryKeyFields) > 0 { createTableSQL += fmt.Sprintf(` CONSTRAINT "%s_PK" PRIMARY KEY (%s)`, tableName, strings.Join(primaryKeyFields, ", ")) } // 执行建表SQL _, err = dbFactory.Execute(createTableSQL) if err != nil { return nil, fmt.Errorf("创建表失败: %v", err) } // 添加表注释 tableDescription := strings.TrimSpace(params.TableDescription) if tableDescription != "" { commentSQL := fmt.Sprintf(`COMMENT ON TABLE "%s"."%s" IS '%s'`, schema, tableName, strings.ReplaceAll(tableDescription, "'", "''")) _, err = dbFactory.Execute(commentSQL) if err != nil { fmt.Printf("添加表注释失败: %v\n", err) } } // 添加字段注释 for _, field := range params.Fields { fieldDescription := strings.TrimSpace(field.FieldDescription) if fieldDescription != "" { fieldName := strings.TrimSpace(field.FieldName) commentSQL := fmt.Sprintf(`COMMENT ON COLUMN "%s"."%s"."%s" IS '%s'`, schema, tableName, fieldName, strings.ReplaceAll(fieldDescription, "'", "''")) _, err = dbFactory.Execute(commentSQL) if err != nil { fmt.Printf("添加字段注释失败: %v\n", err) } } } return map[string]interface{}{ "tenant_id": deps.ReqCtx.TenantID, "user_id": deps.ReqCtx.UserID, "database_type": dbType, "database_name": dbFactory.GetDatabaseName(), "schema": schema, "table_name": tableName, "table_description": tableDescription, "total_fields": len(params.Fields), "primary_key_fields": len(primaryKeyFields), "sql_statement": createTableSQL, "status": "success", "message": fmt.Sprintf("表 '%s' 创建成功", tableName), "timestamp": time.Now().Format(time.RFC3339), }, nil }, ) }