| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- package database
-
- import (
- "fmt"
- "log"
- "sync"
-
- "git.x2erp.com/qdy/go-base/config"
- "git.x2erp.com/qdy/go-base/config/subconfigs"
- "git.x2erp.com/qdy/go-base/ctx"
- "git.x2erp.com/qdy/go-base/logger"
- "git.x2erp.com/qdy/go-base/model/response"
- "git.x2erp.com/qdy/go-db/drivers"
- "git.x2erp.com/qdy/go-db/functions"
-
- "github.com/jmoiron/sqlx"
- )
-
- // DBFactory 数据库工厂(全局单例模式)
- type DBFactory struct {
- db *sqlx.DB
- config *subconfigs.DatabaseConfig
- }
-
- var (
- instanceDB *DBFactory
- instanceDBOnce sync.Once
- initErrDB error
- )
-
- // CreateDBFactory 获取数据库工厂单例
- func CreateDBFactory(cfg config.IConfig) *DBFactory {
- config := cfg.GetDatabaseConfig()
- instanceDBOnce.Do(func() {
-
- if config == nil {
- log.Fatal("配置未初始化,请先在yaml进行配置")
- }
-
- // 设置默认值
- if config.MaxOpenConns == 0 {
- config.MaxOpenConns = 100
- }
- if config.MaxIdleConns == 0 {
- config.MaxIdleConns = 10
- }
- if config.ConnMaxLifetime == 0 {
- config.ConnMaxLifetime = 5 * 60 // 5分钟,单位秒
- }
-
- // 验证配置
- if config.Type == "" {
- initErrDB = fmt.Errorf("database type must be configured")
- return
- }
- if config.Host == "" {
- initErrDB = fmt.Errorf("database host must be configured")
- return
- }
- if config.Database == "" {
- initErrDB = fmt.Errorf("database name must be configured")
- return
- }
-
- log.Printf("Creating database connection...")
-
- // 获取对应的驱动
- dbDriver, err := drivers.Get(config.Type)
- if err != nil {
- initErrDB = fmt.Errorf("failed to get database driver: %v", err)
- return
- }
-
- // 将内部 DBConfig 转换为 drivers.DBConfig
- driverConfig := drivers.DBConfig{
- Type: config.Type,
- Host: config.Host,
- Port: config.Port,
- Username: config.Username,
- Password: config.Password,
- Database: config.Database,
- MaxOpenConns: config.MaxOpenConns,
- MaxIdleConns: config.MaxIdleConns,
- ConnMaxLifetime: config.ConnMaxLifetime,
- }
-
- // 创建数据库连接
- db, err := dbDriver.Open(driverConfig)
- if err != nil {
- initErrDB = fmt.Errorf("failed to open database connection: %v", err)
- return
- }
-
- // 测试连接
- if err := functions.TestConnection(db, config.Type); err != nil {
- db.Close()
- initErrDB = fmt.Errorf("database connection test failed: %v", err)
- return
- }
-
- log.Printf("DBFactory is successfully created.\n")
-
- instanceDB = &DBFactory{
- db: db,
- config: config,
- }
- })
-
- if initErrDB != nil {
- log.Fatalf("DBFactory is error: '%v'", initErrDB)
- }
-
- return instanceDB
- }
-
- // ========== DBFactory 实例方法 ==========
-
- // GetDB 获取数据库连接
- func (f *DBFactory) GetDB() *sqlx.DB {
- return f.db
- }
-
- func (f *DBFactory) GetName() string {
- return "DBFactory"
- }
-
- // Close 关闭数据库连接
- func (f *DBFactory) Close() {
- if f.db != nil {
- err := f.db.Close()
- if err != nil {
- logger.Errorf("failed to close database connection: %v", err)
- }
- log.Printf("Database connection closed gracefully\n")
- f.db = nil
- }
- }
-
- // GetConfig 获取配置信息
- func (f *DBFactory) GetConfig() subconfigs.DatabaseConfig {
- return *f.config
- }
-
- // TestConnection 测试连接
- func (f *DBFactory) TestConnection() error {
- return functions.TestConnection(f.db, f.config.Type)
- }
-
- // ========== 快捷操作方法 ==========
-
- // QueryToJSON 快捷查询,直接返回 JSON 字节流
- func (f *DBFactory) QueryToJSON(sql string, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
- return functions.QueryToJSON(f.db, sql, reqCtx)
- }
-
- // QueryPositionalToJSON 位置参数查询并返回 JSON 字节数据
- func (f *DBFactory) QueryPositionalToJSON(sql string, params []interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
- return functions.QueryPositionalToJSON(f.db, sql, params, reqCtx)
- }
-
- // QueryParamsNameToJSON 命名参数查询并返回 JSON 字节数据
- func (f *DBFactory) QueryParamsNameToJSON(sql string, params map[string]interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
- return functions.QueryParamsNameToJSON(f.db, sql, params, reqCtx)
- }
-
- // QueryToCSV 快捷查询,直接返回 CSV 字符串(包含表头)
- func (f *DBFactory) QueryToCSV(sql string, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) {
- return functions.QueryToCSV(f.db, sql, writerHeader, reqCtx)
- }
-
- // QueryPositionalToCSV 位置参数查询并返回 CSV 字节数据
- func (f *DBFactory) QueryPositionalToCSV(sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
- return functions.QueryPositionalToCSV(f.db, sql, writerHeader, params, reqCtx)
- }
-
- // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
- func (f *DBFactory) QueryParamsNameToCSV(sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
- return functions.QueryParamsNameToCSV(f.db, sql, writerHeader, params, reqCtx)
- }
-
- // ExecuteDDL 快捷执行DDL语句
- func (f *DBFactory) ExecuteDDL(ddlSQL string) error {
- return functions.ExecuteDDL(f.db, ddlSQL)
- }
-
- // ExecuteDDLWithTx 快捷在事务中执行DDL语句
- func (f *DBFactory) ExecuteDDLWithTx(ddlSQL string) error {
- return functions.ExecuteDDLWithTx(f.db, ddlSQL)
- }
-
- // ExecuteMultipleDDL 快捷执行多个DDL语句
- func (f *DBFactory) ExecuteMultipleDDL(ddlSQLs []string) error {
- return functions.ExecuteMultipleDDL(f.db, ddlSQLs)
- }
-
- // GetDBType 得到当前使用数据库类型
- func (f *DBFactory) GetDBType() string {
- return f.config.Type
- }
-
- // GetDatabaseName 获取数据库名称
- func (f *DBFactory) GetDatabaseName() string {
- return f.config.Database
- }
-
- // GetHost 获取数据库主机
- func (f *DBFactory) GetHost() string {
- return f.config.Host
- }
-
- // GetPort 获取数据库端口
- func (f *DBFactory) GetPort() int {
- return f.config.Port
- }
-
- // BeginTx 开始事务
- func (f *DBFactory) BeginTx() (*sqlx.Tx, error) {
- return f.db.Beginx()
- }
-
- // GetStats 获取数据库连接统计信息
- func (f *DBFactory) GetStats() interface{} {
- return f.db.Stats()
- }
-
- // Ping 测试数据库连接是否正常
- func (f *DBFactory) Ping() error {
- return f.db.Ping()
- }
-
- // GetAvailableDrivers 获取可用的数据库驱动
- func (f *DBFactory) GetAvailableDrivers() []string {
- return drivers.GetAllDrivers()
- }
-
- // ========== 新增的简化操作方法 ==========
-
- // QueryOne 查询单条记录
- func (f *DBFactory) QueryOne(sql string, dest interface{}) error {
- return f.db.Get(dest, sql)
- }
-
- // QueryOneWithParams 带参数查询单条记录
- func (f *DBFactory) QueryOneWithParams(sql string, dest interface{}, params ...interface{}) error {
- return f.db.Get(dest, sql, params...)
- }
-
- // QueryMany 查询多条记录
- func (f *DBFactory) QueryMany(sql string, dest interface{}) error {
- return f.db.Select(dest, sql)
- }
-
- // QueryManyWithParams 带参数查询多条记录
- func (f *DBFactory) QueryManyWithParams(sql string, dest interface{}, params ...interface{}) error {
- return f.db.Select(dest, sql, params...)
- }
-
- // Execute 执行更新操作
- func (f *DBFactory) Execute(sql string) (int64, error) {
- result, err := f.db.Exec(sql)
- if err != nil {
- return 0, err
- }
- return result.RowsAffected()
- }
-
- // ExecuteWithParams 带参数执行更新操作
- func (f *DBFactory) ExecuteWithParams(sql string, params ...interface{}) (int64, error) {
- result, err := f.db.Exec(sql, params...)
- if err != nil {
- return 0, err
- }
- return result.RowsAffected()
- }
-
- // QueryMap 查询单条记录到map
- func (f *DBFactory) QueryMap(sql string) (map[string]interface{}, error) {
- result := make(map[string]interface{})
- err := f.db.QueryRowx(sql).MapScan(result)
- return result, err
- }
-
- // QueryMapWithParams 带参数查询单条记录到map
- func (f *DBFactory) QueryMapWithParams(sql string, params ...interface{}) (map[string]interface{}, error) {
- result := make(map[string]interface{})
- err := f.db.QueryRowx(sql, params...).MapScan(result)
- return result, err
- }
-
- // QuerySliceMap 查询多条记录到map切片
- func (f *DBFactory) QuerySliceMap(sql string) ([]map[string]interface{}, error) {
- rows, err := f.db.Queryx(sql)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
-
- var results []map[string]interface{}
- for rows.Next() {
- result := make(map[string]interface{})
- if err := rows.MapScan(result); err != nil {
- return nil, err
- }
- results = append(results, result)
- }
- return results, nil
- }
-
- // QuerySliceMapWithParams 带参数查询多条记录到map切片
- func (f *DBFactory) QuerySliceMapWithParams(sql string, params ...interface{}) ([]map[string]interface{}, error) {
- rows, err := f.db.Queryx(sql, params...)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
-
- var results []map[string]interface{}
- for rows.Next() {
- result := make(map[string]interface{})
- if err := rows.MapScan(result); err != nil {
- return nil, err
- }
- results = append(results, result)
- }
- return results, nil
- }
|