瀏覽代碼

修改工厂,实现单例模式

qdy 3 月之前
父節點
當前提交
181cef278c
共有 12 個檔案被更改,包括 218 行新增219 行删除
  1. 4
    3
      drivers/driver.go
  2. 3
    3
      drivers/mysql.go
  3. 3
    3
      drivers/oracle.go
  4. 3
    3
      drivers/postgres.go
  5. 3
    3
      drivers/sqlserver.go
  6. 122
    117
      factory/db_factory.go
  7. 6
    14
      functions/execute.go
  8. 15
    28
      functions/query.go
  9. 33
    0
      functions/testConnection.go
  10. 7
    1
      go.mod
  11. 4
    0
      go.sum
  12. 15
    44
      test.go

+ 4
- 3
drivers/driver.go 查看文件

@@ -1,9 +1,10 @@
1 1
 package drivers
2 2
 
3 3
 import (
4
-	"database/sql"
5 4
 	"fmt"
6 5
 	"time" // 添加 time 包导入
6
+
7
+	"github.com/jmoiron/sqlx"
7 8
 )
8 9
 
9 10
 // DBConfig 数据库配置
@@ -24,7 +25,7 @@ type DBDriver interface {
24 25
 	// Name 返回驱动名称
25 26
 	Name() string
26 27
 	// Open 打开数据库连接
27
-	Open(config DBConfig) (*sql.DB, error)
28
+	Open(config DBConfig) (*sqlx.DB, error)
28 29
 	// BuildDSN 构建连接字符串
29 30
 	BuildDSN(config DBConfig) string
30 31
 }
@@ -62,7 +63,7 @@ func GetAllDrivers() []string {
62 63
 }
63 64
 
64 65
 // configureConnectionPool 配置连接池的公共函数
65
-func configureConnectionPool(db *sql.DB, config DBConfig) {
66
+func configureConnectionPool(db *sqlx.DB, config DBConfig) {
66 67
 	if config.MaxOpenConns > 0 {
67 68
 		db.SetMaxOpenConns(config.MaxOpenConns)
68 69
 	}

+ 3
- 3
drivers/mysql.go 查看文件

@@ -1,10 +1,10 @@
1 1
 package drivers
2 2
 
3 3
 import (
4
-	"database/sql"
5 4
 	"fmt"
6 5
 
7 6
 	_ "github.com/go-sql-driver/mysql"
7
+	"github.com/jmoiron/sqlx"
8 8
 )
9 9
 
10 10
 type MySQLDriver struct{}
@@ -22,9 +22,9 @@ func (d *MySQLDriver) BuildDSN(config DBConfig) string {
22 22
 		config.Database)
23 23
 }
24 24
 
25
-func (d *MySQLDriver) Open(config DBConfig) (*sql.DB, error) {
25
+func (d *MySQLDriver) Open(config DBConfig) (*sqlx.DB, error) {
26 26
 	dsn := d.BuildDSN(config)
27
-	db, err := sql.Open("mysql", dsn)
27
+	db, err := sqlx.Open("mysql", dsn)
28 28
 	if err != nil {
29 29
 		return nil, err
30 30
 	}

+ 3
- 3
drivers/oracle.go 查看文件

@@ -1,9 +1,9 @@
1 1
 package drivers
2 2
 
3 3
 import (
4
-	"database/sql"
5 4
 	"fmt"
6 5
 
6
+	"github.com/jmoiron/sqlx"
7 7
 	_ "github.com/sijms/go-ora/v2"
8 8
 )
9 9
 
@@ -22,9 +22,9 @@ func (d *OracleDriver) BuildDSN(config DBConfig) string {
22 22
 		config.Database)
23 23
 }
24 24
 
25
-func (d *OracleDriver) Open(config DBConfig) (*sql.DB, error) {
25
+func (d *OracleDriver) Open(config DBConfig) (*sqlx.DB, error) {
26 26
 	dsn := d.BuildDSN(config)
27
-	db, err := sql.Open("oracle", dsn)
27
+	db, err := sqlx.Open("oracle", dsn)
28 28
 	if err != nil {
29 29
 		return nil, err
30 30
 	}

+ 3
- 3
drivers/postgres.go 查看文件

@@ -1,9 +1,9 @@
1 1
 package drivers
2 2
 
3 3
 import (
4
-	"database/sql"
5 4
 	"fmt"
6 5
 
6
+	"github.com/jmoiron/sqlx"
7 7
 	_ "github.com/lib/pq"
8 8
 )
9 9
 
@@ -22,9 +22,9 @@ func (d *PostgresDriver) BuildDSN(config DBConfig) string {
22 22
 		config.Database)
23 23
 }
24 24
 
25
-func (d *PostgresDriver) Open(config DBConfig) (*sql.DB, error) {
25
+func (d *PostgresDriver) Open(config DBConfig) (*sqlx.DB, error) {
26 26
 	dsn := d.BuildDSN(config)
27
-	db, err := sql.Open("postgres", dsn)
27
+	db, err := sqlx.Open("postgres", dsn)
28 28
 	if err != nil {
29 29
 		return nil, err
30 30
 	}

+ 3
- 3
drivers/sqlserver.go 查看文件

@@ -1,9 +1,9 @@
1 1
 package drivers
2 2
 
3 3
 import (
4
-	"database/sql"
5 4
 	"fmt"
6 5
 
6
+	"github.com/jmoiron/sqlx"
7 7
 	_ "github.com/microsoft/go-mssqldb"
8 8
 )
9 9
 
@@ -22,9 +22,9 @@ func (d *SQLServerDriver) BuildDSN(config DBConfig) string {
22 22
 		config.Database)
23 23
 }
24 24
 
25
-func (d *SQLServerDriver) Open(config DBConfig) (*sql.DB, error) {
25
+func (d *SQLServerDriver) Open(config DBConfig) (*sqlx.DB, error) {
26 26
 	dsn := d.BuildDSN(config)
27
-	db, err := sql.Open("sqlserver", dsn)
27
+	db, err := sqlx.Open("sqlserver", dsn)
28 28
 	if err != nil {
29 29
 		return nil, err
30 30
 	}

+ 122
- 117
factory/db_factory.go 查看文件

@@ -1,162 +1,167 @@
1 1
 package factory
2 2
 
3 3
 import (
4
-	"database/sql"
5 4
 	"fmt"
6 5
 	"io"
6
+	"sync"
7 7
 
8 8
 	"git.x2erp.com/qdy/go-base/config"
9 9
 	"git.x2erp.com/qdy/go-base/types"
10 10
 	"git.x2erp.com/qdy/go-db/drivers"
11
+	"git.x2erp.com/qdy/go-db/functions"
12
+
13
+	"github.com/jmoiron/sqlx"
11 14
 )
12 15
 
13
-// DBFactory 数据库工厂
14 16
 type DBFactory struct {
15
-	config config.IConfig
16
-}
17
-
18
-// NewDBFactory 创建数据库工厂
19
-func NewDBFactory() (*DBFactory, error) {
20
-	// 使用配置单例 哪里都可以直接使用
21
-	cfg := config.GetConfig()
22
-
23
-	// 检查配置初始化是否有错误
24
-	if err := config.GetInitError(); err != nil {
25
-		return nil, fmt.Errorf("failed to load config: %v", err)
26
-	}
27
-
28
-	// 检查数据库配置是否完整
29
-	if !cfg.IsDatabaseConfigured() {
30
-		return nil, fmt.Errorf("database configuration is incomplete")
31
-	}
32
-
33
-	return &DBFactory{config: cfg}, nil
17
+	db *sqlx.DB
34 18
 }
35 19
 
36
-// CreateDB 创建数据库连接
37
-func (f *DBFactory) CreateDB() (*sql.DB, error) {
38
-	dbConfig := f.config.GetDatabase()
39
-	dbType := dbConfig.Type
40
-
41
-	// 获取对应的驱动
42
-	dbDriver, err := drivers.Get(dbType)
43
-	if err != nil {
44
-		return nil, fmt.Errorf("failed to get database driver: %v", err)
45
-	}
20
+var (
21
+	instanceDBFactory *DBFactory
22
+	once              sync.Once
23
+)
46 24
 
47
-	// 将内部 DBConfig 转换为 drivers.DBConfig
48
-	driverConfig := drivers.DBConfig{
49
-		Type:            dbConfig.Type,
50
-		Host:            dbConfig.Host,
51
-		Port:            dbConfig.Port,
52
-		Username:        dbConfig.Username,
53
-		Password:        dbConfig.Password,
54
-		Database:        dbConfig.Database,
55
-		MaxOpenConns:    dbConfig.MaxOpenConns,
56
-		MaxIdleConns:    dbConfig.MaxIdleConns,
57
-		ConnMaxLifetime: dbConfig.ConnMaxLifetime,
25
+// GetDBFactory 创建数据库工厂单例
26
+func GetDBFactory() (*DBFactory, error) {
27
+	var initErr error
28
+	var msg = "DBFactory instance retrieved from memory.\n"
29
+
30
+	once.Do(func() {
31
+		// 使用配置单例
32
+		cfg := config.GetConfig()
33
+
34
+		// 检查配置初始化是否有错误
35
+		if err := config.GetInitError(); err != nil {
36
+			initErr = fmt.Errorf("failed to load config: %v", err)
37
+			return
38
+		}
39
+
40
+		// 检查数据库配置是否完整
41
+		if !cfg.IsDatabaseConfigured() {
42
+			initErr = fmt.Errorf("database configuration is incomplete")
43
+			return
44
+		}
45
+
46
+		// 显示所支持的数据库驱动
47
+		driversStr := drivers.GetAllDrivers()
48
+		fmt.Printf("Available database drivers: %v\n", driversStr)
49
+
50
+		dbConfig := cfg.GetDatabase()
51
+		dbType := dbConfig.Type
52
+
53
+		// 获取对应的驱动
54
+		dbDriver, err := drivers.Get(dbType)
55
+		if err != nil {
56
+			initErr = fmt.Errorf("failed to get database driver: %v", err)
57
+			return
58
+		}
59
+
60
+		// 将内部 DBConfig 转换为 drivers.DBConfig
61
+		driverConfig := drivers.DBConfig{
62
+			Type:            dbConfig.Type,
63
+			Host:            dbConfig.Host,
64
+			Port:            dbConfig.Port,
65
+			Username:        dbConfig.Username,
66
+			Password:        dbConfig.Password,
67
+			Database:        dbConfig.Database,
68
+			MaxOpenConns:    dbConfig.MaxOpenConns,
69
+			MaxIdleConns:    dbConfig.MaxIdleConns,
70
+			ConnMaxLifetime: dbConfig.ConnMaxLifetime,
71
+		}
72
+
73
+		// 创建数据库连接
74
+		db, err := dbDriver.Open(driverConfig)
75
+		if err != nil {
76
+			initErr = fmt.Errorf("failed to open database connection: %v", err)
77
+			return
78
+		}
79
+
80
+		// 测试连接
81
+		if err := functions.TestConnection(db, dbType); err != nil {
82
+			db.Close()
83
+			initErr = fmt.Errorf("database connection test failed: %v", err)
84
+			return
85
+		}
86
+
87
+		msg = "DBFactory is successfully created.\n"
88
+		instanceDBFactory = &DBFactory{db: db}
89
+	})
90
+
91
+	if initErr != nil {
92
+		return nil, initErr
58 93
 	}
59 94
 
60
-	// 创建数据库连接
61
-	db, err := dbDriver.Open(driverConfig)
62
-	if err != nil {
63
-		return nil, fmt.Errorf("failed to open database connection: %v", err)
64
-	}
95
+	fmt.Print(msg)
65 96
 
66
-	return db, nil
97
+	return instanceDBFactory, nil
67 98
 }
68 99
 
69
-// GetConfig 获取配置信息
70
-func (f *DBFactory) GetConfig() config.IConfig {
71
-	return f.config
100
+// GetDB 获取数据库连接(线程安全)
101
+func (f *DBFactory) GetDB() *sqlx.DB {
102
+	return f.db
72 103
 }
73 104
 
74
-// GetAvailableDrivers 获取可用的数据库驱动
75
-func (f *DBFactory) GetAvailableDrivers() []string {
76
-	return drivers.GetAllDrivers()
105
+// Close 关闭数据库连接
106
+func (f *DBFactory) Close() error {
107
+	if f.db != nil {
108
+		err := f.db.Close()
109
+		f.db = nil
110
+		return err
111
+	}
112
+	return nil
77 113
 }
78 114
 
79
-// CreateQueryExecutor 创建查询执行器
80
-// func (f *DBFactory) CreateQueryExecutor(db *sql.DB) *QueryExecutor {
81
-// 	return newQueryExecutor(db)
82
-// }
83
-
84
-// -------------- 对外暴露的初始化方法(核心入口)--------------
85
-// NewDBQuery 初始化查询实例(对外提供唯一初始化入口)
86
-// db: 已初始化的数据库连接(由调用方传入,解耦数据库配置)
87
-func newDBQuery(db *sql.DB) *queryExecutor {
88
-	return newQueryExecutor(db)
115
+// QueryToJSON 快捷查询,直接返回 JSON 字节流
116
+func (f *DBFactory) QueryToJSON(sql string) *types.QueryResult {
117
+	return functions.QueryToJSON(f.db, sql)
89 118
 }
90 119
 
91
-// QuickQueryToJSON 快捷查询,直接返回 JSON 字节流
92
-func QueryToJSON(db *sql.DB, sql string) *types.QueryResult {
93
-	return newDBQuery(db).queryToJSON(sql)
120
+// QueryToCSV 快捷查询,直接返回 CSV 字符串(包含表头)
121
+func (f *DBFactory) QueryToCSV(sql string) ([]byte, error) {
122
+	return functions.QueryToCSV(f.db, sql)
94 123
 }
95 124
 
96
-// QuickQueryToCSV 快捷查询,直接返回 CSV 字符串(包含表头)
97
-func QueryToCSV(db *sql.DB, sql string) ([]byte, error) {
98
-	return newDBQuery(db).queryToCSV(sql)
125
+// QueryWithColumns 快捷查询,返回完整结果(含列信息
126
+func (f *DBFactory) QueryWithColumns(sql string) *types.QueryResult {
127
+	return functions.QueryWithColumns(f.db, sql)
99 128
 }
100 129
 
101
-// QuickExecuteWithColumns 快捷查询,返回完整结果(含列信息)
102
-func QueryWithColumns(db *sql.DB, sql string) *types.QueryResult {
103
-	return newDBQuery(db).queryWithColumns(sql)
130
+// QueryDataOnly 快捷查询,返回纯数据(性能优先
131
+func (f *DBFactory) QueryDataOnly(sql string) *types.QueryResult {
132
+	return functions.QueryDataOnly(f.db, sql)
104 133
 }
105 134
 
106
-// QuickExecuteDataOnly 快捷查询,返回纯数据(性能优先)
107
-func QueryDataOnly(db *sql.DB, sql string) *types.QueryResult {
108
-	return newDBQuery(db).queryDataOnly(sql)
135
+// QueryCSV 快捷查询,返回 CSV 格式结果(支持自定义是否包含表头
136
+func (f *DBFactory) QueryCSV(sql string, includeHeader bool) *types.QueryResult {
137
+	return functions.QueryCSV(f.db, sql, includeHeader)
109 138
 }
110 139
 
111
-// QuickExecuteCSV 快捷查询,返回 CSV 格式结果(支持自定义是否包含表头)
112
-func QueryCSV(db *sql.DB, sql string, includeHeader bool) *types.QueryResult {
113
-	return newDBQuery(db).queryCSV(sql, includeHeader)
140
+// QueryCSVStream 快捷流式输出 CSV(直接写入 io.Writer,适合大文件
141
+func (f *DBFactory) QueryCSVStream(sql string, w io.Writer, includeHeader bool) (int, error) {
142
+	return functions.QueryCSVStream(f.db, sql, w, includeHeader)
114 143
 }
115 144
 
116
-// QuickExecuteCSVStream 快捷流式输出 CSV(直接写入 io.Writer,适合大文件)
117
-func QueryCSVStream(db *sql.DB, sql string, w io.Writer, includeHeader bool) (int, error) {
118
-	return newDBQuery(db).queryCSVStream(sql, w, includeHeader)
145
+// ExecuteDDL 快捷执行DDL语句
146
+func (f *DBFactory) ExecuteDDL(ddlSQL string) error {
147
+	return functions.ExecuteDDL(f.db, ddlSQL)
119 148
 }
120 149
 
121
-// QuickExecuteDDL 快捷执行DDL语句
122
-func ExecuteDDL(db *sql.DB, ddlSQL string) error {
123
-	factory := &DBFactory{}
124
-	return factory.executeDDL(db, ddlSQL)
150
+// ExecuteDDLWithTx 快捷在事务中执行DDL语句
151
+func (f *DBFactory) ExecuteDDLWithTx(ddlSQL string) error {
152
+	return functions.ExecuteDDLWithTx(f.db, ddlSQL)
125 153
 }
126 154
 
127
-// QuickExecuteDDLWithTx 快捷在事务中执行DDL语句
128
-func ExecuteDDLWithTx(db *sql.DB, ddlSQL string) error {
129
-	factory := &DBFactory{}
130
-	return factory.executeDDLWithTx(db, ddlSQL)
155
+// ExecuteMultipleDDL 快捷执行多个DDL语句
156
+func (f *DBFactory) ExecuteMultipleDDL(ddlSQLs []string) error {
157
+	return functions.ExecuteMultipleDDL(f.db, ddlSQLs)
131 158
 }
132 159
 
133
-// QuickExecuteMultipleDDL 快捷执行多个DDL语句
134
-func ExecuteMultipleDDL(db *sql.DB, ddlSQLs []string) error {
135
-	factory := &DBFactory{}
136
-	return factory.executeMultipleDDL(db, ddlSQLs)
160
+// GetAvailableDrivers 获取可用的数据库驱动
161
+func (f *DBFactory) GetAvailableDrivers() []string {
162
+	return drivers.GetAllDrivers()
137 163
 }
138 164
 
139
-// testConnection 测试数据库连接
140
-func TestConnection(db *sql.DB, dbType string) error {
141
-	var query string
142
-	switch dbType {
143
-	case "mysql", "postgres", "sqlserver":
144
-		query = "SELECT 1"
145
-	case "oracle":
146
-		query = "SELECT 1 FROM DUAL"
147
-	default:
148
-		query = "SELECT 1"
149
-	}
150
-
151
-	var result int
152
-	err := db.QueryRow(query).Scan(&result)
153
-	if err != nil {
154
-		return err
155
-	}
156
-
157
-	if result != 1 {
158
-		return fmt.Errorf("unexpected test result: %d", result)
159
-	}
160
-
161
-	return nil
165
+func (f *DBFactory) TestConnection(dbType string) error {
166
+	return functions.TestConnection(f.db, dbType)
162 167
 }

factory/execute.go → functions/execute.go 查看文件

@@ -1,15 +1,13 @@
1
-package factory
1
+package functions
2 2
 
3 3
 import (
4
-	"database/sql"
5 4
 	"fmt"
5
+
6
+	"github.com/jmoiron/sqlx"
6 7
 )
7 8
 
8 9
 // ExecuteDDL 执行DDL语句(创建、删除、更新表等)
9
-func (f *DBFactory) executeDDL(db *sql.DB, ddlSQL string) error {
10
-	if db == nil {
11
-		return fmt.Errorf("database connection is nil")
12
-	}
10
+func ExecuteDDL(db *sqlx.DB, ddlSQL string) error {
13 11
 
14 12
 	if ddlSQL == "" {
15 13
 		return fmt.Errorf("DDL SQL statement is empty")
@@ -25,10 +23,7 @@ func (f *DBFactory) executeDDL(db *sql.DB, ddlSQL string) error {
25 23
 }
26 24
 
27 25
 // ExecuteDDLWithTx 在事务中执行DDL语句
28
-func (f *DBFactory) executeDDLWithTx(db *sql.DB, ddlSQL string) error {
29
-	if db == nil {
30
-		return fmt.Errorf("database connection is nil")
31
-	}
26
+func ExecuteDDLWithTx(db *sqlx.DB, ddlSQL string) error {
32 27
 
33 28
 	if ddlSQL == "" {
34 29
 		return fmt.Errorf("DDL SQL statement is empty")
@@ -57,10 +52,7 @@ func (f *DBFactory) executeDDLWithTx(db *sql.DB, ddlSQL string) error {
57 52
 }
58 53
 
59 54
 // ExecuteMultipleDDL 执行多个DDL语句
60
-func (f *DBFactory) executeMultipleDDL(db *sql.DB, ddlSQLs []string) error {
61
-	if db == nil {
62
-		return fmt.Errorf("database connection is nil")
63
-	}
55
+func ExecuteMultipleDDL(db *sqlx.DB, ddlSQLs []string) error {
64 56
 
65 57
 	if len(ddlSQLs) == 0 {
66 58
 		return fmt.Errorf("DDL SQL statements are empty")

factory/query.go → functions/query.go 查看文件

@@ -1,7 +1,6 @@
1
-package factory
1
+package functions
2 2
 
3 3
 import (
4
-	"database/sql"
5 4
 	"encoding/csv"
6 5
 	"encoding/json"
7 6
 	"fmt"
@@ -10,24 +9,11 @@ import (
10 9
 	"time"
11 10
 
12 11
 	"git.x2erp.com/qdy/go-base/types"
12
+	"github.com/jmoiron/sqlx"
13 13
 )
14 14
 
15
-// func typesFunction() *types.QueryResult { // 确保使用了该包导出的类型,例如 QueryResult
16
-// 	return &types.QueryResult{}
17
-// }
18
-
19
-// QueryExecutor 查询执行器
20
-type queryExecutor struct {
21
-	db *sql.DB
22
-}
23
-
24
-// NewQueryExecutor 创建查询执行器
25
-func newQueryExecutor(db *sql.DB) *queryExecutor {
26
-	return &queryExecutor{db: db}
27
-}
28
-
29 15
 // QueryToJSON 执行查询并返回JSON格式数据(统一返回QueryResult)
30
-func (e *queryExecutor) queryToJSON(sql string) *types.QueryResult {
16
+func QueryToJSON(db *sqlx.DB, sql string) *types.QueryResult {
31 17
 	startTime := time.Now()
32 18
 	result := &types.QueryResult{}
33 19
 
@@ -38,7 +24,7 @@ func (e *queryExecutor) queryToJSON(sql string) *types.QueryResult {
38 24
 		return result
39 25
 	}
40 26
 
41
-	rows, err := e.db.Query(sql)
27
+	rows, err := db.Query(sql)
42 28
 	if err != nil {
43 29
 		result.Success = false
44 30
 		result.Error = fmt.Sprintf("Query execution failed: %v", err)
@@ -108,13 +94,13 @@ func (e *queryExecutor) queryToJSON(sql string) *types.QueryResult {
108 94
 }
109 95
 
110 96
 // QueryToCSV 查询并返回 CSV 字节数据(包含表头)
111
-func (e *queryExecutor) queryToCSV(sql string) ([]byte, error) {
97
+func QueryToCSV(db *sqlx.DB, sql string) ([]byte, error) {
112 98
 
113 99
 	if sql == "" {
114 100
 		return nil, fmt.Errorf("SQL query cannot be empty")
115 101
 	}
116 102
 
117
-	rows, err := e.db.Query(sql)
103
+	rows, err := db.Query(sql)
118 104
 	if err != nil {
119 105
 		return nil, fmt.Errorf("query execution failed: %v", err)
120 106
 	}
@@ -174,7 +160,7 @@ func (e *queryExecutor) queryToCSV(sql string) ([]byte, error) {
174 160
 }
175 161
 
176 162
 // ExecuteQueryWithColumns 执行查询并返回完整结果(包含列信息)
177
-func (e *queryExecutor) queryWithColumns(sql string) *types.QueryResult {
163
+func QueryWithColumns(db *sqlx.DB, sql string) *types.QueryResult {
178 164
 	startTime := time.Now()
179 165
 	result := &types.QueryResult{}
180 166
 
@@ -185,7 +171,7 @@ func (e *queryExecutor) queryWithColumns(sql string) *types.QueryResult {
185 171
 		return result
186 172
 	}
187 173
 
188
-	rows, err := e.db.Query(sql)
174
+	rows, err := db.Query(sql)
189 175
 	if err != nil {
190 176
 		result.Success = false
191 177
 		result.Error = fmt.Sprintf("Query execution failed: %v", err)
@@ -253,7 +239,7 @@ func (e *queryExecutor) queryWithColumns(sql string) *types.QueryResult {
253 239
 }
254 240
 
255 241
 // ExecuteQueryDataOnly 执行查询并返回纯数据(不包含列信息,性能更高)
256
-func (e *queryExecutor) queryDataOnly(sql string) *types.QueryResult {
242
+func QueryDataOnly(db *sqlx.DB, sql string) *types.QueryResult {
257 243
 	startTime := time.Now()
258 244
 	result := &types.QueryResult{}
259 245
 
@@ -264,7 +250,7 @@ func (e *queryExecutor) queryDataOnly(sql string) *types.QueryResult {
264 250
 		return result
265 251
 	}
266 252
 
267
-	rows, err := e.db.Query(sql)
253
+	rows, err := db.Query(sql)
268 254
 	if err != nil {
269 255
 		result.Success = false
270 256
 		result.Error = fmt.Sprintf("Query execution failed: %v", err)
@@ -334,7 +320,7 @@ func (e *queryExecutor) queryDataOnly(sql string) *types.QueryResult {
334 320
 }
335 321
 
336 322
 // ExecuteQueryCSV 执行查询并返回CSV格式数据
337
-func (e *queryExecutor) queryCSV(sql string, includeHeader bool) *types.QueryResult {
323
+func QueryCSV(db *sqlx.DB, sql string, includeHeader bool) *types.QueryResult {
338 324
 	startTime := time.Now()
339 325
 	result := &types.QueryResult{}
340 326
 
@@ -345,7 +331,7 @@ func (e *queryExecutor) queryCSV(sql string, includeHeader bool) *types.QueryRes
345 331
 		return result
346 332
 	}
347 333
 
348
-	rows, err := e.db.Query(sql)
334
+	rows, err := db.Query(sql)
349 335
 	if err != nil {
350 336
 		result.Success = false
351 337
 		result.Error = fmt.Sprintf("Query execution failed: %v", err)
@@ -459,8 +445,9 @@ func (e *queryExecutor) queryCSV(sql string, includeHeader bool) *types.QueryRes
459 445
 }
460 446
 
461 447
 // ExecuteQueryCSVStream 流式返回CSV数据
462
-func (e *queryExecutor) queryCSVStream(sql string, w io.Writer, includeHeader bool) (int, error) {
463
-	rows, err := e.db.Query(sql)
448
+func QueryCSVStream(db *sqlx.DB, sql string, w io.Writer, includeHeader bool) (int, error) {
449
+
450
+	rows, err := db.Query(sql)
464 451
 	if err != nil {
465 452
 		return 0, err
466 453
 	}

+ 33
- 0
functions/testConnection.go 查看文件

@@ -0,0 +1,33 @@
1
+package functions
2
+
3
+import (
4
+	"fmt"
5
+
6
+	"github.com/jmoiron/sqlx"
7
+)
8
+
9
+// TestConnection 测试数据库连接
10
+func TestConnection(db *sqlx.DB, dbType string) error {
11
+	var query string
12
+	switch dbType {
13
+	case "mysql", "postgres", "sqlserver":
14
+		query = "SELECT 1"
15
+	case "oracle":
16
+		query = "SELECT 1 FROM DUAL"
17
+	default:
18
+		query = "SELECT 1"
19
+	}
20
+
21
+	var result int
22
+	err := db.QueryRow(query).Scan(&result)
23
+	if err != nil {
24
+		return err
25
+	}
26
+
27
+	if result != 1 {
28
+		return fmt.Errorf("unexpected test result: %d", result)
29
+	}
30
+
31
+	fmt.Println("test Connection  database is success.")
32
+	return nil
33
+}

+ 7
- 1
go.mod 查看文件

@@ -10,14 +10,20 @@ require (
10 10
 	github.com/sijms/go-ora/v2 v2.9.0
11 11
 )
12 12
 
13
-require gopkg.in/yaml.v2 v2.4.0 // indirect
13
+require (
14
+	github.com/cespare/xxhash/v2 v2.1.2 // indirect
15
+	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
16
+	gopkg.in/yaml.v2 v2.4.0 // indirect
17
+)
14 18
 
15 19
 require (
16 20
 	filippo.io/edwards25519 v1.1.0 // indirect
17 21
 	git.x2erp.com/qdy/go-base v0.1.10
22
+	github.com/go-redis/redis/v8 v8.11.5
18 23
 	github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
19 24
 	github.com/golang-sql/sqlexp v0.1.0 // indirect
20 25
 	github.com/google/uuid v1.6.0 // indirect
26
+	github.com/jmoiron/sqlx v1.4.0
21 27
 	golang.org/x/crypto v0.38.0 // indirect
22 28
 	golang.org/x/text v0.25.0 // indirect
23 29
 )

+ 4
- 0
go.sum 查看文件

@@ -24,6 +24,7 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
24 24
 github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
25 25
 github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
26 26
 github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
27
+github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
27 28
 github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
28 29
 github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
29 30
 github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
@@ -34,10 +35,13 @@ github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei
34 35
 github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
35 36
 github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
36 37
 github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
38
+github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
39
+github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
37 40
 github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
38 41
 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
39 42
 github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
40 43
 github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
44
+github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
41 45
 github.com/microsoft/go-mssqldb v1.9.4 h1:sHrj3GcdgkxytZ09aZ3+ys72pMeyEXJowT44j74pNgs=
42 46
 github.com/microsoft/go-mssqldb v1.9.4/go.mod h1:GBbW9ASTiDC+mpgWDGKdm3FnFLTUsLYN3iFL90lQ+PA=
43 47
 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=

+ 15
- 44
test.go 查看文件

@@ -1,68 +1,39 @@
1 1
 package main
2 2
 
3 3
 import (
4
-	"database/sql"
5 4
 	"fmt"
6 5
 	"log"
7 6
 
7
+	"git.x2erp.com/qdy/go-base/config"
8 8
 	"git.x2erp.com/qdy/go-db/factory"
9 9
 )
10 10
 
11 11
 func main() {
12
-	// 创建数据库工厂
13
-	dbFactory, err := factory.NewDBFactory()
14
-	if err != nil {
15
-		log.Fatalf("Failed to create DB factory: %v", err)
16
-	}
17
-
18
-	// 显示可用的数据库驱动
19
-	drivers := dbFactory.GetAvailableDrivers()
20
-	fmt.Printf("Available database drivers: %v\n", drivers)
21 12
 
22 13
 	// 显示当前使用的数据库配置
23
-	config := dbFactory.GetConfig()
14
+	config := config.GetConfig()
24 15
 	dbConfig := config.GetDatabase() // 通过接口方法获取数据库配置
25 16
 	fmt.Printf("Using database type: %s\n", dbConfig.Type)
26 17
 	fmt.Printf("Database host: %s:%d\n", dbConfig.Host, dbConfig.Port)
27 18
 	fmt.Printf("Database name: %s\n", dbConfig.Database)
28 19
 
29
-	// 创建数据库连接
30
-	db, err := dbFactory.CreateDB()
31
-	if err != nil {
32
-		log.Fatalf("Failed to create database connection: %v", err)
33
-	}
34
-	defer db.Close()
35
-
36
-	fmt.Println("Successfully connected to database!")
37
-
38
-	// 测试连接
39
-	if err := testConnection(db, dbConfig.Type); err != nil {
40
-		log.Printf("Query test failed: %v", err)
41
-	} else {
42
-		fmt.Println("Database connection test passed!")
43
-	}
44
-}
45
-
46
-func testConnection(db *sql.DB, dbType string) error {
47
-	var query string
48
-	switch dbType {
49
-	case "mysql", "postgres", "sqlserver":
50
-		query = "SELECT 1"
51
-	case "oracle":
52
-		query = "SELECT 1 FROM DUAL"
53
-	default:
54
-		query = "SELECT 1"
55
-	}
20
+	// 创建数据库工厂
21
+	fmt.Printf("第1次.\n")
22
+	dbFactory, err := factory.GetDBFactory()
56 23
 
57
-	var result int
58
-	err := db.QueryRow(query).Scan(&result)
59 24
 	if err != nil {
60
-		return fmt.Errorf("test query failed: %v", err)
25
+		log.Fatalf("Failed to create DB factory: %v", err)
61 26
 	}
62 27
 
63
-	if result != 1 {
64
-		return fmt.Errorf("unexpected test result: %d", result)
28
+	//测试单例是否生效
29
+	fmt.Printf("第2次.\n")
30
+	dbFactory1, err1 := factory.GetDBFactory()
31
+
32
+	if err1 != nil {
33
+		log.Fatalf("Failed to create DB factory: %v", err1)
65 34
 	}
35
+	dbFactory1.TestConnection(dbConfig.Type)
36
+	defer dbFactory.Close()
37
+	defer dbFactory1.Close()
66 38
 
67
-	return nil
68 39
 }

Loading…
取消
儲存