package database import ( "fmt" "log" "sync" "git.x2erp.com/qdy/go-base/config" "git.x2erp.com/qdy/go-base/config/subconfigs" "git.x2erp.com/qdy/go-base/ctx" "git.x2erp.com/qdy/go-base/logger" "git.x2erp.com/qdy/go-base/model/response" "git.x2erp.com/qdy/go-db/driver" "git.x2erp.com/qdy/go-db/function" "github.com/jmoiron/sqlx" ) // DBFactory 数据库工厂(全局单例模式) type DBFactory struct { db *sqlx.DB config *subconfigs.DatabaseConfig } var ( instanceDB *DBFactory instanceDBOnce sync.Once initErrDB error ) // CreateDBFactory 获取数据库工厂单例 func CreateDBFactory(cfg config.IConfig) *DBFactory { config := cfg.GetDatabaseConfig() instanceDBOnce.Do(func() { instanceDB, initErrDB = createDBFactoryNew(config) }) if initErrDB != nil { log.Fatalf("DBFactory is error: '%v'", initErrDB) } return instanceDB } // createDBFactoryNew 获取数据库工厂单例 func createDBFactoryNew(config *subconfigs.DatabaseConfig) (*DBFactory, error) { if config == nil { log.Fatal("配置未初始化,请先在yaml进行配置") } if config.Type == "" { initErrDB = fmt.Errorf("database type must be configured") return nil, initErrDB } if config.Host == "" { initErrDB = fmt.Errorf("database host must be configured") return nil, initErrDB } if config.Database == "" { initErrDB = fmt.Errorf("database name must be configured") return nil, initErrDB } log.Printf("Creating database connection...") // 获取对应的驱动 dbDriver, err := driver.Get(config.Type) if err != nil { initErrDB = fmt.Errorf("failed to get database driver: %v", err) return nil, initErrDB } // 将内部 DBConfig 转换为 drivers.DBConfig driverConfig := driver.DBConfig{ Type: config.Type, Host: config.Host, Port: config.Port, Username: config.Username, Password: config.Password, Database: config.Database, MaxOpenConns: config.MaxOpenConns, MaxIdleConns: config.MaxIdleConns, ConnMaxLifetime: config.ConnMaxLifetime, } // 创建数据库连接 db, err := dbDriver.Open(driverConfig) if err != nil { initErrDB = fmt.Errorf("failed to open database connection: %v", err) return nil, initErrDB } // 测试连接 if err := function.TestConnection(db, config.Type); err != nil { db.Close() initErrDB = fmt.Errorf("database connection test failed: %v", err) return nil, initErrDB } log.Printf("DBFactory is successfully created.\n") instanceDB = &DBFactory{ db: db, config: config, } return instanceDB, initErrDB } // ========== DBFactory 实例方法 ========== // GetDB 获取数据库连接 func (f *DBFactory) GetDB() *sqlx.DB { return f.db } func (f *DBFactory) GetName() string { return "DBFactory" } // Close 关闭数据库连接 func (f *DBFactory) Close() { if f.db != nil { err := f.db.Close() if err != nil { logger.Errorf("failed to close database connection: %v", err) } log.Printf("Database connection closed gracefully\n") f.db = nil } } // GetConfig 获取配置信息 func (f *DBFactory) GetConfig() subconfigs.DatabaseConfig { return *f.config } // TestConnection 测试连接 func (f *DBFactory) TestConnection() error { return function.TestConnection(f.db, f.config.Type) } // ========== 快捷操作方法 ========== // QueryToJSON 快捷查询,直接返回 JSON 字节流 func (f *DBFactory) QueryToJSON(sql string, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] { return function.QueryToJSON(f.db, sql, reqCtx) } // QueryPositionalToJSON 位置参数查询并返回 JSON 字节数据 func (f *DBFactory) QueryPositionalToJSON(sql string, params []interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] { return function.QueryPositionalToJSON(f.db, sql, params, reqCtx) } // QueryParamsNameToJSON 命名参数查询并返回 JSON 字节数据 func (f *DBFactory) QueryParamsNameToJSON(sql string, params map[string]interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] { return function.QueryParamsNameToJSON(f.db, sql, params, reqCtx) } // QueryToCSV 快捷查询,直接返回 CSV 字符串(包含表头) func (f *DBFactory) QueryToCSV(sql string, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) { return function.QueryToCSV(f.db, sql, writerHeader, reqCtx) } // QueryPositionalToCSV 位置参数查询并返回 CSV 字节数据 func (f *DBFactory) QueryPositionalToCSV(sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) { return function.QueryPositionalToCSV(f.db, sql, writerHeader, params, reqCtx) } // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据 func (f *DBFactory) QueryParamsNameToCSV(sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) { return function.QueryParamsNameToCSV(f.db, sql, writerHeader, params, reqCtx) } // ExecuteDDL 快捷执行DDL语句 func (f *DBFactory) ExecuteDDL(ddlSQL string) error { return function.ExecuteDDL(f.db, ddlSQL) } // ExecuteDDLWithTx 快捷在事务中执行DDL语句 func (f *DBFactory) ExecuteDDLWithTx(ddlSQL string) error { return function.ExecuteDDLWithTx(f.db, ddlSQL) } // ExecuteMultipleDDL 快捷执行多个DDL语句 func (f *DBFactory) ExecuteMultipleDDL(ddlSQLs []string) error { return function.ExecuteMultipleDDL(f.db, ddlSQLs) } // GetDBType 得到当前使用数据库类型 func (f *DBFactory) GetDBType() string { return f.config.Type } // GetDatabaseName 获取数据库名称 func (f *DBFactory) GetDatabaseName() string { return f.config.Database } // GetHost 获取数据库主机 func (f *DBFactory) GetHost() string { return f.config.Host } // GetPort 获取数据库端口 func (f *DBFactory) GetPort() int { return f.config.Port } // BeginTx 开始事务 func (f *DBFactory) BeginTx() (*sqlx.Tx, error) { return f.db.Beginx() } // GetStats 获取数据库连接统计信息 func (f *DBFactory) GetStats() interface{} { return f.db.Stats() } // Ping 测试数据库连接是否正常 func (f *DBFactory) Ping() error { return f.db.Ping() } // GetAvailableDrivers 获取可用的数据库驱动 func (f *DBFactory) GetAvailableDrivers() []string { return driver.GetAllDrivers() } // ========== 新增的简化操作方法 ========== // QueryOne 查询单条记录 func (f *DBFactory) QueryOne(sql string, dest interface{}) error { return f.db.Get(dest, sql) } // QueryOneWithParams 带参数查询单条记录 func (f *DBFactory) QueryOneWithParams(sql string, dest interface{}, params ...interface{}) error { return f.db.Get(dest, sql, params...) } // QueryMany 查询多条记录 func (f *DBFactory) QueryMany(sql string, dest interface{}) error { return f.db.Select(dest, sql) } // QueryManyWithParams 带参数查询多条记录 func (f *DBFactory) QueryManyWithParams(sql string, dest interface{}, params ...interface{}) error { return f.db.Select(dest, sql, params...) } // Execute 执行更新操作 func (f *DBFactory) Execute(sql string) (int64, error) { result, err := f.db.Exec(sql) if err != nil { return 0, err } return result.RowsAffected() } // ExecuteWithParams 带参数执行更新操作 func (f *DBFactory) ExecuteWithParams(sql string, params ...interface{}) (int64, error) { result, err := f.db.Exec(sql, params...) if err != nil { return 0, err } return result.RowsAffected() } // QueryMap 查询单条记录到map func (f *DBFactory) QueryMap(sql string) (map[string]interface{}, error) { result := make(map[string]interface{}) err := f.db.QueryRowx(sql).MapScan(result) return result, err } // QueryMapWithParams 带参数查询单条记录到map func (f *DBFactory) QueryMapWithParams(sql string, params ...interface{}) (map[string]interface{}, error) { result := make(map[string]interface{}) err := f.db.QueryRowx(sql, params...).MapScan(result) return result, err } // QuerySliceMap 查询多条记录到map切片 func (f *DBFactory) QuerySliceMap(sql string) ([]map[string]interface{}, error) { rows, err := f.db.Queryx(sql) if err != nil { return nil, err } defer rows.Close() var results []map[string]interface{} for rows.Next() { result := make(map[string]interface{}) if err := rows.MapScan(result); err != nil { return nil, err } results = append(results, result) } return results, nil } // QuerySliceMapWithParams 带参数查询多条记录到map切片 func (f *DBFactory) QuerySliceMapWithParams(sql string, params ...interface{}) ([]map[string]interface{}, error) { rows, err := f.db.Queryx(sql, params...) if err != nil { return nil, err } defer rows.Close() var results []map[string]interface{} for rows.Next() { result := make(map[string]interface{}) if err := rows.MapScan(result); err != nil { return nil, err } results = append(results, result) } return results, nil }