Browse Source

自动建立表通过

qdy 2 months ago
parent
commit
18e4bcfbc2

drivers/postgres.go → drivers/postgresql.go View File

@@ -10,7 +10,7 @@ import (
10 10
 type PostgresDriver struct{}
11 11
 
12 12
 func (d *PostgresDriver) Name() string {
13
-	return "postgres"
13
+	return "postgresql"
14 14
 }
15 15
 
16 16
 func (d *PostgresDriver) BuildDSN(config DBConfig) string {

functions/execute.go → functions/execute_sql.go View File


+ 30
- 17
functions/query_csv.go View File

@@ -7,65 +7,78 @@ import (
7 7
 	"strings"
8 8
 
9 9
 	"git.x2erp.com/qdy/go-base/ctx"
10
+	"git.x2erp.com/qdy/go-base/logger"
10 11
 	"github.com/jmoiron/sqlx"
12
+	"go.uber.org/zap"
11 13
 )
12 14
 
13 15
 // QueryToCSV 无参数查询并返回 CSV 字节数据
14 16
 func QueryToCSV(db *sqlx.DB, sql string, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) {
17
+	logger.DebugC(reqCtx, "Executing QueryToCSV",
18
+		zap.String("sql", sql),
19
+		zap.Bool("writerHeader", writerHeader))
20
+
15 21
 	if sql == "" {
16
-		return nil, fmt.Errorf("SQL query cannot be empty")
22
+		return nil, logger.ErrorCf(reqCtx, "SQL query cannot be empty")
17 23
 	}
18 24
 
19 25
 	rows, err := db.Query(sql)
20 26
 	if err != nil {
21
-		return nil, fmt.Errorf("query execution failed: %v", err)
27
+		return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err)
22 28
 	}
23 29
 
24
-	return rowsToCSV(rows, writerHeader)
30
+	return rowsToCSV(rows, writerHeader, reqCtx)
25 31
 }
26 32
 
27 33
 // QueryParamsToCSV 位置参数查询并返回 CSV 字节数据
28 34
 func QueryPositionalToCSV(db *sqlx.DB, sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
35
+
36
+	logger.DebugC(reqCtx, "Executing QueryToCSV: sql=%s, writerHeader=%v", sql, writerHeader)
37
+
29 38
 	if sql == "" {
30
-		return nil, fmt.Errorf("SQL query cannot be empty")
39
+		return nil, logger.ErrorCf(reqCtx, "SQL query cannot be empty")
31 40
 	}
32 41
 
33 42
 	rows, err := db.Query(sql, params...)
34 43
 	if err != nil {
35
-		return nil, fmt.Errorf("query execution failed: %v", err)
44
+		return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err)
36 45
 	}
37 46
 
38
-	return rowsToCSV(rows, writerHeader)
47
+	return rowsToCSV(rows, writerHeader, reqCtx)
39 48
 }
40 49
 
41 50
 // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
42 51
 // params 可以是 map[string]interface{} 或结构体
43 52
 func QueryParamsNameToCSV(db *sqlx.DB, sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
53
+
54
+	logger.DebugC(reqCtx, "Executing QueryToCSV",
55
+		zap.String("sql", sql),
56
+		zap.Bool("writerHeader", writerHeader))
44 57
 	if sql == "" {
45
-		return nil, fmt.Errorf("SQL query cannot be empty")
58
+		return nil, logger.ErrorCf(reqCtx, "SQL query cannot be empty")
46 59
 	}
47 60
 
48 61
 	query, args, err := sqlx.Named(sql, params)
49 62
 	if err != nil {
50
-		return nil, fmt.Errorf("failed to bind named parameters: %v", err)
63
+		return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err)
51 64
 	}
52 65
 
53 66
 	query = db.Rebind(query)
54 67
 	rows, err := db.Query(query, args...)
55 68
 	if err != nil {
56
-		return nil, fmt.Errorf("query execution failed: %v", err)
69
+		return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err)
57 70
 	}
58 71
 
59
-	return rowsToCSV(rows, writerHeader)
72
+	return rowsToCSV(rows, writerHeader, reqCtx)
60 73
 }
61 74
 
62 75
 // rowsToCSV 公共方法:将查询结果转换为 CSV 字节数据
63
-func rowsToCSV(rows *sql.Rows, writerHeader bool) ([]byte, error) {
76
+func rowsToCSV(rows *sql.Rows, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) {
64 77
 	defer rows.Close()
65 78
 
66 79
 	columns, err := rows.Columns()
67 80
 	if err != nil {
68
-		return nil, fmt.Errorf("failed to get columns: %v", err)
81
+		return nil, logger.ErrorCf(reqCtx, "failed to get columns: %v", err)
69 82
 	}
70 83
 
71 84
 	var builder strings.Builder
@@ -74,7 +87,7 @@ func rowsToCSV(rows *sql.Rows, writerHeader bool) ([]byte, error) {
74 87
 	// 根据参数决定是否写入表头
75 88
 	if writerHeader {
76 89
 		if err := writer.Write(columns); err != nil {
77
-			return nil, fmt.Errorf("failed to write CSV header: %v", err)
90
+			return nil, logger.ErrorCf(reqCtx, "failed to write CSV header: %v", err)
78 91
 		}
79 92
 	}
80 93
 
@@ -86,7 +99,7 @@ func rowsToCSV(rows *sql.Rows, writerHeader bool) ([]byte, error) {
86 99
 		}
87 100
 
88 101
 		if err := rows.Scan(valuePtrs...); err != nil {
89
-			return nil, fmt.Errorf("failed to scan row: %v", err)
102
+			return nil, logger.ErrorCf(reqCtx, "failed to scan row: %v", err)
90 103
 		}
91 104
 
92 105
 		// 所有值转为字符串
@@ -100,17 +113,17 @@ func rowsToCSV(rows *sql.Rows, writerHeader bool) ([]byte, error) {
100 113
 		}
101 114
 
102 115
 		if err := writer.Write(row); err != nil {
103
-			return nil, fmt.Errorf("failed to write CSV row: %v", err)
116
+			return nil, logger.ErrorCf(reqCtx, "failed to write CSV row: %v", err)
104 117
 		}
105 118
 	}
106 119
 
107 120
 	writer.Flush()
108 121
 	if err := writer.Error(); err != nil {
109
-		return nil, fmt.Errorf("failed to flush CSV: %v", err)
122
+		return nil, logger.ErrorCf(reqCtx, "failed to flush CSV: %v", err)
110 123
 	}
111 124
 
112 125
 	if err := rows.Err(); err != nil {
113
-		return nil, fmt.Errorf("row iteration error: %v", err)
126
+		return nil, logger.ErrorCf(reqCtx, "row iteration error: %v", err)
114 127
 	}
115 128
 
116 129
 	return []byte(builder.String()), nil

+ 15
- 6
myhandle/query_handler_bytes.go View File

@@ -2,6 +2,7 @@ package myhandle
2 2
 
3 3
 import (
4 4
 	"encoding/json"
5
+	"fmt"
5 6
 	"net/http"
6 7
 
7 8
 	"git.x2erp.com/qdy/go-base/ctx"
@@ -12,24 +13,32 @@ func QueryHandlerBytes[T any, F any](
12 13
 	w http.ResponseWriter,
13 14
 	r *http.Request,
14 15
 	factory F,
15
-	handlerFunc func(F, T, *ctx.RequestContext) []byte,
16
+	handlerFunc func(F, T, *ctx.RequestContext) ([]byte, error),
16 17
 ) {
17 18
 	// 解析请求参数
19
+
20
+	w.Header().Set("Content-Type", "text/csv")
21
+	w.Header().Set("Content-Disposition", "attachment; filename=query_result.csv")
18 22
 	var req T
19 23
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
20 24
 		// 返回 CSV 格式的错误信息
21 25
 		errorCSV := "error,Invalid request body\n"
22
-		w.Header().Set("Content-Type", "text/csv")
26
+		//		w.Header().Set("Content-Type", "text/csv")
23 27
 		w.WriteHeader(http.StatusBadRequest)
24 28
 		w.Write([]byte(errorCSV))
25 29
 		return
26 30
 	}
27 31
 
32
+	reqCtx := ctx.GetContext(r)
28 33
 	// 调用业务逻辑函数
29
-	csvData := handlerFunc(factory, req, ctx.GetContext(r))
30
-
34
+	csvData, err := handlerFunc(factory, req, reqCtx)
35
+	if err != nil {
36
+		w.WriteHeader(http.StatusBadRequest)
37
+		w.Write([]byte(fmt.Sprintf("%v", err)))
38
+		return
39
+	}
31 40
 	// 直接返回 CSV 数据(包含错误信息时也会被正确处理)
32
-	w.Header().Set("Content-Type", "text/csv")
33
-	w.Header().Set("Content-Disposition", "attachment; filename=query_result.csv")
41
+	//w.Header().Set("Content-Type", "text/csv")
42
+	//w.Header().Set("Content-Disposition", "attachment; filename=query_result.csv")
34 43
 	w.Write(csvData)
35 44
 }

+ 148
- 0
sqldef/generators/mysql.go View File

@@ -0,0 +1,148 @@
1
+// mysql.go 根据//table_defintion.go 定义的表结构,编写mysql建立表和索引的代码。
2
+package generators
3
+
4
+import (
5
+	"fmt"
6
+	"strings"
7
+)
8
+
9
+// MySQLGenerator MySQL SQL生成器
10
+type MySQLGenerator struct{}
11
+
12
+// NewMySQLGenerator 创建MySQL生成器实例
13
+func NewMySQLGenerator() *MySQLGenerator {
14
+	return &MySQLGenerator{}
15
+}
16
+
17
+func (mg *MySQLGenerator) DBType() string {
18
+	return "mysql"
19
+}
20
+
21
+func (mg *MySQLGenerator) TableExistsSQL(tableName string) string {
22
+	return fmt.Sprintf(
23
+		"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '%s'",
24
+		tableName,
25
+	)
26
+}
27
+
28
+func (mg *MySQLGenerator) DropTableSQL(tableName string) string {
29
+	return fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)
30
+}
31
+
32
+func (mg *MySQLGenerator) GenerateCreateTable(table TableDDL) string {
33
+	if table.Schema == nil {
34
+		return ""
35
+	}
36
+
37
+	var sql strings.Builder
38
+
39
+	// 表头
40
+	sql.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", table.Name))
41
+
42
+	// 列定义
43
+	columns := table.Schema.Columns
44
+	for i, col := range columns {
45
+		sql.WriteString(fmt.Sprintf("  %s %s", col.Name, mg.getMySQLType(col)))
46
+
47
+		// 添加选项
48
+		for _, opt := range col.Options {
49
+			sql.WriteString(" " + opt)
50
+		}
51
+
52
+		// 添加默认值
53
+		if col.Default != "" {
54
+			// 检查是否是函数调用(如CURRENT_TIMESTAMP)
55
+			if strings.Contains(strings.ToUpper(col.Default), "CURRENT_TIMESTAMP") ||
56
+				strings.Contains(strings.ToUpper(col.Default), "NOW()") {
57
+				sql.WriteString(" DEFAULT " + col.Default)
58
+			} else {
59
+				sql.WriteString(fmt.Sprintf(" DEFAULT '%s'", col.Default))
60
+			}
61
+		}
62
+
63
+		// 添加注释
64
+		if col.Comment != "" {
65
+			sql.WriteString(fmt.Sprintf(" COMMENT '%s'", col.Comment))
66
+		}
67
+
68
+		if i < len(columns)-1 {
69
+			sql.WriteString(",")
70
+		}
71
+		sql.WriteString("\n")
72
+	}
73
+
74
+	// 添加索引(在MySQL中,索引可以在CREATE TABLE语句中定义)
75
+	for _, idx := range table.Schema.Indexes {
76
+		if idx.Unique {
77
+			sql.WriteString(fmt.Sprintf("  ,UNIQUE KEY %s (%s)\n",
78
+				idx.Name, strings.Join(idx.Columns, ", ")))
79
+		} else {
80
+			sql.WriteString(fmt.Sprintf("  ,KEY %s (%s)\n",
81
+				idx.Name, strings.Join(idx.Columns, ", ")))
82
+		}
83
+	}
84
+
85
+	sql.WriteString(")")
86
+
87
+	// 表选项
88
+	if table.Schema.Comment != "" {
89
+		sql.WriteString(fmt.Sprintf(" COMMENT='%s'", table.Schema.Comment))
90
+	}
91
+
92
+	sql.WriteString(" ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;")
93
+
94
+	return sql.String()
95
+}
96
+
97
+// getMySQLType 获取MySQL数据类型
98
+func (mg *MySQLGenerator) getMySQLType(col ColumnSchema) string {
99
+	switch col.Type {
100
+	case "DECIMAL":
101
+		if col.Precision > 0 && col.Scale > 0 {
102
+			return fmt.Sprintf("DECIMAL(%d,%d)", col.Precision, col.Scale)
103
+		}
104
+		return "DECIMAL"
105
+	case "VARCHAR":
106
+		if col.Length > 0 {
107
+			return fmt.Sprintf("VARCHAR(%d)", col.Length)
108
+		}
109
+		return "VARCHAR(255)"
110
+	case "CHAR":
111
+		if col.Length > 0 {
112
+			return fmt.Sprintf("CHAR(%d)", col.Length)
113
+		}
114
+		return "CHAR(1)"
115
+	case "INT":
116
+		return "INT"
117
+	case "BIGINT":
118
+		return "BIGINT"
119
+	case "TINYINT":
120
+		return "TINYINT"
121
+	case "BOOL":
122
+		return "TINYINT(1)"
123
+	case "DATETIME":
124
+		return "DATETIME"
125
+	case "TIMESTAMP":
126
+		return "TIMESTAMP"
127
+	case "DATE":
128
+		return "DATE"
129
+	case "TIME":
130
+		return "TIME"
131
+	case "TEXT":
132
+		return "TEXT"
133
+	case "JSON":
134
+		return "JSON"
135
+	case "BLOB":
136
+		return "BLOB"
137
+	case "FLOAT":
138
+		return "FLOAT"
139
+	case "DOUBLE":
140
+		return "DOUBLE"
141
+	default:
142
+		return col.Type
143
+	}
144
+}
145
+
146
+func init() {
147
+	RegisterGenerator(NewMySQLGenerator())
148
+}

+ 236
- 0
sqldef/generators/postgresql.go View File

@@ -0,0 +1,236 @@
1
+// postgresql.go 根据//table_defintion.go 定义的表结构,编写PostgreSQL建立表和索引的代码。
2
+package generators
3
+
4
+import (
5
+	"fmt"
6
+	"strings"
7
+)
8
+
9
+// PostgreSQLGenerator PostgreSQL SQL生成器
10
+type PostgreSQLGenerator struct{}
11
+
12
+// NewPostgreSQLGenerator 创建PostgreSQL生成器实例
13
+func NewPostgreSQLGenerator() *PostgreSQLGenerator {
14
+	return &PostgreSQLGenerator{}
15
+}
16
+
17
+func (pg *PostgreSQLGenerator) DBType() string {
18
+	return "postgresql"
19
+}
20
+
21
+func (pg *PostgreSQLGenerator) TableExistsSQL(tableName string) string {
22
+	return fmt.Sprintf(
23
+		"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '%s')",
24
+		strings.ToLower(tableName),
25
+	)
26
+}
27
+
28
+func (pg *PostgreSQLGenerator) DropTableSQL(tableName string) string {
29
+	return fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", tableName)
30
+}
31
+
32
+func (pg *PostgreSQLGenerator) GenerateCreateTable(table TableDDL) string {
33
+	if table.Schema == nil {
34
+		return ""
35
+	}
36
+
37
+	var sql strings.Builder
38
+
39
+	// 表头
40
+	sql.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", table.Name))
41
+
42
+	// 列定义
43
+	columns := table.Schema.Columns
44
+	for i, col := range columns {
45
+		sql.WriteString(fmt.Sprintf("  %s %s", col.Name, pg.getPostgreSQLType(col)))
46
+
47
+		// 添加选项(转换MySQL选项到PostgreSQL)
48
+		for _, opt := range col.Options {
49
+			pgOpt := pg.convertOption(opt, col)
50
+			if pgOpt != "" {
51
+				sql.WriteString(" " + pgOpt)
52
+			}
53
+		}
54
+
55
+		// 添加默认值
56
+		if col.Default != "" {
57
+			pgDefault := pg.convertDefaultValue(col.Default, col.Type)
58
+			if pgDefault != "" {
59
+				sql.WriteString(" DEFAULT " + pgDefault)
60
+			}
61
+		}
62
+
63
+		// 列注释将在表创建后单独添加
64
+
65
+		if i < len(columns)-1 {
66
+			sql.WriteString(",")
67
+		}
68
+		sql.WriteString("\n")
69
+	}
70
+
71
+	// 添加主键约束(从列选项中提取)
72
+	primaryKeys := pg.extractPrimaryKeys(columns)
73
+	if len(primaryKeys) > 0 {
74
+		sql.WriteString(fmt.Sprintf("  ,PRIMARY KEY (%s)\n", strings.Join(primaryKeys, ", ")))
75
+	}
76
+
77
+	sql.WriteString(");\n")
78
+
79
+	// 添加表注释
80
+	if table.Schema.Comment != "" {
81
+		sql.WriteString(fmt.Sprintf("COMMENT ON TABLE %s IS '%s';\n",
82
+			table.Name, table.Schema.Comment))
83
+	}
84
+
85
+	// 添加列注释
86
+	for _, col := range columns {
87
+		if col.Comment != "" {
88
+			sql.WriteString(fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s';\n",
89
+				table.Name, col.Name, col.Comment))
90
+		}
91
+	}
92
+
93
+	// 添加索引(PostgreSQL中索引在表外创建)
94
+	for _, idx := range table.Schema.Indexes {
95
+		if idx.Unique {
96
+			sql.WriteString(fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s (%s);\n",
97
+				idx.Name, table.Name, strings.Join(idx.Columns, ", ")))
98
+		} else {
99
+			sql.WriteString(fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s);\n",
100
+				idx.Name, table.Name, strings.Join(idx.Columns, ", ")))
101
+		}
102
+	}
103
+
104
+	return sql.String()
105
+}
106
+
107
+// getPostgreSQLType 获取PostgreSQL数据类型
108
+func (pg *PostgreSQLGenerator) getPostgreSQLType(col ColumnSchema) string {
109
+	switch col.Type {
110
+	case "TINYINT":
111
+		if col.Length == 1 {
112
+			return "BOOLEAN"
113
+		}
114
+		return "SMALLINT"
115
+	case "BOOL":
116
+		return "BOOLEAN"
117
+	case "DATETIME":
118
+		return "TIMESTAMP"
119
+	case "TIMESTAMP":
120
+		return "TIMESTAMP"
121
+	case "JSON":
122
+		return "JSONB"
123
+	case "BLOB":
124
+		return "BYTEA"
125
+	case "INT":
126
+		return "INTEGER"
127
+	case "BIGINT":
128
+		return "BIGINT"
129
+	case "DECIMAL":
130
+		if col.Precision > 0 && col.Scale > 0 {
131
+			return fmt.Sprintf("DECIMAL(%d,%d)", col.Precision, col.Scale)
132
+		}
133
+		return "DECIMAL"
134
+	case "VARCHAR":
135
+		if col.Length > 0 {
136
+			return fmt.Sprintf("VARCHAR(%d)", col.Length)
137
+		}
138
+		return "VARCHAR"
139
+	case "CHAR":
140
+		if col.Length > 0 {
141
+			return fmt.Sprintf("CHAR(%d)", col.Length)
142
+		}
143
+		return "CHAR"
144
+	case "TEXT":
145
+		return "TEXT"
146
+	case "FLOAT":
147
+		return "REAL"
148
+	case "DOUBLE":
149
+		return "DOUBLE PRECISION"
150
+	case "DATE":
151
+		return "DATE"
152
+	case "TIME":
153
+		return "TIME"
154
+	default:
155
+		return col.Type
156
+	}
157
+}
158
+
159
+// convertOption 转换选项到PostgreSQL语法
160
+func (pg *PostgreSQLGenerator) convertOption(option string, col ColumnSchema) string {
161
+	option = strings.ToUpper(option)
162
+
163
+	switch option {
164
+	case "NOT NULL":
165
+		return "NOT NULL"
166
+	case "UNIQUE":
167
+		return "UNIQUE"
168
+	case "AUTO_INCREMENT":
169
+		// PostgreSQL使用SERIAL/BIGSERIAL/SMALLSERIAL
170
+		// 已在getPostgreSQLType中处理
171
+		return ""
172
+	case "PRIMARY KEY":
173
+		// 主键将在表级别定义
174
+		return ""
175
+	default:
176
+		return option
177
+	}
178
+}
179
+
180
+// convertDefaultValue 转换默认值
181
+func (pg *PostgreSQLGenerator) convertDefaultValue(value, colType string) string {
182
+	value = strings.TrimSpace(value)
183
+
184
+	// 移除引号(如果有)
185
+	if strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'") {
186
+		value = value[1 : len(value)-1]
187
+	}
188
+
189
+	// 处理布尔值
190
+	if strings.EqualFold(colType, "BOOLEAN") || strings.EqualFold(colType, "BOOL") {
191
+		switch strings.ToUpper(value) {
192
+		case "1", "TRUE", "'TRUE'":
193
+			return "TRUE"
194
+		case "0", "FALSE", "'FALSE'":
195
+			return "FALSE"
196
+		}
197
+	}
198
+
199
+	// 处理时间戳
200
+	if strings.Contains(strings.ToUpper(value), "CURRENT_TIMESTAMP") {
201
+		return "CURRENT_TIMESTAMP"
202
+	}
203
+
204
+	// 处理数字
205
+	if _, ok := pg.isNumber(value); ok {
206
+		return value
207
+	}
208
+
209
+	// 其他情况加单引号
210
+	return "'" + value + "'"
211
+}
212
+
213
+// isNumber 检查字符串是否为数字
214
+func (pg *PostgreSQLGenerator) isNumber(s string) (float64, bool) {
215
+	var f float64
216
+	_, err := fmt.Sscanf(s, "%f", &f)
217
+	return f, err == nil
218
+}
219
+
220
+// extractPrimaryKeys 从列中提取主键
221
+func (pg *PostgreSQLGenerator) extractPrimaryKeys(columns []ColumnSchema) []string {
222
+	var primaryKeys []string
223
+	for _, col := range columns {
224
+		for _, opt := range col.Options {
225
+			if strings.ToUpper(opt) == "PRIMARY KEY" {
226
+				primaryKeys = append(primaryKeys, col.Name)
227
+				break
228
+			}
229
+		}
230
+	}
231
+	return primaryKeys
232
+}
233
+
234
+func init() {
235
+	RegisterGenerator(NewPostgreSQLGenerator())
236
+}

+ 137
- 0
sqldef/generators/sql_generator.go View File

@@ -0,0 +1,137 @@
1
+// sql_generator.go 根据数据库类型,自动选择对应注册的建立表的sql编写器,进行编写。
2
+package generators
3
+
4
+import (
5
+	"fmt"
6
+	"strings"
7
+	"sync"
8
+)
9
+
10
+// ColumnSchema 列结构定义
11
+type ColumnSchema struct {
12
+	Name      string
13
+	Type      string // 基本类型:VARCHAR, INT, DECIMAL 等
14
+	Length    int    // 长度
15
+	Precision int    // 精度(用于 DECIMAL)
16
+	Scale     int    // 小数位数(用于 DECIMAL)
17
+	Comment   string
18
+	Options   []string // 选项:NOT NULL, PRIMARY KEY 等
19
+	Default   string   // 默认值字符串
20
+}
21
+
22
+// IndexSchema 索引结构定义
23
+type IndexSchema struct {
24
+	Name    string
25
+	Columns []string
26
+	Unique  bool
27
+}
28
+
29
+// TableSchema 表结构定义
30
+type TableSchema struct {
31
+	Name    string
32
+	Comment string
33
+	Columns []ColumnSchema
34
+	Indexes []IndexSchema
35
+}
36
+
37
+// TableDDL 表定义
38
+type TableDDL struct {
39
+	Name    string
40
+	SQL     string
41
+	Comment string
42
+	Schema  *TableSchema // 新增Schema字段
43
+}
44
+
45
+// SQLGenerator SQL生成器接口
46
+type SQLGenerator interface {
47
+	// 生成检查表是否存在的SQL
48
+	TableExistsSQL(tableName string) string
49
+
50
+	// 生成删除表的SQL
51
+	DropTableSQL(tableName string) string
52
+
53
+	// 生成创建表的SQL
54
+	GenerateCreateTable(table TableDDL) string
55
+
56
+	// 获取数据库类型
57
+	DBType() string
58
+}
59
+
60
+// SQLGeneratorRegistry SQL生成器注册表
61
+type SQLGeneratorRegistry struct {
62
+	generators map[string]SQLGenerator
63
+	mu         sync.RWMutex
64
+}
65
+
66
+var (
67
+	globalGeneratorRegistry *SQLGeneratorRegistry
68
+	generatorRegistryOnce   sync.Once
69
+)
70
+
71
+// GetSQLGeneratorRegistry 获取SQL生成器注册表单例
72
+func GetSQLGeneratorRegistry() *SQLGeneratorRegistry {
73
+	generatorRegistryOnce.Do(func() {
74
+		globalGeneratorRegistry = &SQLGeneratorRegistry{
75
+			generators: make(map[string]SQLGenerator),
76
+		}
77
+	})
78
+	return globalGeneratorRegistry
79
+}
80
+
81
+// RegisterGenerator 注册SQL生成器
82
+func (r *SQLGeneratorRegistry) RegisterGenerator(generator SQLGenerator) {
83
+	r.mu.Lock()
84
+	defer r.mu.Unlock()
85
+
86
+	dbType := strings.ToLower(generator.DBType())
87
+	r.generators[dbType] = generator
88
+}
89
+
90
+// GetGenerator 获取SQL生成器
91
+func (r *SQLGeneratorRegistry) GetGenerator(dbType string) (SQLGenerator, error) {
92
+	r.mu.RLock()
93
+	defer r.mu.RUnlock()
94
+
95
+	generator, exists := r.generators[strings.ToLower(dbType)]
96
+	if !exists {
97
+		return nil, fmt.Errorf("不支持的数据类型: %s,已注册的类型: %v", dbType, r.GetRegisteredTypes())
98
+	}
99
+	return generator, nil
100
+}
101
+
102
+// GetRegisteredTypes 获取已注册的数据库类型
103
+func (r *SQLGeneratorRegistry) GetRegisteredTypes() []string {
104
+	r.mu.RLock()
105
+	defer r.mu.RUnlock()
106
+
107
+	types := make([]string, 0, len(r.generators))
108
+	for dbType := range r.generators {
109
+		types = append(types, dbType)
110
+	}
111
+	return types
112
+}
113
+
114
+// 包级便捷函数
115
+
116
+// RegisterGenerator 包级便捷函数:注册SQL生成器
117
+func RegisterGenerator(generator SQLGenerator) {
118
+	GetSQLGeneratorRegistry().RegisterGenerator(generator)
119
+}
120
+
121
+// GetGenerator 包级便捷函数:获取SQL生成器
122
+func GetGenerator(dbType string) (SQLGenerator, error) {
123
+	return GetSQLGeneratorRegistry().GetGenerator(dbType)
124
+}
125
+
126
+// GetAllGenerators 获取所有已注册的生成器
127
+func GetAllGenerators() []SQLGenerator {
128
+	registry := GetSQLGeneratorRegistry()
129
+	registry.mu.RLock()
130
+	defer registry.mu.RUnlock()
131
+
132
+	generators := make([]SQLGenerator, 0, len(registry.generators))
133
+	for _, generator := range registry.generators {
134
+		generators = append(generators, generator)
135
+	}
136
+	return generators
137
+}

+ 99
- 0
sqldef/index_manager.go View File

@@ -0,0 +1,99 @@
1
+// index_manager.go
2
+package sqldef
3
+
4
+import (
5
+	"fmt"
6
+	"strings"
7
+)
8
+
9
+// IndexManager 索引管理器
10
+type IndexManager struct {
11
+	ddlExecutor DDLExecutor
12
+}
13
+
14
+// IndexDefinition 索引定义
15
+type IndexDefinition struct {
16
+	TableName string
17
+	IndexName string
18
+	Columns   []string
19
+	IsUnique  bool
20
+	IsPrimary bool
21
+}
22
+
23
+// NewIndexManager 创建索引管理器
24
+func NewIndexManager(executor DDLExecutor) *IndexManager {
25
+	return &IndexManager{
26
+		ddlExecutor: executor,
27
+	}
28
+}
29
+
30
+// CreateIndex 创建索引
31
+func (im *IndexManager) CreateIndex(idx *IndexDefinition) error {
32
+	var indexType string
33
+	if idx.IsPrimary {
34
+		indexType = "PRIMARY KEY"
35
+	} else if idx.IsUnique {
36
+		indexType = "UNIQUE INDEX"
37
+	} else {
38
+		indexType = "INDEX"
39
+	}
40
+
41
+	columns := strings.Join(idx.Columns, ", ")
42
+	sql := fmt.Sprintf("CREATE %s %s ON %s (%s)",
43
+		indexType, idx.IndexName, idx.TableName, columns)
44
+
45
+	return im.ddlExecutor.ExecuteDDL(sql)
46
+}
47
+
48
+// DropIndex 删除索引
49
+func (im *IndexManager) DropIndex(tableName, indexName string) error {
50
+	sql := fmt.Sprintf("DROP INDEX %s ON %s", indexName, tableName)
51
+	return im.ddlExecutor.ExecuteDDL(sql)
52
+}
53
+
54
+// TableManager 添加索引相关方法
55
+func (tm *TableManager) SetIndexManager(im *IndexManager) {
56
+	// 可以在这里关联索引管理器
57
+}
58
+
59
+// CreateIndexesFromDDL 从DDL语句中解析并创建索引
60
+func (tm *TableManager) CreateIndexesFromDDL() error {
61
+	// 获取所有表
62
+	tables := tm.GetRegisteredTables()
63
+
64
+	for _, tableName := range tables {
65
+		ddl, exists := tm.GetTableDDL(tableName)
66
+		if !exists {
67
+			continue
68
+		}
69
+
70
+		// 解析DDL中的索引信息
71
+		indexes := parseIndexesFromDDL(ddl)
72
+
73
+		// 创建索引
74
+		for _, index := range indexes {
75
+			if err := tm.ddlExecutor.ExecuteDDL(index); err != nil {
76
+				return fmt.Errorf("创建索引失败: %w", err)
77
+			}
78
+		}
79
+	}
80
+
81
+	return nil
82
+}
83
+
84
+// parseIndexesFromDDL 从DDL语句中解析索引
85
+func parseIndexesFromDDL(ddl string) []string {
86
+	var indexes []string
87
+	lines := strings.Split(ddl, "\n")
88
+
89
+	for _, line := range lines {
90
+		line = strings.TrimSpace(line)
91
+		if strings.Contains(line, "INDEX") || strings.Contains(line, "PRIMARY KEY") {
92
+			// 移除末尾的逗号
93
+			line = strings.TrimSuffix(line, ",")
94
+			indexes = append(indexes, line)
95
+		}
96
+	}
97
+
98
+	return indexes
99
+}

+ 168
- 0
sqldef/table_create.go View File

@@ -0,0 +1,168 @@
1
+// table_create.go 执行建表代码
2
+package sqldef
3
+
4
+import (
5
+	"fmt"
6
+	"log"
7
+
8
+	"git.x2erp.com/qdy/go-db/sqldef/generators"
9
+	"github.com/jmoiron/sqlx"
10
+)
11
+
12
+// TableSyncer 表同步器
13
+type TableSyncer struct {
14
+	db        *sqlx.DB
15
+	dbType    string
16
+	generator generators.SQLGenerator
17
+}
18
+
19
+// NewTableSyncer 创建表同步器
20
+func NewTableSyncer(db *sqlx.DB, dbType string) (*TableSyncer, error) {
21
+	// 获取SQL生成器
22
+	generator, err := generators.GetGenerator(dbType)
23
+	if err != nil {
24
+		return nil, fmt.Errorf("获取SQL生成器失败: %w", err)
25
+	}
26
+
27
+	return &TableSyncer{
28
+		db:        db,
29
+		dbType:    dbType,
30
+		generator: generator,
31
+	}, nil
32
+}
33
+
34
+// CreateAllTables 同步所有注册的表
35
+// recreate: true - 表存在则删除重建;false - 表不存在则创建
36
+func (ts *TableSyncer) CreateAllTables(recreate bool) error {
37
+	// 确保注册表已初始化
38
+	globalRegistry.ensureInit()
39
+
40
+	// 测试数据库连接
41
+	if err := ts.db.Ping(); err != nil {
42
+		return fmt.Errorf("数据库连接失败: %w", err)
43
+	}
44
+
45
+	// 获取所有注册的表定义
46
+	tables := GetAll()
47
+
48
+	log.Printf("开始创建 %d 个表到 %s 数据库...\n", len(tables), ts.dbType)
49
+
50
+	// 处理每个表
51
+	for i, table := range tables {
52
+		tableName := table.Name
53
+
54
+		// 检查表是否存在
55
+		exists, err := ts.tableExists(tableName)
56
+		if err != nil {
57
+			return fmt.Errorf("检查表 %s 是否存在失败: %w", tableName, err)
58
+		}
59
+
60
+		if recreate {
61
+			// 如果存在则删除
62
+			if exists {
63
+				if err := ts.dropTable(tableName); err != nil {
64
+					return fmt.Errorf("删除表 %s 失败: %w", tableName, err)
65
+				}
66
+				log.Printf("[%d/%d] 表 %s 已删除\n", i+1, len(tables), tableName)
67
+			}
68
+
69
+			// 创建表
70
+			if err := ts.createTable(table); err != nil {
71
+				return fmt.Errorf("创建表 %s 失败: %w", tableName, err)
72
+			}
73
+			// 修正:Go没有三元运算符,使用if-else
74
+			var action string
75
+			if exists {
76
+				action = "重建"
77
+			} else {
78
+				action = "创建"
79
+			}
80
+			log.Printf("[%d/%d] 表 %s 已%s\n", i+1, len(tables), tableName, action)
81
+		} else {
82
+			// 只创建不存在的表
83
+			if !exists {
84
+				if err := ts.createTable(table); err != nil {
85
+					return fmt.Errorf("创建表 %s 失败: %w", tableName, err)
86
+				}
87
+				log.Printf("[%d/%d] 表 %s 已创建\n", i+1, len(tables), tableName)
88
+			} else {
89
+				log.Printf("[%d/%d] 表 %s 已存在,跳过\n", i+1, len(tables), tableName)
90
+			}
91
+		}
92
+	}
93
+
94
+	log.Println("所有表创建完成!")
95
+	return nil
96
+}
97
+
98
+// tableExists 检查表是否存在
99
+func (ts *TableSyncer) tableExists(tableName string) (bool, error) {
100
+	sql := ts.generator.TableExistsSQL(tableName)
101
+
102
+	var exists bool
103
+	if ts.dbType == "postgresql" {
104
+		// PostgreSQL返回boolean
105
+		err := ts.db.QueryRow(sql).Scan(&exists)
106
+		return exists, err
107
+	} else {
108
+		// MySQL返回count
109
+		var count int
110
+		err := ts.db.QueryRow(sql).Scan(&count)
111
+		return count > 0, err
112
+	}
113
+}
114
+
115
+// dropTable 删除表
116
+func (ts *TableSyncer) dropTable(tableName string) error {
117
+	sql := ts.generator.DropTableSQL(tableName)
118
+	_, err := ts.db.Exec(sql)
119
+	return err
120
+}
121
+
122
+// createTable 创建表
123
+func (ts *TableSyncer) createTable(table generators.TableDDL) error {
124
+	// 使用生成器生成适合当前数据库的SQL
125
+	sql := ts.generator.GenerateCreateTable(table)
126
+
127
+	// 打印SQL用于调试
128
+	log.Printf("执行SQL: %s", sql)
129
+
130
+	_, err := ts.db.Exec(sql)
131
+	if err != nil {
132
+		// 记录详细的错误信息
133
+		log.Printf("执行SQL失败: %v", err)
134
+		return fmt.Errorf("执行SQL失败: %w", err)
135
+	}
136
+	return nil
137
+}
138
+
139
+// CreateTables 创建所有不存在的表
140
+func (ts *TableSyncer) CreateTables() error {
141
+	return ts.CreateAllTables(false)
142
+}
143
+
144
+// RecreateTables 重建所有表
145
+func (ts *TableSyncer) RecreateTables() error {
146
+	return ts.CreateAllTables(true)
147
+}
148
+
149
+// 包级便捷函数
150
+
151
+// SyncTables 同步所有表
152
+func SyncTables(db *sqlx.DB, dbType string, recreate bool) error {
153
+	syncer, err := NewTableSyncer(db, dbType)
154
+	if err != nil {
155
+		return err
156
+	}
157
+	return syncer.CreateAllTables(recreate)
158
+}
159
+
160
+// CreateTables 创建所有不存在的表
161
+func CreateTables(db *sqlx.DB, dbType string) error {
162
+	return SyncTables(db, dbType, false)
163
+}
164
+
165
+// RecreateTables 重建所有表
166
+func RecreateTables(db *sqlx.DB, dbType string) error {
167
+	return SyncTables(db, dbType, true)
168
+}

+ 406
- 0
sqldef/table_definition.go View File

@@ -0,0 +1,406 @@
1
+package sqldef
2
+
3
+import (
4
+	"fmt"
5
+	"strings"
6
+	"sync"
7
+
8
+	"git.x2erp.com/qdy/go-db/sqldef/generators"
9
+)
10
+
11
+// 数据类型常量
12
+const (
13
+	TypeVarchar   = "VARCHAR"
14
+	TypeChar      = "CHAR"
15
+	TypeText      = "TEXT"
16
+	TypeInt       = "INT"
17
+	TypeBigInt    = "BIGINT"
18
+	TypeTinyInt   = "TINYINT"
19
+	TypeDecimal   = "DECIMAL"
20
+	TypeFloat     = "FLOAT"
21
+	TypeDouble    = "DOUBLE"
22
+	TypeBool      = "BOOL"
23
+	TypeDateTime  = "DATETIME"
24
+	TypeTimestamp = "TIMESTAMP"
25
+	TypeDate      = "DATE"
26
+	TypeTime      = "TIME"
27
+	TypeBlob      = "BLOB"
28
+	TypeJson      = "JSON"
29
+)
30
+
31
+// Column 列定义
32
+type Column struct {
33
+	name         string
34
+	sqlType      string
35
+	comment      string
36
+	options      []string
37
+	defaultValue string // 新增:默认值
38
+}
39
+
40
+// ColumnBuilder 列构建器
41
+type ColumnBuilder struct {
42
+	table  *TableBuilder
43
+	column Column
44
+}
45
+
46
+// TableBuilder 表构建器
47
+type TableBuilder struct {
48
+	name    string
49
+	comment string
50
+	columns []Column
51
+	indexes []string
52
+}
53
+
54
+// Registry 注册表(懒加载)
55
+type Registry struct {
56
+	tables map[string]generators.TableDDL
57
+	mu     sync.RWMutex
58
+	once   sync.Once
59
+	regFns []func(*Registry)
60
+}
61
+
62
+var globalRegistry = &Registry{
63
+	tables: make(map[string]generators.TableDDL),
64
+}
65
+
66
+// AddRegistration 添加注册函数
67
+func AddRegistration(fn func(*Registry)) {
68
+	globalRegistry.regFns = append(globalRegistry.regFns, fn)
69
+}
70
+
71
+func (r *Registry) lazyInit() {
72
+	r.mu.Lock()
73
+	defer r.mu.Unlock()
74
+
75
+	for _, fn := range r.regFns {
76
+		fn(r)
77
+	}
78
+}
79
+
80
+func (r *Registry) ensureInit() {
81
+	r.once.Do(r.lazyInit)
82
+}
83
+
84
+// RegisterTable 注册表
85
+func (r *Registry) RegisterTable(table generators.TableDDL) {
86
+	if table.Name != "" {
87
+		r.tables[table.Name] = table
88
+	}
89
+}
90
+
91
+// GetAll 获取所有表定义
92
+func GetAll() []generators.TableDDL {
93
+	globalRegistry.ensureInit()
94
+
95
+	globalRegistry.mu.RLock()
96
+	defer globalRegistry.mu.RUnlock()
97
+
98
+	result := make([]generators.TableDDL, 0, len(globalRegistry.tables))
99
+	for _, table := range globalRegistry.tables {
100
+		result = append(result, table)
101
+	}
102
+	return result
103
+}
104
+
105
+// Get 获取表定义
106
+func Get(tableName string) (generators.TableDDL, bool) {
107
+	globalRegistry.ensureInit()
108
+
109
+	globalRegistry.mu.RLock()
110
+	defer globalRegistry.mu.RUnlock()
111
+
112
+	table, exists := globalRegistry.tables[tableName]
113
+	return table, exists
114
+}
115
+
116
+// ================== TableBuilder 表构建器 ==================
117
+
118
+// NewTable 创建新表
119
+func NewTable(name string, comment ...string) *TableBuilder {
120
+	tb := &TableBuilder{name: name}
121
+	if len(comment) > 0 {
122
+		tb.comment = comment[0]
123
+	}
124
+	return tb
125
+}
126
+
127
+// ================== 列定义方法 ==================
128
+
129
+// 默认64位长度
130
+func (t *TableBuilder) ID(name string, length ...int) *ColumnBuilder {
131
+	size := 64
132
+	if len(length) > 0 {
133
+		size = length[0]
134
+	}
135
+	return t.column(name, TypeVarchar, size).PrimaryKey()
136
+}
137
+
138
+func (t *TableBuilder) String(name string, length int) *ColumnBuilder {
139
+	return t.column(name, TypeVarchar, length)
140
+}
141
+
142
+func (t *TableBuilder) Char(name string, length int) *ColumnBuilder {
143
+	return t.column(name, TypeChar, length)
144
+}
145
+
146
+func (t *TableBuilder) Text(name string) *ColumnBuilder {
147
+	return t.column(name, TypeText, 0)
148
+}
149
+
150
+func (t *TableBuilder) Int(name string) *ColumnBuilder {
151
+	return t.column(name, TypeInt, 0)
152
+}
153
+
154
+func (t *TableBuilder) TinyInt(name string) *ColumnBuilder {
155
+	return t.column(name, TypeTinyInt, 0)
156
+}
157
+
158
+func (t *TableBuilder) BigInt(name string) *ColumnBuilder {
159
+	return t.column(name, TypeBigInt, 0)
160
+}
161
+
162
+func (t *TableBuilder) Bool(name string) *ColumnBuilder {
163
+	return t.column(name, TypeBool, 0)
164
+}
165
+
166
+func (t *TableBuilder) JSON(name string) *ColumnBuilder {
167
+	return t.column(name, TypeJson, 0)
168
+}
169
+
170
+func (t *TableBuilder) Decimal(name string, precision, scale int) *ColumnBuilder {
171
+	col := &ColumnBuilder{
172
+		table: t,
173
+		column: Column{
174
+			name:    name,
175
+			sqlType: fmt.Sprintf("%s(%d,%d)", TypeDecimal, precision, scale),
176
+		},
177
+	}
178
+	return col
179
+}
180
+
181
+func (t *TableBuilder) Float(name string) *ColumnBuilder {
182
+	return t.column(name, TypeFloat, 0)
183
+}
184
+
185
+func (t *TableBuilder) Double(name string) *ColumnBuilder {
186
+	return t.column(name, TypeDouble, 0)
187
+}
188
+
189
+func (t *TableBuilder) DateTime(name string) *ColumnBuilder {
190
+	return t.column(name, TypeDateTime, 0)
191
+}
192
+
193
+func (t *TableBuilder) Timestamp(name string) *ColumnBuilder {
194
+	return t.column(name, TypeTimestamp, 0)
195
+}
196
+
197
+func (t *TableBuilder) Time(name string) *ColumnBuilder {
198
+	return t.column(name, TypeTime, 0)
199
+}
200
+
201
+func (t *TableBuilder) Date(name string) *ColumnBuilder {
202
+	return t.column(name, TypeDate, 0)
203
+}
204
+
205
+func (t *TableBuilder) Blob(name string) *ColumnBuilder {
206
+	return t.column(name, TypeBlob, 0)
207
+}
208
+
209
+// 私有辅助方法
210
+func (t *TableBuilder) column(name string, dataType string, length int) *ColumnBuilder {
211
+	sqlType := dataType
212
+	if length > 0 {
213
+		sqlType = fmt.Sprintf("%s(%d)", dataType, length)
214
+	}
215
+
216
+	col := &ColumnBuilder{
217
+		table: t,
218
+		column: Column{
219
+			name:    name,
220
+			sqlType: sqlType,
221
+		},
222
+	}
223
+	return col
224
+}
225
+
226
+// ================== 列构建器方法 ==================
227
+
228
+func (c *ColumnBuilder) NotNull() *ColumnBuilder {
229
+	c.column.options = append(c.column.options, "NOT NULL")
230
+	return c
231
+}
232
+
233
+func (c *ColumnBuilder) Default(value string) *ColumnBuilder {
234
+	c.column.defaultValue = value
235
+	return c
236
+}
237
+
238
+func (c *ColumnBuilder) PrimaryKey() *ColumnBuilder {
239
+	c.column.options = append(c.column.options, "PRIMARY KEY")
240
+	return c
241
+}
242
+
243
+func (c *ColumnBuilder) AutoIncrement() *ColumnBuilder {
244
+	c.column.options = append(c.column.options, "AUTO_INCREMENT")
245
+	return c
246
+}
247
+
248
+func (c *ColumnBuilder) Unique() *ColumnBuilder {
249
+	c.column.options = append(c.column.options, "UNIQUE")
250
+	return c
251
+}
252
+
253
+func (c *ColumnBuilder) Comment(comment string) *ColumnBuilder {
254
+	c.column.comment = comment
255
+	return c
256
+}
257
+
258
+// End 结束列定义,返回TableBuilder继续定义其他列
259
+func (c *ColumnBuilder) End() *TableBuilder {
260
+	c.table.columns = append(c.table.columns, c.column)
261
+	return c.table
262
+}
263
+
264
+// ================== 索引方法 ==================
265
+
266
+func (t *TableBuilder) AddIndex(name string, columns ...string) *TableBuilder {
267
+	idx := fmt.Sprintf("INDEX %s (%s)", name, strings.Join(columns, ", "))
268
+	t.indexes = append(t.indexes, idx)
269
+	return t
270
+}
271
+
272
+func (t *TableBuilder) AddUniqueIndex(name string, columns ...string) *TableBuilder {
273
+	idx := fmt.Sprintf("UNIQUE INDEX %s (%s)", name, strings.Join(columns, ", "))
274
+	t.indexes = append(t.indexes, idx)
275
+	return t
276
+}
277
+
278
+// Build 构建表定义,包含完整的Schema信息
279
+func (t *TableBuilder) Build() generators.TableDDL {
280
+	// 构建列Schema
281
+	columns := make([]generators.ColumnSchema, 0, len(t.columns))
282
+
283
+	for _, col := range t.columns {
284
+		// 解析列类型
285
+		colType, length, precision, scale := parseColumnType(col.sqlType)
286
+
287
+		// 提取默认值
288
+		var defaultValue string
289
+		var options []string
290
+		for _, opt := range col.options {
291
+			if strings.HasPrefix(opt, "DEFAULT ") {
292
+				defaultValue = strings.TrimPrefix(opt, "DEFAULT ")
293
+			} else {
294
+				options = append(options, opt)
295
+			}
296
+		}
297
+
298
+		columns = append(columns, generators.ColumnSchema{
299
+			Name:      col.name,
300
+			Type:      colType,
301
+			Length:    length,
302
+			Precision: precision,
303
+			Scale:     scale,
304
+			Comment:   col.comment,
305
+			Options:   options,
306
+			Default:   defaultValue,
307
+		})
308
+	}
309
+
310
+	// 构建索引Schema
311
+	indexes := make([]generators.IndexSchema, 0, len(t.indexes))
312
+
313
+	for _, idx := range t.indexes {
314
+		// 解析索引字符串,例如: "INDEX idx_name (col1, col2)" 或 "UNIQUE INDEX idx_name (col1, col2)"
315
+		indexName, isUnique, columns := parseIndex(idx)
316
+		if indexName != "" && len(columns) > 0 {
317
+			indexes = append(indexes, generators.IndexSchema{
318
+				Name:    indexName,
319
+				Columns: columns,
320
+				Unique:  isUnique,
321
+			})
322
+		}
323
+	}
324
+
325
+	return generators.TableDDL{
326
+		Name:    t.name,
327
+		Comment: t.comment,
328
+		Schema: &generators.TableSchema{
329
+			Name:    t.name,
330
+			Comment: t.comment,
331
+			Columns: columns,
332
+			Indexes: indexes,
333
+		},
334
+	}
335
+}
336
+
337
+// parseColumnType 解析列类型字符串
338
+func parseColumnType(sqlType string) (colType string, length, precision, scale int) {
339
+	sqlType = strings.ToUpper(strings.TrimSpace(sqlType))
340
+
341
+	// 处理带括号的类型,如 VARCHAR(255), DECIMAL(10,2)
342
+	if strings.Contains(sqlType, "(") {
343
+		openParen := strings.Index(sqlType, "(")
344
+		closeParen := strings.Index(sqlType, ")")
345
+
346
+		colType = sqlType[:openParen]
347
+		params := strings.TrimSpace(sqlType[openParen+1 : closeParen])
348
+
349
+		switch colType {
350
+		case "DECIMAL":
351
+			if strings.Contains(params, ",") {
352
+				parts := strings.Split(params, ",")
353
+				if len(parts) == 2 {
354
+					fmt.Sscanf(parts[0], "%d", &precision)
355
+					fmt.Sscanf(parts[1], "%d", &scale)
356
+				}
357
+			}
358
+		case "VARCHAR", "CHAR":
359
+			fmt.Sscanf(params, "%d", &length)
360
+		default:
361
+			fmt.Sscanf(params, "%d", &length)
362
+		}
363
+	} else {
364
+		colType = sqlType
365
+	}
366
+
367
+	return
368
+}
369
+
370
+// parseIndex 解析索引字符串
371
+func parseIndex(indexStr string) (name string, isUnique bool, columns []string) {
372
+	indexStr = strings.TrimSpace(indexStr)
373
+
374
+	// 检查是否是唯一索引
375
+	if strings.HasPrefix(indexStr, "UNIQUE INDEX") {
376
+		isUnique = true
377
+		indexStr = strings.TrimPrefix(indexStr, "UNIQUE INDEX ")
378
+	} else if strings.HasPrefix(indexStr, "INDEX") {
379
+		indexStr = strings.TrimPrefix(indexStr, "INDEX ")
380
+	} else {
381
+		return "", false, nil
382
+	}
383
+
384
+	// 分割索引名和列
385
+	parts := strings.Split(indexStr, " ")
386
+	if len(parts) < 2 {
387
+		return "", false, nil
388
+	}
389
+
390
+	name = parts[0]
391
+
392
+	// 提取列,如 (col1, col2)
393
+	colsPart := strings.Join(parts[1:], " ")
394
+	colsPart = strings.Trim(colsPart, "()")
395
+	columns = strings.Split(colsPart, ", ")
396
+
397
+	return
398
+}
399
+
400
+// Register 快捷方法:直接注册表
401
+func (t *TableBuilder) Register() {
402
+	table := t.Build()
403
+	AddRegistration(func(r *Registry) {
404
+		r.RegisterTable(table)
405
+	})
406
+}

+ 250
- 0
sqldef/table_manager.go View File

@@ -0,0 +1,250 @@
1
+// table_manager.go
2
+package sqldef
3
+
4
+import (
5
+	"fmt"
6
+	"sync"
7
+
8
+	"git.x2erp.com/qdy/go-db/sqldef/generators"
9
+)
10
+
11
+// TableManager 表管理器
12
+type TableManager struct {
13
+	ddlExecutor DDLExecutor
14
+	initialized bool
15
+	mu          sync.RWMutex
16
+}
17
+
18
+// DDLExecutor DDL执行器接口
19
+type DDLExecutor interface {
20
+	ExecuteDDL(ddl string) error
21
+	TableExists(tableName string) (bool, error)
22
+}
23
+
24
+// TableManagerFactory 表管理器工厂
25
+type TableManagerFactory struct {
26
+	instance *TableManager
27
+	mu       sync.RWMutex
28
+}
29
+
30
+var factory = &TableManagerFactory{}
31
+
32
+// GetTableManager 获取或创建表管理器实例(懒加载)
33
+func GetTableManager() *TableManager {
34
+	return factory.GetInstance()
35
+}
36
+
37
+// GetTableManagerWithExecutor 使用指定的执行器获取表管理器
38
+func GetTableManagerWithExecutor(executor DDLExecutor) *TableManager {
39
+	return factory.GetInstanceWithExecutor(executor)
40
+}
41
+
42
+// GetInstance 获取单例实例(懒加载)
43
+func (f *TableManagerFactory) GetInstance() *TableManager {
44
+	f.mu.RLock()
45
+	if f.instance != nil && f.instance.initialized {
46
+		f.mu.RUnlock()
47
+		return f.instance
48
+	}
49
+	f.mu.RUnlock()
50
+
51
+	f.mu.Lock()
52
+	defer f.mu.Unlock()
53
+
54
+	// 双重检查
55
+	if f.instance != nil && f.instance.initialized {
56
+		return f.instance
57
+	}
58
+
59
+	// 创建新实例(但未初始化执行器)
60
+	f.instance = &TableManager{}
61
+	return f.instance
62
+}
63
+
64
+// GetInstanceWithExecutor 使用执行器获取实例
65
+func (f *TableManagerFactory) GetInstanceWithExecutor(executor DDLExecutor) *TableManager {
66
+	f.mu.Lock()
67
+	defer f.mu.Unlock()
68
+
69
+	// 如果已存在实例且有执行器,则直接返回
70
+	if f.instance != nil && f.instance.initialized && f.instance.ddlExecutor != nil {
71
+		return f.instance
72
+	}
73
+
74
+	// 创建或更新实例
75
+	if f.instance == nil {
76
+		f.instance = &TableManager{
77
+			ddlExecutor: executor,
78
+			initialized: true,
79
+		}
80
+	} else {
81
+		f.instance.ddlExecutor = executor
82
+		f.instance.initialized = true
83
+	}
84
+
85
+	return f.instance
86
+}
87
+
88
+// SyncTables 同步所有注册的表结构(懒加载执行器)
89
+// recreate: true - 如果表存在则删除重建;false - 如果表不存在则创建
90
+func (tm *TableManager) SyncTables(recreate bool) error {
91
+	// 确保注册表已初始化
92
+	globalRegistry.ensureInit()
93
+
94
+	// 检查执行器是否初始化
95
+	if !tm.isExecutorInitialized() {
96
+		return fmt.Errorf("DDL执行器未初始化,请先调用SetExecutor方法")
97
+	}
98
+
99
+	// 获取所有注册的表定义
100
+	tables := GetAll()
101
+
102
+	// 按顺序处理所有表
103
+	for _, table := range tables {
104
+		err := tm.syncTable(table, recreate)
105
+		if err != nil {
106
+			return fmt.Errorf("同步表 %s 失败: %w", table.Name, err)
107
+		}
108
+	}
109
+
110
+	return nil
111
+}
112
+
113
+// SetExecutor 设置DDL执行器
114
+func (tm *TableManager) SetExecutor(executor DDLExecutor) {
115
+	tm.mu.Lock()
116
+	defer tm.mu.Unlock()
117
+
118
+	tm.ddlExecutor = executor
119
+	tm.initialized = true
120
+}
121
+
122
+// isExecutorInitialized 检查执行器是否初始化
123
+func (tm *TableManager) isExecutorInitialized() bool {
124
+	tm.mu.RLock()
125
+	defer tm.mu.RUnlock()
126
+
127
+	return tm.initialized && tm.ddlExecutor != nil
128
+}
129
+
130
+// syncTable 同步单个表
131
+func (tm *TableManager) syncTable(table generators.TableDDL, recreate bool) error {
132
+	// 检查表是否存在
133
+	exists, err := tm.ddlExecutor.TableExists(table.Name)
134
+	if err != nil {
135
+		return fmt.Errorf("检查表 %s 是否存在失败: %w", table.Name, err)
136
+	}
137
+
138
+	if recreate {
139
+		// 如果存在就删除重建
140
+		if exists {
141
+			// 删除表
142
+			dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", table.Name)
143
+			if err := tm.ddlExecutor.ExecuteDDL(dropSQL); err != nil {
144
+				return fmt.Errorf("删除表 %s 失败: %w", table.Name, err)
145
+			}
146
+			fmt.Printf("表 %s 已删除\n", table.Name)
147
+
148
+			// 重新创建表
149
+			if err := tm.ddlExecutor.ExecuteDDL(table.SQL); err != nil {
150
+				return fmt.Errorf("创建表 %s 失败: %w", table.Name, err)
151
+			}
152
+			fmt.Printf("表 %s 已创建\n", table.Name)
153
+		} else {
154
+			// 不存在直接创建
155
+			if err := tm.ddlExecutor.ExecuteDDL(table.SQL); err != nil {
156
+				return fmt.Errorf("创建表 %s 失败: %w", table.Name, err)
157
+			}
158
+			fmt.Printf("表 %s 已创建\n", table.Name)
159
+		}
160
+	} else {
161
+		// 如果不存在就建立
162
+		if !exists {
163
+			if err := tm.ddlExecutor.ExecuteDDL(table.SQL); err != nil {
164
+				return fmt.Errorf("创建表 %s 失败: %w", table.Name, err)
165
+			}
166
+			fmt.Printf("表 %s 已创建\n", table.Name)
167
+		} else {
168
+			fmt.Printf("表 %s 已存在,跳过创建\n", table.Name)
169
+		}
170
+	}
171
+
172
+	return nil
173
+}
174
+
175
+// CreateTables 创建所有表(不存在则创建)
176
+func (tm *TableManager) CreateTables() error {
177
+	return tm.SyncTables(false)
178
+}
179
+
180
+// RecreateTables 重建所有表(存在则删除重建)
181
+func (tm *TableManager) RecreateTables() error {
182
+	return tm.SyncTables(true)
183
+}
184
+
185
+// CreateIndexes 创建索引
186
+func (tm *TableManager) CreateIndexes() error {
187
+	// 这里可以添加索引创建的逻辑
188
+	// 由于TableDDL中没有包含索引信息,所以暂时为空
189
+
190
+	// 暂时返回nil,表示索引创建成功(实际上是空的)
191
+	fmt.Println("创建索引方法(空实现)")
192
+	return nil
193
+}
194
+
195
+// CreateTablesAndIndexes 创建表及其索引(便捷方法)
196
+func (tm *TableManager) CreateTablesAndIndexes(recreateTables bool) error {
197
+	if err := tm.SyncTables(recreateTables); err != nil {
198
+		return err
199
+	}
200
+
201
+	return tm.CreateIndexes()
202
+}
203
+
204
+// GetRegisteredTables 获取所有注册的表名
205
+func (tm *TableManager) GetRegisteredTables() []string {
206
+	globalRegistry.ensureInit()
207
+
208
+	globalRegistry.mu.RLock()
209
+	defer globalRegistry.mu.RUnlock()
210
+
211
+	tables := make([]string, 0, len(globalRegistry.tables))
212
+	for name := range globalRegistry.tables {
213
+		tables = append(tables, name)
214
+	}
215
+	return tables
216
+}
217
+
218
+// GetTableDDL 获取指定表的DDL语句
219
+func (tm *TableManager) GetTableDDL(tableName string) (string, bool) {
220
+	table, exists := Get(tableName)
221
+	return table.SQL, exists
222
+}
223
+
224
+// PrintRegisteredTables 打印所有注册的表信息
225
+func (tm *TableManager) PrintRegisteredTables() {
226
+	tables := tm.GetRegisteredTables()
227
+	fmt.Printf("注册的表数量: %d\n", len(tables))
228
+	for i, name := range tables {
229
+		fmt.Printf("%d. %s\n", i+1, name)
230
+	}
231
+}
232
+
233
+// 包级便捷函数
234
+// CreateAllTables 创建所有表(便捷函数,自动获取单例)
235
+func CreateAllTables(executor DDLExecutor) error {
236
+	tm := GetTableManagerWithExecutor(executor)
237
+	return tm.CreateTables()
238
+}
239
+
240
+// RecreateAllTables 重建所有表(便捷函数,自动获取单例)
241
+func RecreateAllTables(executor DDLExecutor) error {
242
+	tm := GetTableManagerWithExecutor(executor)
243
+	return tm.RecreateTables()
244
+}
245
+
246
+// SyncAllTables 同步所有表(便捷函数,自动获取单例)
247
+func SyncAllTables(executor DDLExecutor, recreate bool) error {
248
+	tm := GetTableManagerWithExecutor(executor)
249
+	return tm.SyncTables(recreate)
250
+}

Loading…
Cancel
Save