| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- package sqldef
-
- import (
- "fmt"
- "strings"
- "sync"
-
- "git.x2erp.com/qdy/go-db/sqldef/generators"
- )
-
- // 数据类型常量
- const (
- TypeVarchar = "VARCHAR"
- TypeChar = "CHAR"
- TypeText = "TEXT"
- TypeInt = "INT"
- TypeBigInt = "BIGINT"
- TypeTinyInt = "TINYINT"
- TypeDecimal = "DECIMAL"
- TypeFloat = "FLOAT"
- TypeDouble = "DOUBLE"
- TypeBool = "BOOL"
- TypeDateTime = "DATETIME"
- TypeTimestamp = "TIMESTAMP"
- TypeDate = "DATE"
- TypeTime = "TIME"
- TypeBlob = "BLOB"
- TypeJson = "JSON"
- )
-
- // Column 列定义
- type Column struct {
- name string
- sqlType string
- comment string
- options []string
- defaultValue string // 新增:默认值
- }
-
- // ColumnBuilder 列构建器
- type ColumnBuilder struct {
- table *TableBuilder
- column Column
- }
-
- // TableBuilder 表构建器
- type TableBuilder struct {
- name string
- comment string
- columns []Column
- indexes []string
- }
-
- // Registry 注册表(懒加载)
- type Registry struct {
- tables map[string]generators.TableDDL
- mu sync.RWMutex
- once sync.Once
- regFns []func(*Registry)
- }
-
- var globalRegistry = &Registry{
- tables: make(map[string]generators.TableDDL),
- }
-
- // AddRegistration 添加注册函数
- func AddRegistration(fn func(*Registry)) {
- globalRegistry.regFns = append(globalRegistry.regFns, fn)
- }
-
- func (r *Registry) lazyInit() {
- r.mu.Lock()
- defer r.mu.Unlock()
-
- for _, fn := range r.regFns {
- fn(r)
- }
- }
-
- func (r *Registry) ensureInit() {
- r.once.Do(r.lazyInit)
- }
-
- // RegisterTable 注册表
- func (r *Registry) RegisterTable(table generators.TableDDL) {
- if table.Name != "" {
- r.tables[table.Name] = table
- }
- }
-
- // GetAll 获取所有表定义
- func GetAll() []generators.TableDDL {
- globalRegistry.ensureInit()
-
- globalRegistry.mu.RLock()
- defer globalRegistry.mu.RUnlock()
-
- result := make([]generators.TableDDL, 0, len(globalRegistry.tables))
- for _, table := range globalRegistry.tables {
- result = append(result, table)
- }
- return result
- }
-
- // Get 获取表定义
- func Get(tableName string) (generators.TableDDL, bool) {
- globalRegistry.ensureInit()
-
- globalRegistry.mu.RLock()
- defer globalRegistry.mu.RUnlock()
-
- table, exists := globalRegistry.tables[tableName]
- return table, exists
- }
-
- // ================== TableBuilder 表构建器 ==================
-
- // NewTable 创建新表
- func NewTable(name string, comment ...string) *TableBuilder {
- tb := &TableBuilder{name: name}
- if len(comment) > 0 {
- tb.comment = comment[0]
- }
- return tb
- }
-
- // ================== 列定义方法 ==================
-
- // 默认64位长度
- func (t *TableBuilder) ID(name string, length ...int) *ColumnBuilder {
- size := 64
- if len(length) > 0 {
- size = length[0]
- }
- return t.column(name, TypeVarchar, size).PrimaryKey()
- }
-
- func (t *TableBuilder) String(name string, length int) *ColumnBuilder {
- return t.column(name, TypeVarchar, length)
- }
-
- func (t *TableBuilder) Char(name string, length int) *ColumnBuilder {
- return t.column(name, TypeChar, length)
- }
-
- func (t *TableBuilder) Text(name string) *ColumnBuilder {
- return t.column(name, TypeText, 0)
- }
-
- func (t *TableBuilder) Int(name string) *ColumnBuilder {
- return t.column(name, TypeInt, 0)
- }
-
- func (t *TableBuilder) TinyInt(name string) *ColumnBuilder {
- return t.column(name, TypeTinyInt, 0)
- }
-
- func (t *TableBuilder) BigInt(name string) *ColumnBuilder {
- return t.column(name, TypeBigInt, 0)
- }
-
- func (t *TableBuilder) Bool(name string) *ColumnBuilder {
- return t.column(name, TypeBool, 0)
- }
-
- func (t *TableBuilder) JSON(name string) *ColumnBuilder {
- return t.column(name, TypeJson, 0)
- }
-
- func (t *TableBuilder) Decimal(name string, precision, scale int) *ColumnBuilder {
- col := &ColumnBuilder{
- table: t,
- column: Column{
- name: name,
- sqlType: fmt.Sprintf("%s(%d,%d)", TypeDecimal, precision, scale),
- },
- }
- return col
- }
-
- func (t *TableBuilder) Float(name string) *ColumnBuilder {
- return t.column(name, TypeFloat, 0)
- }
-
- func (t *TableBuilder) Double(name string) *ColumnBuilder {
- return t.column(name, TypeDouble, 0)
- }
-
- func (t *TableBuilder) DateTime(name string) *ColumnBuilder {
- return t.column(name, TypeDateTime, 0)
- }
-
- func (t *TableBuilder) Timestamp(name string) *ColumnBuilder {
- return t.column(name, TypeTimestamp, 0)
- }
-
- func (t *TableBuilder) Time(name string) *ColumnBuilder {
- return t.column(name, TypeTime, 0)
- }
-
- func (t *TableBuilder) Date(name string) *ColumnBuilder {
- return t.column(name, TypeDate, 0)
- }
-
- func (t *TableBuilder) Blob(name string) *ColumnBuilder {
- return t.column(name, TypeBlob, 0)
- }
-
- // 私有辅助方法
- func (t *TableBuilder) column(name string, dataType string, length int) *ColumnBuilder {
- sqlType := dataType
- if length > 0 {
- sqlType = fmt.Sprintf("%s(%d)", dataType, length)
- }
-
- col := &ColumnBuilder{
- table: t,
- column: Column{
- name: name,
- sqlType: sqlType,
- },
- }
- return col
- }
-
- // ================== 列构建器方法 ==================
-
- func (c *ColumnBuilder) NotNull() *ColumnBuilder {
- c.column.options = append(c.column.options, "NOT NULL")
- return c
- }
-
- func (c *ColumnBuilder) Default(value string) *ColumnBuilder {
- c.column.defaultValue = value
- return c
- }
-
- func (c *ColumnBuilder) PrimaryKey() *ColumnBuilder {
- c.column.options = append(c.column.options, "PRIMARY KEY")
- return c
- }
-
- func (c *ColumnBuilder) AutoIncrement() *ColumnBuilder {
- c.column.options = append(c.column.options, "AUTO_INCREMENT")
- return c
- }
-
- func (c *ColumnBuilder) Unique() *ColumnBuilder {
- c.column.options = append(c.column.options, "UNIQUE")
- return c
- }
-
- func (c *ColumnBuilder) Comment(comment string) *ColumnBuilder {
- c.column.comment = comment
- return c
- }
-
- // End 结束列定义,返回TableBuilder继续定义其他列
- func (c *ColumnBuilder) End() *TableBuilder {
- c.table.columns = append(c.table.columns, c.column)
- return c.table
- }
-
- // ================== 索引方法 ==================
-
- func (t *TableBuilder) AddIndex(name string, columns ...string) *TableBuilder {
- idx := fmt.Sprintf("INDEX %s (%s)", name, strings.Join(columns, ", "))
- t.indexes = append(t.indexes, idx)
- return t
- }
-
- func (t *TableBuilder) AddUniqueIndex(name string, columns ...string) *TableBuilder {
- idx := fmt.Sprintf("UNIQUE INDEX %s (%s)", name, strings.Join(columns, ", "))
- t.indexes = append(t.indexes, idx)
- return t
- }
-
- // Build 构建表定义,包含完整的Schema信息
- func (t *TableBuilder) Build() generators.TableDDL {
- // 构建列Schema
- columns := make([]generators.ColumnSchema, 0, len(t.columns))
-
- for _, col := range t.columns {
- // 解析列类型
- colType, length, precision, scale := parseColumnType(col.sqlType)
-
- // 提取默认值
- var defaultValue string
- var options []string
- for _, opt := range col.options {
- if strings.HasPrefix(opt, "DEFAULT ") {
- defaultValue = strings.TrimPrefix(opt, "DEFAULT ")
- } else {
- options = append(options, opt)
- }
- }
-
- columns = append(columns, generators.ColumnSchema{
- Name: col.name,
- Type: colType,
- Length: length,
- Precision: precision,
- Scale: scale,
- Comment: col.comment,
- Options: options,
- Default: defaultValue,
- })
- }
-
- // 构建索引Schema
- indexes := make([]generators.IndexSchema, 0, len(t.indexes))
-
- for _, idx := range t.indexes {
- // 解析索引字符串,例如: "INDEX idx_name (col1, col2)" 或 "UNIQUE INDEX idx_name (col1, col2)"
- indexName, isUnique, columns := parseIndex(idx)
- if indexName != "" && len(columns) > 0 {
- indexes = append(indexes, generators.IndexSchema{
- Name: indexName,
- Columns: columns,
- Unique: isUnique,
- })
- }
- }
-
- return generators.TableDDL{
- Name: t.name,
- Comment: t.comment,
- Schema: &generators.TableSchema{
- Name: t.name,
- Comment: t.comment,
- Columns: columns,
- Indexes: indexes,
- },
- }
- }
-
- // parseColumnType 解析列类型字符串
- func parseColumnType(sqlType string) (colType string, length, precision, scale int) {
- sqlType = strings.ToUpper(strings.TrimSpace(sqlType))
-
- // 处理带括号的类型,如 VARCHAR(255), DECIMAL(10,2)
- if strings.Contains(sqlType, "(") {
- openParen := strings.Index(sqlType, "(")
- closeParen := strings.Index(sqlType, ")")
-
- colType = sqlType[:openParen]
- params := strings.TrimSpace(sqlType[openParen+1 : closeParen])
-
- switch colType {
- case "DECIMAL":
- if strings.Contains(params, ",") {
- parts := strings.Split(params, ",")
- if len(parts) == 2 {
- fmt.Sscanf(parts[0], "%d", &precision)
- fmt.Sscanf(parts[1], "%d", &scale)
- }
- }
- case "VARCHAR", "CHAR":
- fmt.Sscanf(params, "%d", &length)
- default:
- fmt.Sscanf(params, "%d", &length)
- }
- } else {
- colType = sqlType
- }
-
- return
- }
-
- // parseIndex 解析索引字符串
- func parseIndex(indexStr string) (name string, isUnique bool, columns []string) {
- indexStr = strings.TrimSpace(indexStr)
-
- // 检查是否是唯一索引
- if strings.HasPrefix(indexStr, "UNIQUE INDEX") {
- isUnique = true
- indexStr = strings.TrimPrefix(indexStr, "UNIQUE INDEX ")
- } else if strings.HasPrefix(indexStr, "INDEX") {
- indexStr = strings.TrimPrefix(indexStr, "INDEX ")
- } else {
- return "", false, nil
- }
-
- // 分割索引名和列
- parts := strings.Split(indexStr, " ")
- if len(parts) < 2 {
- return "", false, nil
- }
-
- name = parts[0]
-
- // 提取列,如 (col1, col2)
- colsPart := strings.Join(parts[1:], " ")
- colsPart = strings.Trim(colsPart, "()")
- columns = strings.Split(colsPart, ", ")
-
- return
- }
-
- // Register 快捷方法:直接注册表
- func (t *TableBuilder) Register() {
- table := t.Build()
- AddRegistration(func(r *Registry) {
- r.RegisterTable(table)
- })
- }
|