| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- // postgresql.go 根据//table_defintion.go 定义的表结构,编写PostgreSQL建立表和索引的代码。
- package generators
-
- import (
- "fmt"
- "strings"
- )
-
- // PostgreSQLGenerator PostgreSQL SQL生成器
- type PostgreSQLGenerator struct{}
-
- // NewPostgreSQLGenerator 创建PostgreSQL生成器实例
- func NewPostgreSQLGenerator() *PostgreSQLGenerator {
- return &PostgreSQLGenerator{}
- }
-
- func (pg *PostgreSQLGenerator) DBType() string {
- return "postgresql"
- }
-
- func (pg *PostgreSQLGenerator) TableExistsSQL(tableName string) string {
- return fmt.Sprintf(
- "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '%s')",
- strings.ToLower(tableName),
- )
- }
-
- func (pg *PostgreSQLGenerator) DropTableSQL(tableName string) string {
- return fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", tableName)
- }
-
- func (pg *PostgreSQLGenerator) GenerateCreateTable(table TableDDL) string {
- if table.Schema == nil {
- return ""
- }
-
- var sql strings.Builder
-
- // 表头
- sql.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", table.Name))
-
- // 列定义
- columns := table.Schema.Columns
- for i, col := range columns {
- sql.WriteString(fmt.Sprintf(" %s %s", col.Name, pg.getPostgreSQLType(col)))
-
- // 添加选项(转换MySQL选项到PostgreSQL)
- for _, opt := range col.Options {
- pgOpt := pg.convertOption(opt, col)
- if pgOpt != "" {
- sql.WriteString(" " + pgOpt)
- }
- }
-
- // 添加默认值
- if col.Default != "" {
- pgDefault := pg.convertDefaultValue(col.Default, col.Type)
- if pgDefault != "" {
- sql.WriteString(" DEFAULT " + pgDefault)
- }
- }
-
- // 列注释将在表创建后单独添加
-
- if i < len(columns)-1 {
- sql.WriteString(",")
- }
- sql.WriteString("\n")
- }
-
- // 添加主键约束(从列选项中提取)
- primaryKeys := pg.extractPrimaryKeys(columns)
- if len(primaryKeys) > 0 {
- sql.WriteString(fmt.Sprintf(" ,PRIMARY KEY (%s)\n", strings.Join(primaryKeys, ", ")))
- }
-
- sql.WriteString(");\n")
-
- // 添加表注释
- if table.Schema.Comment != "" {
- sql.WriteString(fmt.Sprintf("COMMENT ON TABLE %s IS '%s';\n",
- table.Name, table.Schema.Comment))
- }
-
- // 添加列注释
- for _, col := range columns {
- if col.Comment != "" {
- sql.WriteString(fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s';\n",
- table.Name, col.Name, col.Comment))
- }
- }
-
- // 添加索引(PostgreSQL中索引在表外创建)
- for _, idx := range table.Schema.Indexes {
- if idx.Unique {
- sql.WriteString(fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s (%s);\n",
- idx.Name, table.Name, strings.Join(idx.Columns, ", ")))
- } else {
- sql.WriteString(fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s);\n",
- idx.Name, table.Name, strings.Join(idx.Columns, ", ")))
- }
- }
-
- return sql.String()
- }
-
- // getPostgreSQLType 获取PostgreSQL数据类型
- func (pg *PostgreSQLGenerator) getPostgreSQLType(col ColumnSchema) string {
- switch col.Type {
- case "TINYINT":
- if col.Length == 1 {
- return "BOOLEAN"
- }
- return "SMALLINT"
- case "BOOL":
- return "BOOLEAN"
- case "DATETIME":
- return "TIMESTAMP"
- case "TIMESTAMP":
- return "TIMESTAMP"
- case "JSON":
- return "JSONB"
- case "BLOB":
- return "BYTEA"
- case "INT":
- return "INTEGER"
- case "BIGINT":
- return "BIGINT"
- case "DECIMAL":
- if col.Precision > 0 && col.Scale > 0 {
- return fmt.Sprintf("DECIMAL(%d,%d)", col.Precision, col.Scale)
- }
- return "DECIMAL"
- case "VARCHAR":
- if col.Length > 0 {
- return fmt.Sprintf("VARCHAR(%d)", col.Length)
- }
- return "VARCHAR"
- case "CHAR":
- if col.Length > 0 {
- return fmt.Sprintf("CHAR(%d)", col.Length)
- }
- return "CHAR"
- case "TEXT":
- return "TEXT"
- case "FLOAT":
- return "REAL"
- case "DOUBLE":
- return "DOUBLE PRECISION"
- case "DATE":
- return "DATE"
- case "TIME":
- return "TIME"
- default:
- return col.Type
- }
- }
-
- // convertOption 转换选项到PostgreSQL语法
- func (pg *PostgreSQLGenerator) convertOption(option string, col ColumnSchema) string {
- option = strings.ToUpper(option)
-
- switch option {
- case "NOT NULL":
- return "NOT NULL"
- case "UNIQUE":
- return "UNIQUE"
- case "AUTO_INCREMENT":
- // PostgreSQL使用SERIAL/BIGSERIAL/SMALLSERIAL
- // 已在getPostgreSQLType中处理
- return ""
- case "PRIMARY KEY":
- // 主键将在表级别定义
- return ""
- default:
- return option
- }
- }
-
- // convertDefaultValue 转换默认值
- func (pg *PostgreSQLGenerator) convertDefaultValue(value, colType string) string {
- value = strings.TrimSpace(value)
-
- // 移除引号(如果有)
- if strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'") {
- value = value[1 : len(value)-1]
- }
-
- // 处理布尔值
- if strings.EqualFold(colType, "BOOLEAN") || strings.EqualFold(colType, "BOOL") {
- switch strings.ToUpper(value) {
- case "1", "TRUE", "'TRUE'":
- return "TRUE"
- case "0", "FALSE", "'FALSE'":
- return "FALSE"
- }
- }
-
- // 处理时间戳
- if strings.Contains(strings.ToUpper(value), "CURRENT_TIMESTAMP") {
- return "CURRENT_TIMESTAMP"
- }
-
- // 处理数字
- if _, ok := pg.isNumber(value); ok {
- return value
- }
-
- // 其他情况加单引号
- return "'" + value + "'"
- }
-
- // isNumber 检查字符串是否为数字
- func (pg *PostgreSQLGenerator) isNumber(s string) (float64, bool) {
- var f float64
- _, err := fmt.Sscanf(s, "%f", &f)
- return f, err == nil
- }
-
- // extractPrimaryKeys 从列中提取主键
- func (pg *PostgreSQLGenerator) extractPrimaryKeys(columns []ColumnSchema) []string {
- var primaryKeys []string
- for _, col := range columns {
- for _, opt := range col.Options {
- if strings.ToUpper(opt) == "PRIMARY KEY" {
- primaryKeys = append(primaryKeys, col.Name)
- break
- }
- }
- }
- return primaryKeys
- }
-
- func init() {
- RegisterGenerator(NewPostgreSQLGenerator())
- }
|