| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- package generators
-
- import (
- "fmt"
- "strings"
- )
-
- // MySQLGenerator MySQL SQL生成器
- type MySQLGenerator struct{}
-
- // NewMySQLGenerator 创建MySQL生成器实例
- func NewMySQLGenerator() *MySQLGenerator {
- return &MySQLGenerator{}
- }
-
- func (mg *MySQLGenerator) DBType() string {
- return "mysql"
- }
-
- func (mg *MySQLGenerator) TableExistsSQL(tableName string) string {
- return fmt.Sprintf(
- "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '%s'",
- tableName,
- )
- }
-
- func (mg *MySQLGenerator) DropTableSQL(tableName string) string {
- return fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)
- }
-
- func (mg *MySQLGenerator) GenerateCreateTable(table TableDDL) string {
- if table.Schema == nil {
- return ""
- }
-
- var sql strings.Builder
- var pkColumns []string
-
- // 表头
- sql.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", table.Name))
-
- // 第一遍:收集主键列
- for _, col := range table.Schema.Columns {
- for _, opt := range col.Options {
- if strings.ToUpper(opt) == "PRIMARY KEY" {
- pkColumns = append(pkColumns, col.Name)
- break
- }
- }
- }
-
- // 第二遍:构建列定义
- columnDefs := make([]string, 0, len(table.Schema.Columns))
- for _, col := range table.Schema.Columns {
- colDef := fmt.Sprintf(" %s %s", col.Name, mg.getMySQLType(col))
-
- // 处理列选项
- for _, opt := range col.Options {
- upperOpt := strings.ToUpper(opt)
- // PRIMARY KEY 选项不在列级定义
- if upperOpt != "PRIMARY KEY" {
- colDef += " " + opt
- }
- }
-
- // 如果是主键列,确保有 NOT NULL
- isPrimaryKey := false
- for _, pk := range pkColumns {
- if pk == col.Name {
- isPrimaryKey = true
- break
- }
- }
-
- if isPrimaryKey {
- hasNotNull := false
- for _, opt := range col.Options {
- if strings.ToUpper(opt) == "NOT NULL" {
- hasNotNull = true
- break
- }
- }
- if !hasNotNull {
- colDef += " NOT NULL"
- }
- }
-
- // 添加默认值
- if col.Default != "" {
- upperDefault := strings.ToUpper(col.Default)
- if strings.Contains(upperDefault, "CURRENT_TIMESTAMP") ||
- strings.Contains(upperDefault, "NOW()") ||
- strings.Contains(upperDefault, "UUID()") ||
- strings.Contains(upperDefault, "NULL") {
- colDef += " DEFAULT " + col.Default
- } else {
- colDef += fmt.Sprintf(" DEFAULT '%s'", col.Default)
- }
- }
-
- // 添加注释
- if col.Comment != "" {
- colDef += fmt.Sprintf(" COMMENT '%s'", escapeSingleQuote(col.Comment))
- }
-
- columnDefs = append(columnDefs, colDef)
- }
-
- // 将所有定义部分合并
- allParts := make([]string, 0)
- allParts = append(allParts, columnDefs...)
-
- // 添加主键
- if len(pkColumns) > 0 {
- allParts = append(allParts, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(pkColumns, ", ")))
- }
-
- // 添加索引
- for _, idx := range table.Schema.Indexes {
- if idx.Unique {
- allParts = append(allParts, fmt.Sprintf(" UNIQUE KEY %s (%s)",
- idx.Name, strings.Join(idx.Columns, ", ")))
- } else {
- allParts = append(allParts, fmt.Sprintf(" KEY %s (%s)",
- idx.Name, strings.Join(idx.Columns, ", ")))
- }
- }
-
- // 用逗号连接所有部分
- sql.WriteString(strings.Join(allParts, ",\n"))
- sql.WriteString("\n)")
-
- // 表注释
- if table.Schema.Comment != "" {
- sql.WriteString(fmt.Sprintf(" COMMENT='%s'", escapeSingleQuote(table.Schema.Comment)))
- }
-
- // 引擎和字符集
- sql.WriteString(" ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;")
-
- return sql.String()
- }
-
- // getMySQLType 获取MySQL数据类型
- func (mg *MySQLGenerator) getMySQLType(col ColumnSchema) string {
- switch col.Type {
- case "DECIMAL":
- if col.Precision > 0 && col.Scale > 0 {
- return fmt.Sprintf("DECIMAL(%d,%d)", col.Precision, col.Scale)
- }
- return "DECIMAL"
- case "VARCHAR":
- if col.Length > 0 {
- return fmt.Sprintf("VARCHAR(%d)", col.Length)
- }
- return "VARCHAR(255)"
- case "CHAR":
- if col.Length > 0 {
- return fmt.Sprintf("CHAR(%d)", col.Length)
- }
- return "CHAR(1)"
- case "INT":
- return "INT"
- case "BIGINT":
- return "BIGINT"
- case "TINYINT":
- return "TINYINT"
- case "BOOL":
- return "TINYINT(1)"
- case "DATETIME":
- return "DATETIME"
- case "TIMESTAMP":
- return "TIMESTAMP"
- case "DATE":
- return "DATE"
- case "TIME":
- return "TIME"
- case "TEXT":
- return "TEXT"
- case "JSON":
- return "JSON"
- case "BLOB":
- return "BLOB"
- case "FLOAT":
- return "FLOAT"
- case "DOUBLE":
- return "DOUBLE"
- default:
- return col.Type
- }
- }
-
- // escapeSingleQuote 转义单引号
- func escapeSingleQuote(str string) string {
- return strings.ReplaceAll(str, "'", "''")
- }
-
- func init() {
- RegisterGenerator(NewMySQLGenerator())
- }
|