Нет описания
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package dbs
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "git.x2erp.com/qdy/go-svc-mcp/internal/mcp"
  8. )
  9. func init() {
  10. mcp.Register("get_postgresql_create_table", "在PostgreSQL数据库中创建新表,支持字段约束和注释",
  11. map[string]interface{}{
  12. "type": "object",
  13. "properties": map[string]interface{}{
  14. "table_name": map[string]interface{}{
  15. "type": "string",
  16. "description": "表名称",
  17. },
  18. "table_description": map[string]interface{}{
  19. "type": "string",
  20. "description": "表描述/注释",
  21. "default": "",
  22. },
  23. "schema": map[string]interface{}{
  24. "type": "string",
  25. "description": "模式名称(默认为public)",
  26. "default": "public",
  27. },
  28. "database_key": map[string]interface{}{
  29. "type": "string",
  30. "description": "数据库配置键名(如:business),可选,默认使用主数据库",
  31. "enum": []string{"warehouse", "business"},
  32. "default": "warehouse",
  33. },
  34. "fields": map[string]interface{}{
  35. "type": "array",
  36. "description": "字段定义",
  37. "items": map[string]interface{}{
  38. "type": "object",
  39. "properties": map[string]interface{}{
  40. "field_name": map[string]interface{}{
  41. "type": "string",
  42. "description": "字段名称",
  43. },
  44. "data_type": map[string]interface{}{
  45. "type": "string",
  46. "description": "数据类型(如 VARCHAR(255), INTEGER, TIMESTAMP 等)",
  47. },
  48. "is_nullable": map[string]interface{}{
  49. "type": "boolean",
  50. "description": "是否允许为空",
  51. "default": true,
  52. },
  53. "field_default": map[string]interface{}{
  54. "type": "string",
  55. "description": "默认值",
  56. "default": "",
  57. },
  58. "field_description": map[string]interface{}{
  59. "type": "string",
  60. "description": "字段描述/注释",
  61. "default": "",
  62. },
  63. "is_primary_key": map[string]interface{}{
  64. "type": "boolean",
  65. "description": "是否为主键",
  66. "default": false,
  67. },
  68. "is_unique": map[string]interface{}{
  69. "type": "boolean",
  70. "description": "是否唯一约束",
  71. "default": false,
  72. },
  73. "auto_increment": map[string]interface{}{
  74. "type": "boolean",
  75. "description": "是否自增(使用SERIAL或GENERATED ALWAYS AS IDENTITY)",
  76. "default": false,
  77. },
  78. },
  79. "required": []string{"field_name", "data_type"},
  80. },
  81. },
  82. },
  83. "required": []string{"table_name", "fields"},
  84. },
  85. func(input json.RawMessage, deps *mcp.ToolDependencies) (interface{}, error) {
  86. var params struct {
  87. TableName string `json:"table_name"`
  88. TableDescription string `json:"table_description"`
  89. Schema string `json:"schema"`
  90. DatabaseKey string `json:"database_key"`
  91. Fields []struct {
  92. FieldName string `json:"field_name"`
  93. DataType string `json:"data_type"`
  94. IsNullable bool `json:"is_nullable"`
  95. FieldDefault string `json:"field_default"`
  96. FieldDescription string `json:"field_description"`
  97. IsPrimaryKey bool `json:"is_primary_key"`
  98. IsUnique bool `json:"is_unique"`
  99. AutoIncrement bool `json:"auto_increment"`
  100. } `json:"fields"`
  101. }
  102. if len(input) > 0 {
  103. if err := json.Unmarshal(input, &params); err != nil {
  104. return nil, err
  105. }
  106. }
  107. // 获取数据库工厂
  108. dbFactory, err := GetDBFactory(params.DatabaseKey, deps)
  109. if err != nil {
  110. return nil, err
  111. }
  112. // 获取数据库类型,确保是PostgreSQL
  113. dbType := dbFactory.GetDBType()
  114. if dbType != "postgresql" {
  115. return nil, fmt.Errorf("当前数据库类型为 %s,此工具仅支持PostgreSQL数据库", dbType)
  116. }
  117. // 获取当前数据库名称
  118. _ = dbFactory.GetDatabaseName() // 保留但不使用
  119. schema := strings.TrimSpace(params.Schema)
  120. if schema == "" {
  121. schema = "public"
  122. }
  123. tableName := strings.TrimSpace(params.TableName)
  124. if tableName == "" {
  125. return nil, fmt.Errorf("表名称不能为空")
  126. }
  127. if len(params.Fields) == 0 {
  128. return nil, fmt.Errorf("至少需要定义一个字段")
  129. }
  130. // 检查表是否已存在
  131. tableExistsQuery := `SELECT COUNT(*) as table_count FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2`
  132. tableCheckResults, err := dbFactory.QuerySliceMapWithParams(tableExistsQuery, schema, tableName)
  133. if err != nil {
  134. return nil, fmt.Errorf("检查表是否存在失败: %v", err)
  135. }
  136. tableExists := false
  137. if len(tableCheckResults) > 0 {
  138. if count, ok := tableCheckResults[0]["table_count"].(int64); ok && count > 0 {
  139. tableExists = true
  140. }
  141. }
  142. if tableExists {
  143. return nil, fmt.Errorf("表 '%s' 已存在于模式 '%s' 中", tableName, schema)
  144. }
  145. // 构建建表SQL
  146. var fieldDefinitions []string
  147. var primaryKeyFields []string
  148. for _, field := range params.Fields {
  149. fieldName := strings.TrimSpace(field.FieldName)
  150. dataType := strings.TrimSpace(field.DataType)
  151. if fieldName == "" || dataType == "" {
  152. return nil, fmt.Errorf("字段名称和数据类型不能为空")
  153. }
  154. // 构建字段定义
  155. fieldDef := fmt.Sprintf(`"%s" %s`, fieldName, dataType)
  156. // 添加NOT NULL约束
  157. if !field.IsNullable {
  158. fieldDef += " NOT NULL"
  159. }
  160. // 添加默认值
  161. if field.FieldDefault != "" {
  162. fieldDef += fmt.Sprintf(" DEFAULT %s", field.FieldDefault)
  163. }
  164. // 添加自增
  165. if field.AutoIncrement {
  166. // 检查数据类型是否为整数类型
  167. if strings.Contains(strings.ToUpper(dataType), "INT") {
  168. fieldDef += " GENERATED ALWAYS AS IDENTITY"
  169. }
  170. }
  171. // 添加唯一约束(如果不是主键)
  172. if field.IsUnique && !field.IsPrimaryKey {
  173. fieldDef += " UNIQUE"
  174. }
  175. fieldDefinitions = append(fieldDefinitions, fieldDef)
  176. // 收集主键字段
  177. if field.IsPrimaryKey {
  178. primaryKeyFields = append(primaryKeyFields, fmt.Sprintf(`"%s"`, fieldName))
  179. }
  180. }
  181. // 构建完整的CREATE TABLE语句
  182. createTableSQL := fmt.Sprintf(`CREATE TABLE "%s"."%s" (%s)`, schema, tableName, strings.Join(fieldDefinitions, ", "))
  183. // 添加主键约束
  184. if len(primaryKeyFields) > 0 {
  185. createTableSQL += fmt.Sprintf(` CONSTRAINT "%s_pkey" PRIMARY KEY (%s)`, tableName, strings.Join(primaryKeyFields, ", "))
  186. }
  187. // 执行建表SQL
  188. _, err = dbFactory.Execute(createTableSQL)
  189. if err != nil {
  190. return nil, fmt.Errorf("创建表失败: %v", err)
  191. }
  192. // 添加表注释
  193. tableDescription := strings.TrimSpace(params.TableDescription)
  194. if tableDescription != "" {
  195. commentSQL := fmt.Sprintf(`COMMENT ON TABLE "%s"."%s" IS '%s'`, schema, tableName, strings.ReplaceAll(tableDescription, "'", "''"))
  196. _, err = dbFactory.Execute(commentSQL)
  197. if err != nil {
  198. fmt.Printf("添加表注释失败: %v\n", err)
  199. }
  200. }
  201. // 添加字段注释
  202. for _, field := range params.Fields {
  203. fieldDescription := strings.TrimSpace(field.FieldDescription)
  204. if fieldDescription != "" {
  205. fieldName := strings.TrimSpace(field.FieldName)
  206. commentSQL := fmt.Sprintf(`COMMENT ON COLUMN "%s"."%s"."%s" IS '%s'`, schema, tableName, fieldName, strings.ReplaceAll(fieldDescription, "'", "''"))
  207. _, err = dbFactory.Execute(commentSQL)
  208. if err != nil {
  209. fmt.Printf("添加字段注释失败: %v\n", err)
  210. }
  211. }
  212. }
  213. return map[string]interface{}{
  214. "tenant_id": deps.ReqCtx.TenantID,
  215. "user_id": deps.ReqCtx.UserID,
  216. "database_type": dbType,
  217. "database_name": dbFactory.GetDatabaseName(),
  218. "schema": schema,
  219. "table_name": tableName,
  220. "table_description": tableDescription,
  221. "total_fields": len(params.Fields),
  222. "primary_key_fields": len(primaryKeyFields),
  223. "sql_statement": createTableSQL,
  224. "status": "success",
  225. "message": fmt.Sprintf("表 '%s' 创建成功", tableName),
  226. "timestamp": time.Now().Format(time.RFC3339),
  227. }, nil
  228. },
  229. )
  230. }