Ingen beskrivning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mysql.go 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package generators
  2. import (
  3. "fmt"
  4. "strings"
  5. )
  6. // MySQLGenerator MySQL SQL生成器
  7. type MySQLGenerator struct{}
  8. // NewMySQLGenerator 创建MySQL生成器实例
  9. func NewMySQLGenerator() *MySQLGenerator {
  10. return &MySQLGenerator{}
  11. }
  12. func (mg *MySQLGenerator) DBType() string {
  13. return "mysql"
  14. }
  15. func (mg *MySQLGenerator) TableExistsSQL(tableName string) string {
  16. return fmt.Sprintf(
  17. "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '%s'",
  18. tableName,
  19. )
  20. }
  21. func (mg *MySQLGenerator) DropTableSQL(tableName string) string {
  22. return fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)
  23. }
  24. func (mg *MySQLGenerator) GenerateCreateTable(table TableDDL) string {
  25. if table.Schema == nil {
  26. return ""
  27. }
  28. var sql strings.Builder
  29. var pkColumns []string
  30. // 表头
  31. sql.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", table.Name))
  32. // 第一遍:收集主键列
  33. for _, col := range table.Schema.Columns {
  34. for _, opt := range col.Options {
  35. if strings.ToUpper(opt) == "PRIMARY KEY" {
  36. pkColumns = append(pkColumns, col.Name)
  37. break
  38. }
  39. }
  40. }
  41. // 第二遍:构建列定义
  42. columnDefs := make([]string, 0, len(table.Schema.Columns))
  43. for _, col := range table.Schema.Columns {
  44. colDef := fmt.Sprintf(" %s %s", col.Name, mg.getMySQLType(col))
  45. // 处理列选项
  46. for _, opt := range col.Options {
  47. upperOpt := strings.ToUpper(opt)
  48. // PRIMARY KEY 选项不在列级定义
  49. if upperOpt != "PRIMARY KEY" {
  50. colDef += " " + opt
  51. }
  52. }
  53. // 如果是主键列,确保有 NOT NULL
  54. isPrimaryKey := false
  55. for _, pk := range pkColumns {
  56. if pk == col.Name {
  57. isPrimaryKey = true
  58. break
  59. }
  60. }
  61. if isPrimaryKey {
  62. hasNotNull := false
  63. for _, opt := range col.Options {
  64. if strings.ToUpper(opt) == "NOT NULL" {
  65. hasNotNull = true
  66. break
  67. }
  68. }
  69. if !hasNotNull {
  70. colDef += " NOT NULL"
  71. }
  72. }
  73. // 添加默认值
  74. if col.Default != "" {
  75. upperDefault := strings.ToUpper(col.Default)
  76. if strings.Contains(upperDefault, "CURRENT_TIMESTAMP") ||
  77. strings.Contains(upperDefault, "NOW()") ||
  78. strings.Contains(upperDefault, "UUID()") ||
  79. strings.Contains(upperDefault, "NULL") {
  80. colDef += " DEFAULT " + col.Default
  81. } else {
  82. colDef += fmt.Sprintf(" DEFAULT '%s'", col.Default)
  83. }
  84. }
  85. // 添加注释
  86. if col.Comment != "" {
  87. colDef += fmt.Sprintf(" COMMENT '%s'", escapeSingleQuote(col.Comment))
  88. }
  89. columnDefs = append(columnDefs, colDef)
  90. }
  91. // 将所有定义部分合并
  92. allParts := make([]string, 0)
  93. allParts = append(allParts, columnDefs...)
  94. // 添加主键
  95. if len(pkColumns) > 0 {
  96. allParts = append(allParts, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(pkColumns, ", ")))
  97. }
  98. // 添加索引
  99. for _, idx := range table.Schema.Indexes {
  100. if idx.Unique {
  101. allParts = append(allParts, fmt.Sprintf(" UNIQUE KEY %s (%s)",
  102. idx.Name, strings.Join(idx.Columns, ", ")))
  103. } else {
  104. allParts = append(allParts, fmt.Sprintf(" KEY %s (%s)",
  105. idx.Name, strings.Join(idx.Columns, ", ")))
  106. }
  107. }
  108. // 用逗号连接所有部分
  109. sql.WriteString(strings.Join(allParts, ",\n"))
  110. sql.WriteString("\n)")
  111. // 表注释
  112. if table.Schema.Comment != "" {
  113. sql.WriteString(fmt.Sprintf(" COMMENT='%s'", escapeSingleQuote(table.Schema.Comment)))
  114. }
  115. // 引擎和字符集
  116. sql.WriteString(" ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;")
  117. return sql.String()
  118. }
  119. // getMySQLType 获取MySQL数据类型
  120. func (mg *MySQLGenerator) getMySQLType(col ColumnSchema) string {
  121. switch col.Type {
  122. case "DECIMAL":
  123. if col.Precision > 0 && col.Scale > 0 {
  124. return fmt.Sprintf("DECIMAL(%d,%d)", col.Precision, col.Scale)
  125. }
  126. return "DECIMAL"
  127. case "VARCHAR":
  128. if col.Length > 0 {
  129. return fmt.Sprintf("VARCHAR(%d)", col.Length)
  130. }
  131. return "VARCHAR(255)"
  132. case "CHAR":
  133. if col.Length > 0 {
  134. return fmt.Sprintf("CHAR(%d)", col.Length)
  135. }
  136. return "CHAR(1)"
  137. case "INT":
  138. return "INT"
  139. case "BIGINT":
  140. return "BIGINT"
  141. case "TINYINT":
  142. return "TINYINT"
  143. case "BOOL":
  144. return "TINYINT(1)"
  145. case "DATETIME":
  146. return "DATETIME"
  147. case "TIMESTAMP":
  148. return "TIMESTAMP"
  149. case "DATE":
  150. return "DATE"
  151. case "TIME":
  152. return "TIME"
  153. case "TEXT":
  154. return "TEXT"
  155. case "JSON":
  156. return "JSON"
  157. case "BLOB":
  158. return "BLOB"
  159. case "FLOAT":
  160. return "FLOAT"
  161. case "DOUBLE":
  162. return "DOUBLE"
  163. default:
  164. return col.Type
  165. }
  166. }
  167. // escapeSingleQuote 转义单引号
  168. func escapeSingleQuote(str string) string {
  169. return strings.ReplaceAll(str, "'", "''")
  170. }
  171. func init() {
  172. RegisterGenerator(NewMySQLGenerator())
  173. }