// postgresql.go 根据//table_defintion.go 定义的表结构,编写PostgreSQL建立表和索引的代码。 package generators import ( "fmt" "strings" ) // PostgreSQLGenerator PostgreSQL SQL生成器 type PostgreSQLGenerator struct{} // NewPostgreSQLGenerator 创建PostgreSQL生成器实例 func NewPostgreSQLGenerator() *PostgreSQLGenerator { return &PostgreSQLGenerator{} } func (pg *PostgreSQLGenerator) DBType() string { return "postgresql" } func (pg *PostgreSQLGenerator) TableExistsSQL(tableName string) string { return fmt.Sprintf( "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '%s')", strings.ToLower(tableName), ) } func (pg *PostgreSQLGenerator) DropTableSQL(tableName string) string { return fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", tableName) } func (pg *PostgreSQLGenerator) GenerateCreateTable(table TableDDL) string { if table.Schema == nil { return "" } var sql strings.Builder // 表头 sql.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", table.Name)) // 列定义 columns := table.Schema.Columns for i, col := range columns { sql.WriteString(fmt.Sprintf(" %s %s", col.Name, pg.getPostgreSQLType(col))) // 添加选项(转换MySQL选项到PostgreSQL) for _, opt := range col.Options { pgOpt := pg.convertOption(opt, col) if pgOpt != "" { sql.WriteString(" " + pgOpt) } } // 添加默认值 if col.Default != "" { pgDefault := pg.convertDefaultValue(col.Default, col.Type) if pgDefault != "" { sql.WriteString(" DEFAULT " + pgDefault) } } // 列注释将在表创建后单独添加 if i < len(columns)-1 { sql.WriteString(",") } sql.WriteString("\n") } // 添加主键约束(从列选项中提取) primaryKeys := pg.extractPrimaryKeys(columns) if len(primaryKeys) > 0 { sql.WriteString(fmt.Sprintf(" ,PRIMARY KEY (%s)\n", strings.Join(primaryKeys, ", "))) } sql.WriteString(");\n") // 添加表注释 if table.Schema.Comment != "" { sql.WriteString(fmt.Sprintf("COMMENT ON TABLE %s IS '%s';\n", table.Name, table.Schema.Comment)) } // 添加列注释 for _, col := range columns { if col.Comment != "" { sql.WriteString(fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s';\n", table.Name, col.Name, col.Comment)) } } // 添加索引(PostgreSQL中索引在表外创建) for _, idx := range table.Schema.Indexes { if idx.Unique { sql.WriteString(fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s (%s);\n", idx.Name, table.Name, strings.Join(idx.Columns, ", "))) } else { sql.WriteString(fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s);\n", idx.Name, table.Name, strings.Join(idx.Columns, ", "))) } } return sql.String() } // getPostgreSQLType 获取PostgreSQL数据类型 func (pg *PostgreSQLGenerator) getPostgreSQLType(col ColumnSchema) string { switch col.Type { case "TINYINT": if col.Length == 1 { return "BOOLEAN" } return "SMALLINT" case "BOOL": return "BOOLEAN" case "DATETIME": return "TIMESTAMP" case "TIMESTAMP": return "TIMESTAMP" case "JSON": return "JSONB" case "BLOB": return "BYTEA" case "INT": return "INTEGER" case "BIGINT": return "BIGINT" case "DECIMAL": if col.Precision > 0 && col.Scale > 0 { return fmt.Sprintf("DECIMAL(%d,%d)", col.Precision, col.Scale) } return "DECIMAL" case "VARCHAR": if col.Length > 0 { return fmt.Sprintf("VARCHAR(%d)", col.Length) } return "VARCHAR" case "CHAR": if col.Length > 0 { return fmt.Sprintf("CHAR(%d)", col.Length) } return "CHAR" case "TEXT": return "TEXT" case "FLOAT": return "REAL" case "DOUBLE": return "DOUBLE PRECISION" case "DATE": return "DATE" case "TIME": return "TIME" default: return col.Type } } // convertOption 转换选项到PostgreSQL语法 func (pg *PostgreSQLGenerator) convertOption(option string, col ColumnSchema) string { option = strings.ToUpper(option) switch option { case "NOT NULL": return "NOT NULL" case "UNIQUE": return "UNIQUE" case "AUTO_INCREMENT": // PostgreSQL使用SERIAL/BIGSERIAL/SMALLSERIAL // 已在getPostgreSQLType中处理 return "" case "PRIMARY KEY": // 主键将在表级别定义 return "" default: return option } } // convertDefaultValue 转换默认值 func (pg *PostgreSQLGenerator) convertDefaultValue(value, colType string) string { value = strings.TrimSpace(value) // 移除引号(如果有) if strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'") { value = value[1 : len(value)-1] } // 处理布尔值 if strings.EqualFold(colType, "BOOLEAN") || strings.EqualFold(colType, "BOOL") { switch strings.ToUpper(value) { case "1", "TRUE", "'TRUE'": return "TRUE" case "0", "FALSE", "'FALSE'": return "FALSE" } } // 处理时间戳 if strings.Contains(strings.ToUpper(value), "CURRENT_TIMESTAMP") { return "CURRENT_TIMESTAMP" } // 处理数字 if _, ok := pg.isNumber(value); ok { return value } // 其他情况加单引号 return "'" + value + "'" } // isNumber 检查字符串是否为数字 func (pg *PostgreSQLGenerator) isNumber(s string) (float64, bool) { var f float64 _, err := fmt.Sscanf(s, "%f", &f) return f, err == nil } // extractPrimaryKeys 从列中提取主键 func (pg *PostgreSQLGenerator) extractPrimaryKeys(columns []ColumnSchema) []string { var primaryKeys []string for _, col := range columns { for _, opt := range col.Options { if strings.ToUpper(opt) == "PRIMARY KEY" { primaryKeys = append(primaryKeys, col.Name) break } } } return primaryKeys } func init() { RegisterGenerator(NewPostgreSQLGenerator()) }