package sqldef import ( "fmt" "strings" "sync" "git.x2erp.com/qdy/go-db/sqldef/generators" ) // 数据类型常量 const ( TypeVarchar = "VARCHAR" TypeChar = "CHAR" TypeText = "TEXT" TypeInt = "INT" TypeBigInt = "BIGINT" TypeTinyInt = "TINYINT" TypeDecimal = "DECIMAL" TypeFloat = "FLOAT" TypeDouble = "DOUBLE" TypeBool = "BOOL" TypeDateTime = "DATETIME" TypeTimestamp = "TIMESTAMP" TypeDate = "DATE" TypeTime = "TIME" TypeBlob = "BLOB" TypeJson = "JSON" ) // Column 列定义 type Column struct { name string chineseName string // 新增:字段中文名称 aliases []string // 新增:字段别名列表 sqlType string comment string options []string defaultValue string // 新增:默认值 } // ColumnBuilder 列构建器 type ColumnBuilder struct { table *TableBuilder column Column } // TableBuilder 表构建器 type TableBuilder struct { name string chineseName string // 新增:表中文名称 aliases []string // 新增:表别名列表 comment string columns []Column indexes []string } // Registry 注册表(懒加载) type Registry struct { tables map[string]generators.TableDDL mu sync.RWMutex once sync.Once regFns []func(*Registry) } var globalRegistry = &Registry{ tables: make(map[string]generators.TableDDL), } // AddRegistration 添加注册函数 func AddRegistration(fn func(*Registry)) { globalRegistry.regFns = append(globalRegistry.regFns, fn) } func (r *Registry) lazyInit() { r.mu.Lock() defer r.mu.Unlock() for _, fn := range r.regFns { fn(r) } } func (r *Registry) ensureInit() { r.once.Do(r.lazyInit) } // RegisterTable 注册表 func (r *Registry) RegisterTable(table generators.TableDDL) { if table.Name != "" { r.tables[table.Name] = table } } // GetAll 获取所有表定义 func GetAll() []generators.TableDDL { globalRegistry.ensureInit() globalRegistry.mu.RLock() defer globalRegistry.mu.RUnlock() result := make([]generators.TableDDL, 0, len(globalRegistry.tables)) for _, table := range globalRegistry.tables { result = append(result, table) } return result } // Get 获取表定义 func Get(tableName string) (generators.TableDDL, bool) { globalRegistry.ensureInit() globalRegistry.mu.RLock() defer globalRegistry.mu.RUnlock() table, exists := globalRegistry.tables[tableName] return table, exists } // ================== TableBuilder 表构建器 ================== // NewTable 创建新表 func NewTable(name string, comment ...string) *TableBuilder { tb := &TableBuilder{name: name} if len(comment) > 0 { tb.comment = comment[0] } return tb } // ================== TableBuilder 扩展方法 ================== func (t *TableBuilder) ChineseName(name string) *TableBuilder { t.chineseName = name return t } func (t *TableBuilder) Aliases(aliases ...string) *TableBuilder { t.aliases = aliases return t } func (t *TableBuilder) Alias(alias string) *TableBuilder { t.aliases = append(t.aliases, alias) return t } // ================== 列定义方法 ================== // 默认64位长度 func (t *TableBuilder) ID(name string, length ...int) *ColumnBuilder { size := 64 if len(length) > 0 { size = length[0] } return t.column(name, TypeVarchar, size).PrimaryKey() } func (t *TableBuilder) String(name string, length int) *ColumnBuilder { return t.column(name, TypeVarchar, length) } func (t *TableBuilder) Char(name string, length int) *ColumnBuilder { return t.column(name, TypeChar, length) } func (t *TableBuilder) Text(name string) *ColumnBuilder { return t.column(name, TypeText, 0) } func (t *TableBuilder) Int(name string) *ColumnBuilder { return t.column(name, TypeInt, 0) } func (t *TableBuilder) TinyInt(name string) *ColumnBuilder { return t.column(name, TypeTinyInt, 0) } func (t *TableBuilder) BigInt(name string) *ColumnBuilder { return t.column(name, TypeBigInt, 0) } func (t *TableBuilder) Bool(name string) *ColumnBuilder { return t.column(name, TypeBool, 0) } func (t *TableBuilder) JSON(name string) *ColumnBuilder { return t.column(name, TypeJson, 0) } func (t *TableBuilder) Decimal(name string, precision, scale int) *ColumnBuilder { col := &ColumnBuilder{ table: t, column: Column{ name: name, sqlType: fmt.Sprintf("%s(%d,%d)", TypeDecimal, precision, scale), }, } return col } func (t *TableBuilder) Float(name string) *ColumnBuilder { return t.column(name, TypeFloat, 0) } func (t *TableBuilder) Double(name string) *ColumnBuilder { return t.column(name, TypeDouble, 0) } func (t *TableBuilder) DateTime(name string) *ColumnBuilder { return t.column(name, TypeDateTime, 0) } func (t *TableBuilder) Timestamp(name string) *ColumnBuilder { return t.column(name, TypeTimestamp, 0) } func (t *TableBuilder) Time(name string) *ColumnBuilder { return t.column(name, TypeTime, 0) } func (t *TableBuilder) Date(name string) *ColumnBuilder { return t.column(name, TypeDate, 0) } func (t *TableBuilder) Blob(name string) *ColumnBuilder { return t.column(name, TypeBlob, 0) } // 私有辅助方法 func (t *TableBuilder) column(name string, dataType string, length int) *ColumnBuilder { sqlType := dataType if length > 0 { sqlType = fmt.Sprintf("%s(%d)", dataType, length) } col := &ColumnBuilder{ table: t, column: Column{ name: name, sqlType: sqlType, }, } return col } // ================== 列构建器方法 ================== func (c *ColumnBuilder) NotNull() *ColumnBuilder { c.column.options = append(c.column.options, "NOT NULL") return c } func (c *ColumnBuilder) Default(value string) *ColumnBuilder { c.column.defaultValue = value return c } func (c *ColumnBuilder) PrimaryKey() *ColumnBuilder { c.column.options = append(c.column.options, "PRIMARY KEY") return c } func (c *ColumnBuilder) AutoIncrement() *ColumnBuilder { c.column.options = append(c.column.options, "AUTO_INCREMENT") return c } func (c *ColumnBuilder) Unique() *ColumnBuilder { c.column.options = append(c.column.options, "UNIQUE") return c } func (c *ColumnBuilder) Comment(comment string) *ColumnBuilder { c.column.comment = comment return c } func (c *ColumnBuilder) ChineseName(name string) *ColumnBuilder { c.column.chineseName = name return c } func (c *ColumnBuilder) Aliases(aliases ...string) *ColumnBuilder { c.column.aliases = aliases return c } func (c *ColumnBuilder) Alias(alias string) *ColumnBuilder { c.column.aliases = append(c.column.aliases, alias) return c } // End 结束列定义,返回TableBuilder继续定义其他列 func (c *ColumnBuilder) End() *TableBuilder { c.table.columns = append(c.table.columns, c.column) return c.table } // ================== 索引方法 ================== func (t *TableBuilder) AddIndex(name string, columns ...string) *TableBuilder { idx := fmt.Sprintf("INDEX %s (%s)", name, strings.Join(columns, ", ")) t.indexes = append(t.indexes, idx) return t } func (t *TableBuilder) AddUniqueIndex(name string, columns ...string) *TableBuilder { idx := fmt.Sprintf("UNIQUE INDEX %s (%s)", name, strings.Join(columns, ", ")) t.indexes = append(t.indexes, idx) return t } // Build 构建表定义,包含完整的Schema信息 func (t *TableBuilder) Build() generators.TableDDL { // 构建列Schema columns := make([]generators.ColumnSchema, 0, len(t.columns)) for _, col := range t.columns { // 解析列类型 colType, length, precision, scale := parseColumnType(col.sqlType) // 提取默认值 var defaultValue string var options []string for _, opt := range col.options { if strings.HasPrefix(opt, "DEFAULT ") { defaultValue = strings.TrimPrefix(opt, "DEFAULT ") } else { options = append(options, opt) } } columns = append(columns, generators.ColumnSchema{ Name: col.name, ChineseName: col.chineseName, Aliases: col.aliases, Type: colType, Length: length, Precision: precision, Scale: scale, Comment: col.comment, Options: options, Default: defaultValue, }) } // 构建索引Schema indexes := make([]generators.IndexSchema, 0, len(t.indexes)) for _, idx := range t.indexes { // 解析索引字符串,例如: "INDEX idx_name (col1, col2)" 或 "UNIQUE INDEX idx_name (col1, col2)" indexName, isUnique, columns := parseIndex(idx) if indexName != "" && len(columns) > 0 { indexes = append(indexes, generators.IndexSchema{ Name: indexName, Columns: columns, Unique: isUnique, }) } } return generators.TableDDL{ Name: t.name, Comment: t.comment, Schema: &generators.TableSchema{ Name: t.name, ChineseName: t.chineseName, Aliases: t.aliases, Comment: t.comment, Columns: columns, Indexes: indexes, }, } } // parseColumnType 解析列类型字符串 func parseColumnType(sqlType string) (colType string, length, precision, scale int) { sqlType = strings.ToUpper(strings.TrimSpace(sqlType)) // 处理带括号的类型,如 VARCHAR(255), DECIMAL(10,2) if strings.Contains(sqlType, "(") { openParen := strings.Index(sqlType, "(") closeParen := strings.Index(sqlType, ")") colType = sqlType[:openParen] params := strings.TrimSpace(sqlType[openParen+1 : closeParen]) switch colType { case "DECIMAL": if strings.Contains(params, ",") { parts := strings.Split(params, ",") if len(parts) == 2 { fmt.Sscanf(parts[0], "%d", &precision) fmt.Sscanf(parts[1], "%d", &scale) } } case "VARCHAR", "CHAR": fmt.Sscanf(params, "%d", &length) default: fmt.Sscanf(params, "%d", &length) } } else { colType = sqlType } return } // parseIndex 解析索引字符串 func parseIndex(indexStr string) (name string, isUnique bool, columns []string) { indexStr = strings.TrimSpace(indexStr) // 检查是否是唯一索引 if strings.HasPrefix(indexStr, "UNIQUE INDEX") { isUnique = true indexStr = strings.TrimPrefix(indexStr, "UNIQUE INDEX ") } else if strings.HasPrefix(indexStr, "INDEX") { indexStr = strings.TrimPrefix(indexStr, "INDEX ") } else { return "", false, nil } // 分割索引名和列 parts := strings.Split(indexStr, " ") if len(parts) < 2 { return "", false, nil } name = parts[0] // 提取列,如 (col1, col2) colsPart := strings.Join(parts[1:], " ") colsPart = strings.Trim(colsPart, "()") columns = strings.Split(colsPart, ", ") return } // Register 快捷方法:直接注册表 func (t *TableBuilder) Register() { table := t.Build() AddRegistration(func(r *Registry) { r.RegisterTable(table) }) }