package generators import ( "fmt" "strings" ) // MySQLGenerator MySQL SQL生成器 type MySQLGenerator struct{} // NewMySQLGenerator 创建MySQL生成器实例 func NewMySQLGenerator() *MySQLGenerator { return &MySQLGenerator{} } func (mg *MySQLGenerator) DBType() string { return "mysql" } func (mg *MySQLGenerator) TableExistsSQL(tableName string) string { return fmt.Sprintf( "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '%s'", tableName, ) } func (mg *MySQLGenerator) DropTableSQL(tableName string) string { return fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName) } func (mg *MySQLGenerator) GenerateCreateTable(table TableDDL) string { if table.Schema == nil { return "" } var sql strings.Builder var pkColumns []string // 表头 sql.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", table.Name)) // 第一遍:收集主键列 for _, col := range table.Schema.Columns { for _, opt := range col.Options { if strings.ToUpper(opt) == "PRIMARY KEY" { pkColumns = append(pkColumns, col.Name) break } } } // 第二遍:构建列定义 columnDefs := make([]string, 0, len(table.Schema.Columns)) for _, col := range table.Schema.Columns { colDef := fmt.Sprintf(" %s %s", col.Name, mg.getMySQLType(col)) // 处理列选项 for _, opt := range col.Options { upperOpt := strings.ToUpper(opt) // PRIMARY KEY 选项不在列级定义 if upperOpt != "PRIMARY KEY" { colDef += " " + opt } } // 如果是主键列,确保有 NOT NULL isPrimaryKey := false for _, pk := range pkColumns { if pk == col.Name { isPrimaryKey = true break } } if isPrimaryKey { hasNotNull := false for _, opt := range col.Options { if strings.ToUpper(opt) == "NOT NULL" { hasNotNull = true break } } if !hasNotNull { colDef += " NOT NULL" } } // 添加默认值 if col.Default != "" { upperDefault := strings.ToUpper(col.Default) if strings.Contains(upperDefault, "CURRENT_TIMESTAMP") || strings.Contains(upperDefault, "NOW()") || strings.Contains(upperDefault, "UUID()") || strings.Contains(upperDefault, "NULL") { colDef += " DEFAULT " + col.Default } else { colDef += fmt.Sprintf(" DEFAULT '%s'", col.Default) } } // 添加注释 if col.Comment != "" { colDef += fmt.Sprintf(" COMMENT '%s'", escapeSingleQuote(col.Comment)) } columnDefs = append(columnDefs, colDef) } // 将所有定义部分合并 allParts := make([]string, 0) allParts = append(allParts, columnDefs...) // 添加主键 if len(pkColumns) > 0 { allParts = append(allParts, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(pkColumns, ", "))) } // 添加索引 for _, idx := range table.Schema.Indexes { if idx.Unique { allParts = append(allParts, fmt.Sprintf(" UNIQUE KEY %s (%s)", idx.Name, strings.Join(idx.Columns, ", "))) } else { allParts = append(allParts, fmt.Sprintf(" KEY %s (%s)", idx.Name, strings.Join(idx.Columns, ", "))) } } // 用逗号连接所有部分 sql.WriteString(strings.Join(allParts, ",\n")) sql.WriteString("\n)") // 表注释 if table.Schema.Comment != "" { sql.WriteString(fmt.Sprintf(" COMMENT='%s'", escapeSingleQuote(table.Schema.Comment))) } // 引擎和字符集 sql.WriteString(" ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;") return sql.String() } // getMySQLType 获取MySQL数据类型 func (mg *MySQLGenerator) getMySQLType(col ColumnSchema) string { switch col.Type { case "DECIMAL": if col.Precision > 0 && col.Scale > 0 { return fmt.Sprintf("DECIMAL(%d,%d)", col.Precision, col.Scale) } return "DECIMAL" case "VARCHAR": if col.Length > 0 { return fmt.Sprintf("VARCHAR(%d)", col.Length) } return "VARCHAR(255)" case "CHAR": if col.Length > 0 { return fmt.Sprintf("CHAR(%d)", col.Length) } return "CHAR(1)" case "INT": return "INT" case "BIGINT": return "BIGINT" case "TINYINT": return "TINYINT" case "BOOL": return "TINYINT(1)" case "DATETIME": return "DATETIME" case "TIMESTAMP": return "TIMESTAMP" case "DATE": return "DATE" case "TIME": return "TIME" case "TEXT": return "TEXT" case "JSON": return "JSON" case "BLOB": return "BLOB" case "FLOAT": return "FLOAT" case "DOUBLE": return "DOUBLE" default: return col.Type } } // escapeSingleQuote 转义单引号 func escapeSingleQuote(str string) string { return strings.ReplaceAll(str, "'", "''") } func init() { RegisterGenerator(NewMySQLGenerator()) }