// mysql.go 根据//table_defintion.go 定义的表结构,编写mysql建立表和索引的代码。 package generators import ( "fmt" "strings" ) // MySQLGenerator MySQL SQL生成器 type MySQLGenerator struct{} // NewMySQLGenerator 创建MySQL生成器实例 func NewMySQLGenerator() *MySQLGenerator { return &MySQLGenerator{} } func (mg *MySQLGenerator) DBType() string { return "mysql" } func (mg *MySQLGenerator) TableExistsSQL(tableName string) string { return fmt.Sprintf( "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '%s'", tableName, ) } func (mg *MySQLGenerator) DropTableSQL(tableName string) string { return fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName) } func (mg *MySQLGenerator) GenerateCreateTable(table TableDDL) string { if table.Schema == nil { return "" } var sql strings.Builder // 表头 sql.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", table.Name)) // 列定义 columns := table.Schema.Columns for i, col := range columns { sql.WriteString(fmt.Sprintf(" %s %s", col.Name, mg.getMySQLType(col))) // 添加选项 for _, opt := range col.Options { sql.WriteString(" " + opt) } // 添加默认值 if col.Default != "" { // 检查是否是函数调用(如CURRENT_TIMESTAMP) if strings.Contains(strings.ToUpper(col.Default), "CURRENT_TIMESTAMP") || strings.Contains(strings.ToUpper(col.Default), "NOW()") { sql.WriteString(" DEFAULT " + col.Default) } else { sql.WriteString(fmt.Sprintf(" DEFAULT '%s'", col.Default)) } } // 添加注释 if col.Comment != "" { sql.WriteString(fmt.Sprintf(" COMMENT '%s'", col.Comment)) } if i < len(columns)-1 { sql.WriteString(",") } sql.WriteString("\n") } // 添加索引(在MySQL中,索引可以在CREATE TABLE语句中定义) for _, idx := range table.Schema.Indexes { if idx.Unique { sql.WriteString(fmt.Sprintf(" ,UNIQUE KEY %s (%s)\n", idx.Name, strings.Join(idx.Columns, ", "))) } else { sql.WriteString(fmt.Sprintf(" ,KEY %s (%s)\n", idx.Name, strings.Join(idx.Columns, ", "))) } } sql.WriteString(")") // 表选项 if table.Schema.Comment != "" { sql.WriteString(fmt.Sprintf(" COMMENT='%s'", table.Schema.Comment)) } sql.WriteString(" ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;") return sql.String() } // getMySQLType 获取MySQL数据类型 func (mg *MySQLGenerator) getMySQLType(col ColumnSchema) string { switch col.Type { case "DECIMAL": if col.Precision > 0 && col.Scale > 0 { return fmt.Sprintf("DECIMAL(%d,%d)", col.Precision, col.Scale) } return "DECIMAL" case "VARCHAR": if col.Length > 0 { return fmt.Sprintf("VARCHAR(%d)", col.Length) } return "VARCHAR(255)" case "CHAR": if col.Length > 0 { return fmt.Sprintf("CHAR(%d)", col.Length) } return "CHAR(1)" case "INT": return "INT" case "BIGINT": return "BIGINT" case "TINYINT": return "TINYINT" case "BOOL": return "TINYINT(1)" case "DATETIME": return "DATETIME" case "TIMESTAMP": return "TIMESTAMP" case "DATE": return "DATE" case "TIME": return "TIME" case "TEXT": return "TEXT" case "JSON": return "JSON" case "BLOB": return "BLOB" case "FLOAT": return "FLOAT" case "DOUBLE": return "DOUBLE" default: return col.Type } } func init() { RegisterGenerator(NewMySQLGenerator()) }