Без опису
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. // postgresql.go 根据//table_defintion.go 定义的表结构,编写PostgreSQL建立表和索引的代码。
  2. package generators
  3. import (
  4. "fmt"
  5. "strings"
  6. )
  7. // PostgreSQLGenerator PostgreSQL SQL生成器
  8. type PostgreSQLGenerator struct{}
  9. // NewPostgreSQLGenerator 创建PostgreSQL生成器实例
  10. func NewPostgreSQLGenerator() *PostgreSQLGenerator {
  11. return &PostgreSQLGenerator{}
  12. }
  13. func (pg *PostgreSQLGenerator) DBType() string {
  14. return "postgresql"
  15. }
  16. func (pg *PostgreSQLGenerator) TableExistsSQL(tableName string) string {
  17. return fmt.Sprintf(
  18. "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '%s')",
  19. strings.ToLower(tableName),
  20. )
  21. }
  22. func (pg *PostgreSQLGenerator) DropTableSQL(tableName string) string {
  23. return fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", tableName)
  24. }
  25. func (pg *PostgreSQLGenerator) GenerateCreateTable(table TableDDL) string {
  26. if table.Schema == nil {
  27. return ""
  28. }
  29. var sql strings.Builder
  30. // 表头
  31. sql.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", table.Name))
  32. // 列定义
  33. columns := table.Schema.Columns
  34. for i, col := range columns {
  35. sql.WriteString(fmt.Sprintf(" %s %s", col.Name, pg.getPostgreSQLType(col)))
  36. // 添加选项(转换MySQL选项到PostgreSQL)
  37. for _, opt := range col.Options {
  38. pgOpt := pg.convertOption(opt, col)
  39. if pgOpt != "" {
  40. sql.WriteString(" " + pgOpt)
  41. }
  42. }
  43. // 添加默认值
  44. if col.Default != "" {
  45. pgDefault := pg.convertDefaultValue(col.Default, col.Type)
  46. if pgDefault != "" {
  47. sql.WriteString(" DEFAULT " + pgDefault)
  48. }
  49. }
  50. // 列注释将在表创建后单独添加
  51. if i < len(columns)-1 {
  52. sql.WriteString(",")
  53. }
  54. sql.WriteString("\n")
  55. }
  56. // 添加主键约束(从列选项中提取)
  57. primaryKeys := pg.extractPrimaryKeys(columns)
  58. if len(primaryKeys) > 0 {
  59. sql.WriteString(fmt.Sprintf(" ,PRIMARY KEY (%s)\n", strings.Join(primaryKeys, ", ")))
  60. }
  61. sql.WriteString(");\n")
  62. // 添加表注释
  63. if table.Schema.Comment != "" {
  64. sql.WriteString(fmt.Sprintf("COMMENT ON TABLE %s IS '%s';\n",
  65. table.Name, table.Schema.Comment))
  66. }
  67. // 添加列注释
  68. for _, col := range columns {
  69. if col.Comment != "" {
  70. sql.WriteString(fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s';\n",
  71. table.Name, col.Name, col.Comment))
  72. }
  73. }
  74. // 添加索引(PostgreSQL中索引在表外创建)
  75. for _, idx := range table.Schema.Indexes {
  76. if idx.Unique {
  77. sql.WriteString(fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s (%s);\n",
  78. idx.Name, table.Name, strings.Join(idx.Columns, ", ")))
  79. } else {
  80. sql.WriteString(fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s);\n",
  81. idx.Name, table.Name, strings.Join(idx.Columns, ", ")))
  82. }
  83. }
  84. return sql.String()
  85. }
  86. // getPostgreSQLType 获取PostgreSQL数据类型
  87. func (pg *PostgreSQLGenerator) getPostgreSQLType(col ColumnSchema) string {
  88. switch col.Type {
  89. case "TINYINT":
  90. if col.Length == 1 {
  91. return "BOOLEAN"
  92. }
  93. return "SMALLINT"
  94. case "BOOL":
  95. return "BOOLEAN"
  96. case "DATETIME":
  97. return "TIMESTAMP"
  98. case "TIMESTAMP":
  99. return "TIMESTAMP"
  100. case "JSON":
  101. return "JSONB"
  102. case "BLOB":
  103. return "BYTEA"
  104. case "INT":
  105. return "INTEGER"
  106. case "BIGINT":
  107. return "BIGINT"
  108. case "DECIMAL":
  109. if col.Precision > 0 && col.Scale > 0 {
  110. return fmt.Sprintf("DECIMAL(%d,%d)", col.Precision, col.Scale)
  111. }
  112. return "DECIMAL"
  113. case "VARCHAR":
  114. if col.Length > 0 {
  115. return fmt.Sprintf("VARCHAR(%d)", col.Length)
  116. }
  117. return "VARCHAR"
  118. case "CHAR":
  119. if col.Length > 0 {
  120. return fmt.Sprintf("CHAR(%d)", col.Length)
  121. }
  122. return "CHAR"
  123. case "TEXT":
  124. return "TEXT"
  125. case "FLOAT":
  126. return "REAL"
  127. case "DOUBLE":
  128. return "DOUBLE PRECISION"
  129. case "DATE":
  130. return "DATE"
  131. case "TIME":
  132. return "TIME"
  133. default:
  134. return col.Type
  135. }
  136. }
  137. // convertOption 转换选项到PostgreSQL语法
  138. func (pg *PostgreSQLGenerator) convertOption(option string, col ColumnSchema) string {
  139. option = strings.ToUpper(option)
  140. switch option {
  141. case "NOT NULL":
  142. return "NOT NULL"
  143. case "UNIQUE":
  144. return "UNIQUE"
  145. case "AUTO_INCREMENT":
  146. // PostgreSQL使用SERIAL/BIGSERIAL/SMALLSERIAL
  147. // 已在getPostgreSQLType中处理
  148. return ""
  149. case "PRIMARY KEY":
  150. // 主键将在表级别定义
  151. return ""
  152. default:
  153. return option
  154. }
  155. }
  156. // convertDefaultValue 转换默认值
  157. func (pg *PostgreSQLGenerator) convertDefaultValue(value, colType string) string {
  158. value = strings.TrimSpace(value)
  159. // 移除引号(如果有)
  160. if strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'") {
  161. value = value[1 : len(value)-1]
  162. }
  163. // 处理布尔值
  164. if strings.EqualFold(colType, "BOOLEAN") || strings.EqualFold(colType, "BOOL") {
  165. switch strings.ToUpper(value) {
  166. case "1", "TRUE", "'TRUE'":
  167. return "TRUE"
  168. case "0", "FALSE", "'FALSE'":
  169. return "FALSE"
  170. }
  171. }
  172. // 处理时间戳
  173. if strings.Contains(strings.ToUpper(value), "CURRENT_TIMESTAMP") {
  174. return "CURRENT_TIMESTAMP"
  175. }
  176. // 处理数字
  177. if _, ok := pg.isNumber(value); ok {
  178. return value
  179. }
  180. // 其他情况加单引号
  181. return "'" + value + "'"
  182. }
  183. // isNumber 检查字符串是否为数字
  184. func (pg *PostgreSQLGenerator) isNumber(s string) (float64, bool) {
  185. var f float64
  186. _, err := fmt.Sscanf(s, "%f", &f)
  187. return f, err == nil
  188. }
  189. // extractPrimaryKeys 从列中提取主键
  190. func (pg *PostgreSQLGenerator) extractPrimaryKeys(columns []ColumnSchema) []string {
  191. var primaryKeys []string
  192. for _, col := range columns {
  193. for _, opt := range col.Options {
  194. if strings.ToUpper(opt) == "PRIMARY KEY" {
  195. primaryKeys = append(primaryKeys, col.Name)
  196. break
  197. }
  198. }
  199. }
  200. return primaryKeys
  201. }
  202. func init() {
  203. RegisterGenerator(NewPostgreSQLGenerator())
  204. }