qdy 3 місяці тому
коміт
7fa6b873ef
13 змінених файлів з 1134 додано та 0 видалено
  1. BIN
      .DS_Store
  2. 10
    0
      db.yaml
  3. 10
    0
      db_pg.yaml
  4. 76
    0
      drivers/driver.go
  5. 45
    0
      drivers/mysql.go
  6. 45
    0
      drivers/oracle.go
  7. 45
    0
      drivers/postgres.go
  8. 45
    0
      drivers/sqlserver.go
  9. 160
    0
      factory/factory.go
  10. 554
    0
      factory/query.go
  11. 21
    0
      go.mod
  12. 56
    0
      go.sum
  13. 67
    0
      test.go

+ 10
- 0
db.yaml Переглянути файл

@@ -0,0 +1,10 @@
1
+database:
2
+  type: "oracle"  # mysql, postgres, oracle, sqlserver
3
+  host: "161.189.7.134"
4
+  port: 1521
5
+  username: "x6_stock_dev"
6
+  password: "mosdev"
7
+  database: "ORCL"
8
+  max_open_conns: 100
9
+  max_idle_conns: 10
10
+  conn_max_lifetime: 300

+ 10
- 0
db_pg.yaml Переглянути файл

@@ -0,0 +1,10 @@
1
+database:
2
+  type: "postgres"  # mysql, postgres, oracle, sqlserver
3
+  host: "69.235.172.218"
4
+  port: 5432
5
+  username: "x3stock"
6
+  password: "mos8555"
7
+  database: "x3stock"
8
+  max_open_conns: 100
9
+  max_idle_conns: 10
10
+  conn_max_lifetime: 300

+ 76
- 0
drivers/driver.go Переглянути файл

@@ -0,0 +1,76 @@
1
+package drivers
2
+
3
+import (
4
+	"database/sql"
5
+	"fmt"
6
+	"time" // 添加 time 包导入
7
+)
8
+
9
+// DBConfig 数据库配置
10
+type DBConfig struct {
11
+	Type            string `yaml:"type"`
12
+	Host            string `yaml:"host"`
13
+	Port            int    `yaml:"port"`
14
+	Username        string `yaml:"username"`
15
+	Password        string `yaml:"password"`
16
+	Database        string `yaml:"database"`
17
+	MaxOpenConns    int    `yaml:"max_open_conns"`
18
+	MaxIdleConns    int    `yaml:"max_idle_conns"`
19
+	ConnMaxLifetime int    `yaml:"conn_max_lifetime"` // 单位:秒
20
+}
21
+
22
+// DBDriver 数据库驱动接口
23
+type DBDriver interface {
24
+	// Name 返回驱动名称
25
+	Name() string
26
+	// Open 打开数据库连接
27
+	Open(config DBConfig) (*sql.DB, error)
28
+	// BuildDSN 构建连接字符串
29
+	BuildDSN(config DBConfig) string
30
+}
31
+
32
+var driverss = make(map[string]DBDriver)
33
+
34
+// Register 注册数据库驱动
35
+func Register(driver DBDriver) {
36
+	if driver == nil {
37
+		panic("db driver: Register driver is nil")
38
+	}
39
+	name := driver.Name()
40
+	if _, dup := driverss[name]; dup {
41
+		panic("db driver: Register called twice for driver " + name)
42
+	}
43
+	driverss[name] = driver
44
+}
45
+
46
+// Get 获取数据库驱动
47
+func Get(name string) (DBDriver, error) {
48
+	driver, exists := driverss[name]
49
+	if !exists {
50
+		return nil, fmt.Errorf("db driver: unknown driver %q", name)
51
+	}
52
+	return driver, nil
53
+}
54
+
55
+// GetAllDrivers 获取所有已注册的驱动名称
56
+func GetAllDrivers() []string {
57
+	names := make([]string, 0, len(driverss))
58
+	for name := range driverss {
59
+		names = append(names, name)
60
+	}
61
+	return names
62
+}
63
+
64
+// configureConnectionPool 配置连接池的公共函数
65
+func configureConnectionPool(db *sql.DB, config DBConfig) {
66
+	if config.MaxOpenConns > 0 {
67
+		db.SetMaxOpenConns(config.MaxOpenConns)
68
+	}
69
+	if config.MaxIdleConns > 0 {
70
+		db.SetMaxIdleConns(config.MaxIdleConns)
71
+	}
72
+	if config.ConnMaxLifetime > 0 {
73
+		// 将秒转换为 time.Duration
74
+		db.SetConnMaxLifetime(time.Duration(config.ConnMaxLifetime) * time.Second)
75
+	}
76
+}

+ 45
- 0
drivers/mysql.go Переглянути файл

@@ -0,0 +1,45 @@
1
+package drivers
2
+
3
+import (
4
+	"database/sql"
5
+	"fmt"
6
+
7
+	_ "github.com/go-sql-driver/mysql"
8
+)
9
+
10
+type MySQLDriver struct{}
11
+
12
+func (d *MySQLDriver) Name() string {
13
+	return "mysql"
14
+}
15
+
16
+func (d *MySQLDriver) BuildDSN(config DBConfig) string {
17
+	return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
18
+		config.Username,
19
+		config.Password,
20
+		config.Host,
21
+		config.Port,
22
+		config.Database)
23
+}
24
+
25
+func (d *MySQLDriver) Open(config DBConfig) (*sql.DB, error) {
26
+	dsn := d.BuildDSN(config)
27
+	db, err := sql.Open("mysql", dsn)
28
+	if err != nil {
29
+		return nil, err
30
+	}
31
+
32
+	// 使用公共的连接池配置函数
33
+	configureConnectionPool(db, config)
34
+
35
+	// 测试连接
36
+	if err = db.Ping(); err != nil {
37
+		return nil, err
38
+	}
39
+
40
+	return db, nil
41
+}
42
+
43
+func init() {
44
+	Register(&MySQLDriver{})
45
+}

+ 45
- 0
drivers/oracle.go Переглянути файл

@@ -0,0 +1,45 @@
1
+package drivers
2
+
3
+import (
4
+	"database/sql"
5
+	"fmt"
6
+
7
+	_ "github.com/sijms/go-ora/v2"
8
+)
9
+
10
+type OracleDriver struct{}
11
+
12
+func (d *OracleDriver) Name() string {
13
+	return "oracle"
14
+}
15
+
16
+func (d *OracleDriver) BuildDSN(config DBConfig) string {
17
+	return fmt.Sprintf("oracle://%s:%s@%s:%d/%s",
18
+		config.Username,
19
+		config.Password,
20
+		config.Host,
21
+		config.Port,
22
+		config.Database)
23
+}
24
+
25
+func (d *OracleDriver) Open(config DBConfig) (*sql.DB, error) {
26
+	dsn := d.BuildDSN(config)
27
+	db, err := sql.Open("oracle", dsn)
28
+	if err != nil {
29
+		return nil, err
30
+	}
31
+
32
+	// 使用公共的连接池配置函数
33
+	configureConnectionPool(db, config)
34
+
35
+	// 测试连接
36
+	if err = db.Ping(); err != nil {
37
+		return nil, err
38
+	}
39
+
40
+	return db, nil
41
+}
42
+
43
+func init() {
44
+	Register(&OracleDriver{})
45
+}

+ 45
- 0
drivers/postgres.go Переглянути файл

@@ -0,0 +1,45 @@
1
+package drivers
2
+
3
+import (
4
+	"database/sql"
5
+	"fmt"
6
+
7
+	_ "github.com/lib/pq"
8
+)
9
+
10
+type PostgresDriver struct{}
11
+
12
+func (d *PostgresDriver) Name() string {
13
+	return "postgres"
14
+}
15
+
16
+func (d *PostgresDriver) BuildDSN(config DBConfig) string {
17
+	return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
18
+		config.Host,
19
+		config.Port,
20
+		config.Username,
21
+		config.Password,
22
+		config.Database)
23
+}
24
+
25
+func (d *PostgresDriver) Open(config DBConfig) (*sql.DB, error) {
26
+	dsn := d.BuildDSN(config)
27
+	db, err := sql.Open("postgres", dsn)
28
+	if err != nil {
29
+		return nil, err
30
+	}
31
+
32
+	// 使用公共的连接池配置函数
33
+	configureConnectionPool(db, config)
34
+
35
+	// 测试连接
36
+	if err = db.Ping(); err != nil {
37
+		return nil, err
38
+	}
39
+
40
+	return db, nil
41
+}
42
+
43
+func init() {
44
+	Register(&PostgresDriver{})
45
+}

+ 45
- 0
drivers/sqlserver.go Переглянути файл

@@ -0,0 +1,45 @@
1
+package drivers
2
+
3
+import (
4
+	"database/sql"
5
+	"fmt"
6
+
7
+	_ "github.com/microsoft/go-mssqldb"
8
+)
9
+
10
+type SQLServerDriver struct{}
11
+
12
+func (d *SQLServerDriver) Name() string {
13
+	return "sqlserver"
14
+}
15
+
16
+func (d *SQLServerDriver) BuildDSN(config DBConfig) string {
17
+	return fmt.Sprintf("server=%s;port=%d;user id=%s;password=%s;database=%s;encrypt=disable",
18
+		config.Host,
19
+		config.Port,
20
+		config.Username,
21
+		config.Password,
22
+		config.Database)
23
+}
24
+
25
+func (d *SQLServerDriver) Open(config DBConfig) (*sql.DB, error) {
26
+	dsn := d.BuildDSN(config)
27
+	db, err := sql.Open("sqlserver", dsn)
28
+	if err != nil {
29
+		return nil, err
30
+	}
31
+
32
+	// 使用公共的连接池配置函数
33
+	configureConnectionPool(db, config)
34
+
35
+	// 测试连接
36
+	if err = db.Ping(); err != nil {
37
+		return nil, err
38
+	}
39
+
40
+	return db, nil
41
+}
42
+
43
+func init() {
44
+	Register(&SQLServerDriver{})
45
+}

+ 160
- 0
factory/factory.go Переглянути файл

@@ -0,0 +1,160 @@
1
+package factory
2
+
3
+import (
4
+	"database/sql"
5
+	"fmt"
6
+	"io"
7
+	"os"
8
+	"path/filepath"
9
+
10
+	"git.x2erp.com/qdy/go-base/types"
11
+	"git.x2erp.com/qdy/go-db/drivers"
12
+	"gopkg.in/yaml.v2"
13
+)
14
+
15
+// Config 总配置
16
+type Config struct {
17
+	Database drivers.DBConfig `yaml:"database"`
18
+}
19
+
20
+// DBFactory 数据库工厂
21
+type DBFactory struct {
22
+	config *Config
23
+}
24
+
25
+// NewDBFactory 创建数据库工厂
26
+func NewDBFactory() (*DBFactory, error) {
27
+	configFile, err := findConfigFile()
28
+	if err != nil {
29
+		return nil, err
30
+	}
31
+
32
+	fmt.Printf("✅ Using config file: %s\n", configFile)
33
+
34
+	// 读取配置文件
35
+	data, err := os.ReadFile(configFile)
36
+	if err != nil {
37
+		return nil, fmt.Errorf("failed to read config file %s: %v", configFile, err)
38
+	}
39
+
40
+	var config Config
41
+	err = yaml.Unmarshal(data, &config)
42
+	if err != nil {
43
+		return nil, fmt.Errorf("failed to parse config file: %v", err)
44
+	}
45
+
46
+	return &DBFactory{config: &config}, nil
47
+}
48
+
49
+// findConfigFile 查找配置文件
50
+func findConfigFile() (string, error) {
51
+	// 1. 首先尝试可执行文件同目录
52
+	exePath, err := os.Executable()
53
+	if err == nil {
54
+		exeDir := filepath.Dir(exePath)
55
+		configFile := filepath.Join(exeDir, "db.yaml")
56
+		if _, err := os.Stat(configFile); err == nil {
57
+			return configFile, nil
58
+		}
59
+	}
60
+
61
+	// 2. 尝试环境变量指定的路径
62
+	envConfigPath := os.Getenv("DB_CONFIG_PATH")
63
+	if envConfigPath != "" {
64
+		if _, err := os.Stat(envConfigPath); err == nil {
65
+			return envConfigPath, nil
66
+		}
67
+		return "", fmt.Errorf("DB_CONFIG_PATH file not found: %s", envConfigPath)
68
+	}
69
+
70
+	// 3. 如果都没有找到,返回错误
71
+	exeDir := "unknown"
72
+	if exePath, err := os.Executable(); err == nil {
73
+		exeDir = filepath.Dir(exePath)
74
+	}
75
+
76
+	return "", fmt.Errorf(`No configuration file found!
77
+
78
+Tried locations:
79
+1. Executable directory: %s/db.yaml
80
+2. Environment variable: DB_CONFIG_PATH
81
+
82
+Solutions:
83
+- Place db.yaml in the same directory as the executable
84
+- Or set DB_CONFIG_PATH environment variable to config file path
85
+
86
+Example:
87
+  export DB_CONFIG_PATH=/path/to/your/db.yaml`, exeDir)
88
+}
89
+
90
+// CreateDB 创建数据库连接
91
+func (f *DBFactory) CreateDB() (*sql.DB, error) {
92
+	dbType := f.config.Database.Type
93
+
94
+	// 获取对应的驱动
95
+	dbDriver, err := drivers.Get(dbType)
96
+	if err != nil {
97
+		return nil, fmt.Errorf("failed to get database driver: %v", err)
98
+	}
99
+
100
+	// 创建数据库连接
101
+	db, err := dbDriver.Open(f.config.Database)
102
+	if err != nil {
103
+		return nil, fmt.Errorf("failed to open database connection: %v", err)
104
+	}
105
+
106
+	return db, nil
107
+}
108
+
109
+// GetConfig 获取配置信息
110
+func (f *DBFactory) GetConfig() *Config {
111
+	return f.config
112
+}
113
+
114
+// GetAvailableDrivers 获取可用的数据库驱动
115
+func (f *DBFactory) GetAvailableDrivers() []string {
116
+	return drivers.GetAllDrivers()
117
+}
118
+
119
+// CreateQueryExecutor 创建查询执行器(新增方法)
120
+func (f *DBFactory) CreateQueryExecutor(db *sql.DB) *QueryExecutor {
121
+	return NewQueryExecutor(db)
122
+}
123
+
124
+// -------------- 对外暴露的初始化方法(核心入口)--------------
125
+// NewDBQuery 初始化查询实例(对外提供唯一初始化入口)
126
+// db: 已初始化的数据库连接(由调用方传入,解耦数据库配置)
127
+func NewDBQuery(db *sql.DB) *QueryExecutor {
128
+	// 直接复用原文件的构造函数,对外隐藏实现细节
129
+	return NewQueryExecutor(db)
130
+}
131
+
132
+// QuickQueryToJSON 快捷查询,直接返回 JSON 字节流
133
+func QuickQueryToJSON(db *sql.DB, sql string) *types.QueryResult {
134
+	return NewDBQuery(db).QueryToJSON(sql)
135
+}
136
+
137
+// QuickQueryToCSV 快捷查询,直接返回 CSV 字符串(包含表头)
138
+func QuickQueryToCSV(db *sql.DB, sql string) *types.QueryResult {
139
+	return NewDBQuery(db).QueryToCSV(sql)
140
+}
141
+
142
+// QuickExecuteWithColumns 快捷查询,返回完整结果(含列信息)
143
+func QuickExecuteWithColumns(db *sql.DB, sql string) *types.QueryResult {
144
+	return NewDBQuery(db).ExecuteQueryWithColumns(sql)
145
+}
146
+
147
+// QuickExecuteDataOnly 快捷查询,返回纯数据(性能优先)
148
+func QuickExecuteDataOnly(db *sql.DB, sql string) *types.QueryResult {
149
+	return NewDBQuery(db).ExecuteQueryDataOnly(sql)
150
+}
151
+
152
+// QuickExecuteCSV 快捷查询,返回 CSV 格式结果(支持自定义是否包含表头)
153
+func QuickExecuteCSV(db *sql.DB, sql string, includeHeader bool) *types.QueryResult {
154
+	return NewDBQuery(db).ExecuteQueryCSV(sql, includeHeader)
155
+}
156
+
157
+// QuickExecuteCSVStream 快捷流式输出 CSV(直接写入 io.Writer,适合大文件)
158
+func QuickExecuteCSVStream(db *sql.DB, sql string, w io.Writer, includeHeader bool) (int, error) {
159
+	return NewDBQuery(db).ExecuteQueryCSVStream(sql, w, includeHeader)
160
+}

+ 554
- 0
factory/query.go Переглянути файл

@@ -0,0 +1,554 @@
1
+package factory
2
+
3
+import (
4
+	"database/sql"
5
+	"encoding/csv"
6
+	"encoding/json"
7
+	"fmt"
8
+	"io"
9
+	"strings"
10
+	"time"
11
+
12
+	"git.x2erp.com/qdy/go-base/types"
13
+)
14
+
15
+func typesFunction() *types.QueryResult { // 确保使用了该包导出的类型,例如 QueryResult
16
+	return &types.QueryResult{}
17
+}
18
+
19
+// QueryExecutor 查询执行器
20
+type QueryExecutor struct {
21
+	db *sql.DB
22
+}
23
+
24
+// NewQueryExecutor 创建查询执行器
25
+func NewQueryExecutor(db *sql.DB) *QueryExecutor {
26
+	return &QueryExecutor{db: db}
27
+}
28
+
29
+// QueryToJSON 执行查询并返回JSON格式数据(统一返回QueryResult)
30
+func (e *QueryExecutor) QueryToJSON(sql string) *types.QueryResult {
31
+	startTime := time.Now()
32
+	result := &types.QueryResult{}
33
+
34
+	if sql == "" {
35
+		result.Success = false
36
+		result.Error = "SQL query cannot be empty"
37
+		result.Time = time.Since(startTime).String()
38
+		return result
39
+	}
40
+
41
+	rows, err := e.db.Query(sql)
42
+	if err != nil {
43
+		result.Success = false
44
+		result.Error = fmt.Sprintf("Query execution failed: %v", err)
45
+		result.Time = time.Since(startTime).String()
46
+		return result
47
+	}
48
+	defer rows.Close()
49
+
50
+	columns, err := rows.Columns()
51
+	if err != nil {
52
+		result.Success = false
53
+		result.Error = fmt.Sprintf("Failed to get columns: %v", err)
54
+		result.Time = time.Since(startTime).String()
55
+		return result
56
+	}
57
+
58
+	var results []map[string]interface{}
59
+	count := 0
60
+
61
+	for rows.Next() {
62
+		count++
63
+		values := make([]interface{}, len(columns))
64
+		valuePtrs := make([]interface{}, len(columns))
65
+		for i := range columns {
66
+			valuePtrs[i] = &values[i]
67
+		}
68
+
69
+		if err := rows.Scan(valuePtrs...); err != nil {
70
+			result.Success = false
71
+			result.Error = fmt.Sprintf("Failed to scan row: %v", err)
72
+			result.Time = time.Since(startTime).String()
73
+			return result
74
+		}
75
+
76
+		resultMap := make(map[string]interface{})
77
+		for i, col := range columns {
78
+			// 完全不处理类型,直接赋值,让 json.Marshal 自己处理
79
+			resultMap[col] = values[i]
80
+		}
81
+		results = append(results, resultMap)
82
+	}
83
+
84
+	if err := rows.Err(); err != nil {
85
+		result.Success = false
86
+		result.Error = fmt.Sprintf("Row iteration error: %v", err)
87
+		result.Time = time.Since(startTime).String()
88
+		return result
89
+	}
90
+
91
+	jsonData, err := json.Marshal(results)
92
+	if err != nil {
93
+		result.Success = false
94
+		result.Error = fmt.Sprintf("JSON marshal failed: %v", err)
95
+		result.Time = time.Since(startTime).String()
96
+		return result
97
+	}
98
+
99
+	result.Success = true
100
+	result.Data = map[string]interface{}{
101
+		"json":  string(jsonData),
102
+		"rows":  results,
103
+		"count": count,
104
+	}
105
+	result.Count = count
106
+	result.Time = time.Since(startTime).String()
107
+	return result
108
+}
109
+
110
+// QueryToCSV 查询并返回 CSV 字符串(包含表头,统一返回QueryResult)
111
+func (e *QueryExecutor) QueryToCSV(sql string) *types.QueryResult {
112
+	startTime := time.Now()
113
+	result := &types.QueryResult{}
114
+
115
+	if sql == "" {
116
+		result.Success = false
117
+		result.Error = "SQL query cannot be empty"
118
+		result.Time = time.Since(startTime).String()
119
+		return result
120
+	}
121
+
122
+	rows, err := e.db.Query(sql)
123
+	if err != nil {
124
+		result.Success = false
125
+		result.Error = fmt.Sprintf("Query execution failed: %v", err)
126
+		result.Time = time.Since(startTime).String()
127
+		return result
128
+	}
129
+	defer rows.Close()
130
+
131
+	columns, err := rows.Columns()
132
+	if err != nil {
133
+		result.Success = false
134
+		result.Error = fmt.Sprintf("Failed to get columns: %v", err)
135
+		result.Time = time.Since(startTime).String()
136
+		return result
137
+	}
138
+
139
+	var builder strings.Builder
140
+	writer := csv.NewWriter(&builder)
141
+	count := 0
142
+
143
+	// 写入表头
144
+	if err := writer.Write(columns); err != nil {
145
+		result.Success = false
146
+		result.Error = fmt.Sprintf("Failed to write CSV header: %v", err)
147
+		result.Time = time.Since(startTime).String()
148
+		return result
149
+	}
150
+
151
+	for rows.Next() {
152
+		count++
153
+		values := make([]interface{}, len(columns))
154
+		valuePtrs := make([]any, len(columns))
155
+		for i := range columns {
156
+			valuePtrs[i] = &values[i]
157
+		}
158
+
159
+		if err := rows.Scan(valuePtrs...); err != nil {
160
+			result.Success = false
161
+			result.Error = fmt.Sprintf("Failed to scan row: %v", err)
162
+			result.Time = time.Since(startTime).String()
163
+			return result
164
+		}
165
+
166
+		// 所有值转为字符串
167
+		row := make([]string, len(columns))
168
+		for i, val := range values {
169
+			if val == nil {
170
+				row[i] = ""
171
+			} else {
172
+				row[i] = fmt.Sprintf("%v", val)
173
+			}
174
+		}
175
+
176
+		if err := writer.Write(row); err != nil {
177
+			result.Success = false
178
+			result.Error = fmt.Sprintf("Failed to write CSV row: %v", err)
179
+			result.Time = time.Since(startTime).String()
180
+			return result
181
+		}
182
+	}
183
+
184
+	writer.Flush()
185
+	if err := writer.Error(); err != nil {
186
+		result.Success = false
187
+		result.Error = fmt.Sprintf("Failed to flush CSV: %v", err)
188
+		result.Time = time.Since(startTime).String()
189
+		return result
190
+	}
191
+
192
+	if err := rows.Err(); err != nil {
193
+		result.Success = false
194
+		result.Error = fmt.Sprintf("Row iteration error: %v", err)
195
+		result.Time = time.Since(startTime).String()
196
+		return result
197
+	}
198
+
199
+	result.Success = true
200
+	result.Data = map[string]interface{}{
201
+		"csv":           builder.String(),
202
+		"count":         count,
203
+		"includeHeader": true,
204
+	}
205
+	result.Count = count
206
+	result.Time = time.Since(startTime).String()
207
+	return result
208
+}
209
+
210
+// ExecuteQueryWithColumns 执行查询并返回完整结果(包含列信息)
211
+func (e *QueryExecutor) ExecuteQueryWithColumns(sql string) *types.QueryResult {
212
+	startTime := time.Now()
213
+	result := &types.QueryResult{}
214
+
215
+	if sql == "" {
216
+		result.Success = false
217
+		result.Error = "SQL query cannot be empty"
218
+		result.Time = time.Since(startTime).String()
219
+		return result
220
+	}
221
+
222
+	rows, err := e.db.Query(sql)
223
+	if err != nil {
224
+		result.Success = false
225
+		result.Error = fmt.Sprintf("Query execution failed: %v", err)
226
+		result.Time = time.Since(startTime).String()
227
+		return result
228
+	}
229
+	defer rows.Close()
230
+
231
+	columns, err := rows.Columns()
232
+	if err != nil {
233
+		result.Success = false
234
+		result.Error = fmt.Sprintf("Failed to get columns: %v", err)
235
+		result.Time = time.Since(startTime).String()
236
+		return result
237
+	}
238
+
239
+	var results []map[string]interface{}
240
+	count := 0
241
+
242
+	for rows.Next() {
243
+		count++
244
+
245
+		values := make([]interface{}, len(columns))
246
+		valuePtrs := make([]interface{}, len(columns))
247
+		for i := range columns {
248
+			valuePtrs[i] = &values[i]
249
+		}
250
+
251
+		if err := rows.Scan(valuePtrs...); err != nil {
252
+			result.Success = false
253
+			result.Error = fmt.Sprintf("Failed to scan row: %v", err)
254
+			result.Time = time.Since(startTime).String()
255
+			return result
256
+		}
257
+
258
+		resultRow := make(map[string]interface{})
259
+		for i, col := range columns {
260
+			val := values[i]
261
+			switch v := val.(type) {
262
+			case []byte:
263
+				resultRow[col] = string(v)
264
+			case time.Time:
265
+				resultRow[col] = v.Format(time.RFC3339)
266
+			default:
267
+				resultRow[col] = v
268
+			}
269
+		}
270
+
271
+		results = append(results, resultRow)
272
+	}
273
+
274
+	if err := rows.Err(); err != nil {
275
+		result.Success = false
276
+		result.Error = fmt.Sprintf("Row iteration error: %v", err)
277
+		result.Time = time.Since(startTime).String()
278
+		return result
279
+	}
280
+
281
+	result.Success = true
282
+	result.Data = results
283
+	result.Count = count
284
+	result.Time = time.Since(startTime).String()
285
+
286
+	return result
287
+}
288
+
289
+// ExecuteQueryDataOnly 执行查询并返回纯数据(不包含列信息,性能更高)
290
+func (e *QueryExecutor) ExecuteQueryDataOnly(sql string) *types.QueryResult {
291
+	startTime := time.Now()
292
+	result := &types.QueryResult{}
293
+
294
+	if sql == "" {
295
+		result.Success = false
296
+		result.Error = "SQL query cannot be empty"
297
+		result.Time = time.Since(startTime).String()
298
+		return result
299
+	}
300
+
301
+	rows, err := e.db.Query(sql)
302
+	if err != nil {
303
+		result.Success = false
304
+		result.Error = fmt.Sprintf("Query execution failed: %v", err)
305
+		result.Time = time.Since(startTime).String()
306
+		return result
307
+	}
308
+	defer rows.Close()
309
+
310
+	columns, err := rows.Columns()
311
+	if err != nil {
312
+		result.Success = false
313
+		result.Error = fmt.Sprintf("Failed to get columns: %v", err)
314
+		result.Time = time.Since(startTime).String()
315
+		return result
316
+	}
317
+
318
+	var results []interface{}
319
+	count := 0
320
+
321
+	for rows.Next() {
322
+		count++
323
+
324
+		values := make([]interface{}, len(columns))
325
+		valuePtrs := make([]interface{}, len(columns))
326
+		for i := range columns {
327
+			valuePtrs[i] = &values[i]
328
+		}
329
+
330
+		if err := rows.Scan(valuePtrs...); err != nil {
331
+			result.Success = false
332
+			result.Error = fmt.Sprintf("Failed to scan row: %v", err)
333
+			result.Time = time.Since(startTime).String()
334
+			return result
335
+		}
336
+
337
+		resultRow := make([]interface{}, len(columns))
338
+		for i, val := range values {
339
+			switch v := val.(type) {
340
+			case []byte:
341
+				resultRow[i] = string(v)
342
+			case time.Time:
343
+				resultRow[i] = v.Format(time.RFC3339)
344
+			default:
345
+				resultRow[i] = v
346
+			}
347
+		}
348
+
349
+		results = append(results, resultRow)
350
+	}
351
+
352
+	if err := rows.Err(); err != nil {
353
+		result.Success = false
354
+		result.Error = fmt.Sprintf("Row iteration error: %v", err)
355
+		result.Time = time.Since(startTime).String()
356
+		return result
357
+	}
358
+
359
+	result.Success = true
360
+	result.Data = map[string]interface{}{
361
+		"rows":  results,
362
+		"count": count,
363
+	}
364
+	result.Count = count
365
+	result.Time = time.Since(startTime).String()
366
+
367
+	return result
368
+}
369
+
370
+// ExecuteQueryCSV 执行查询并返回CSV格式数据
371
+func (e *QueryExecutor) ExecuteQueryCSV(sql string, includeHeader bool) *types.QueryResult {
372
+	startTime := time.Now()
373
+	result := &types.QueryResult{}
374
+
375
+	if sql == "" {
376
+		result.Success = false
377
+		result.Error = "SQL query cannot be empty"
378
+		result.Time = time.Since(startTime).String()
379
+		return result
380
+	}
381
+
382
+	rows, err := e.db.Query(sql)
383
+	if err != nil {
384
+		result.Success = false
385
+		result.Error = fmt.Sprintf("Query execution failed: %v", err)
386
+		result.Time = time.Since(startTime).String()
387
+		return result
388
+	}
389
+	defer rows.Close()
390
+
391
+	columns, err := rows.Columns()
392
+	if err != nil {
393
+		result.Success = false
394
+		result.Error = fmt.Sprintf("Failed to get columns: %v", err)
395
+		result.Time = time.Since(startTime).String()
396
+		return result
397
+	}
398
+
399
+	var csvBuilder strings.Builder
400
+	writer := csv.NewWriter(&csvBuilder)
401
+
402
+	if includeHeader {
403
+		if err := writer.Write(columns); err != nil {
404
+			result.Success = false
405
+			result.Error = fmt.Sprintf("Failed to write CSV header: %v", err)
406
+			result.Time = time.Since(startTime).String()
407
+			return result
408
+		}
409
+	}
410
+
411
+	count := 0
412
+
413
+	for rows.Next() {
414
+		count++
415
+
416
+		values := make([]interface{}, len(columns))
417
+		valuePtrs := make([]interface{}, len(columns))
418
+		for i := range columns {
419
+			valuePtrs[i] = &values[i]
420
+		}
421
+
422
+		if err := rows.Scan(valuePtrs...); err != nil {
423
+			result.Success = false
424
+			result.Error = fmt.Sprintf("Failed to scan row: %v", err)
425
+			result.Time = time.Since(startTime).String()
426
+			return result
427
+		}
428
+
429
+		rowData := make([]string, len(columns))
430
+		for i, val := range values {
431
+			if val == nil {
432
+				rowData[i] = ""
433
+				continue
434
+			}
435
+
436
+			switch v := val.(type) {
437
+			case []byte:
438
+				rowData[i] = string(v)
439
+			case string:
440
+				rowData[i] = v
441
+			case int, int8, int16, int32, int64:
442
+				rowData[i] = fmt.Sprintf("%d", v)
443
+			case uint, uint8, uint16, uint32, uint64:
444
+				rowData[i] = fmt.Sprintf("%d", v)
445
+			case float32, float64:
446
+				rowData[i] = fmt.Sprintf("%f", v)
447
+			case bool:
448
+				if v {
449
+					rowData[i] = "true"
450
+				} else {
451
+					rowData[i] = "false"
452
+				}
453
+			case time.Time:
454
+				rowData[i] = v.Format(time.RFC3339)
455
+			default:
456
+				rowData[i] = fmt.Sprintf("%v", v)
457
+			}
458
+		}
459
+
460
+		if err := writer.Write(rowData); err != nil {
461
+			result.Success = false
462
+			result.Error = fmt.Sprintf("Failed to write CSV row: %v", err)
463
+			result.Time = time.Since(startTime).String()
464
+			return result
465
+		}
466
+	}
467
+
468
+	if err := rows.Err(); err != nil {
469
+		result.Success = false
470
+		result.Error = fmt.Sprintf("Row iteration error: %v", err)
471
+		result.Time = time.Since(startTime).String()
472
+		return result
473
+	}
474
+
475
+	writer.Flush()
476
+	if err := writer.Error(); err != nil {
477
+		result.Success = false
478
+		result.Error = fmt.Sprintf("Failed to flush CSV: %v", err)
479
+		result.Time = time.Since(startTime).String()
480
+		return result
481
+	}
482
+
483
+	result.Success = true
484
+	result.Data = map[string]interface{}{
485
+		"csv":           csvBuilder.String(),
486
+		"count":         count,
487
+		"includeHeader": includeHeader,
488
+	}
489
+	result.Count = count
490
+	result.Time = time.Since(startTime).String()
491
+
492
+	return result
493
+}
494
+
495
+// ExecuteQueryCSVStream 流式返回CSV数据
496
+func (e *QueryExecutor) ExecuteQueryCSVStream(sql string, w io.Writer, includeHeader bool) (int, error) {
497
+	rows, err := e.db.Query(sql)
498
+	if err != nil {
499
+		return 0, err
500
+	}
501
+	defer rows.Close()
502
+
503
+	columns, err := rows.Columns()
504
+	if err != nil {
505
+		return 0, err
506
+	}
507
+
508
+	writer := csv.NewWriter(w)
509
+	count := 0
510
+
511
+	if includeHeader {
512
+		if err := writer.Write(columns); err != nil {
513
+			return 0, err
514
+		}
515
+	}
516
+
517
+	for rows.Next() {
518
+		count++
519
+
520
+		values := make([]interface{}, len(columns))
521
+		valuePtrs := make([]interface{}, len(columns))
522
+		for i := range columns {
523
+			valuePtrs[i] = &values[i]
524
+		}
525
+
526
+		if err := rows.Scan(valuePtrs...); err != nil {
527
+			return count, err
528
+		}
529
+
530
+		rowData := make([]string, len(columns))
531
+		for i, val := range values {
532
+			if val == nil {
533
+				rowData[i] = ""
534
+				continue
535
+			}
536
+			rowData[i] = fmt.Sprintf("%v", val)
537
+		}
538
+
539
+		if err := writer.Write(rowData); err != nil {
540
+			return count, err
541
+		}
542
+	}
543
+
544
+	writer.Flush()
545
+	if err := writer.Error(); err != nil {
546
+		return count, err
547
+	}
548
+
549
+	if err := rows.Err(); err != nil {
550
+		return count, err
551
+	}
552
+
553
+	return count, nil
554
+}

+ 21
- 0
go.mod Переглянути файл

@@ -0,0 +1,21 @@
1
+module git.x2erp.com/qdy/go-db
2
+
3
+go 1.25.4
4
+
5
+require (
6
+	github.com/go-sql-driver/mysql v1.9.3
7
+	github.com/lib/pq v1.10.9
8
+	github.com/microsoft/go-mssqldb v1.9.4
9
+	github.com/sijms/go-ora/v2 v2.9.0
10
+	gopkg.in/yaml.v2 v2.4.0
11
+)
12
+
13
+require (
14
+	filippo.io/edwards25519 v1.1.0 // indirect
15
+	git.x2erp.com/qdy/go-base v0.1.1
16
+	github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
17
+	github.com/golang-sql/sqlexp v0.1.0 // indirect
18
+	github.com/google/uuid v1.6.0 // indirect
19
+	golang.org/x/crypto v0.38.0 // indirect
20
+	golang.org/x/text v0.25.0 // indirect
21
+)

+ 56
- 0
go.sum Переглянути файл

@@ -0,0 +1,56 @@
1
+filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
2
+filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
3
+git.x2erp.com/qdy/go-base v0.1.1 h1:Yj6rFTCL9CShqdQeE6mUYQMugUpI25nWNL0i09SFtXo=
4
+git.x2erp.com/qdy/go-base v0.1.1/go.mod h1:+NdHouWcxqex8sJUwDTGSl/OIh5fKOT6tJteefD1MMA=
5
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U=
6
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM=
7
+github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
8
+github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
9
+github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4=
10
+github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA=
11
+github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w=
12
+github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14=
13
+github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI=
14
+github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww=
15
+github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
16
+github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
17
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
18
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
19
+github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
20
+github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
21
+github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
22
+github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
23
+github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
24
+github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
25
+github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
26
+github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
27
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
28
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
29
+github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
30
+github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
31
+github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
32
+github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
33
+github.com/microsoft/go-mssqldb v1.9.4 h1:sHrj3GcdgkxytZ09aZ3+ys72pMeyEXJowT44j74pNgs=
34
+github.com/microsoft/go-mssqldb v1.9.4/go.mod h1:GBbW9ASTiDC+mpgWDGKdm3FnFLTUsLYN3iFL90lQ+PA=
35
+github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
36
+github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
37
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
38
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
39
+github.com/sijms/go-ora/v2 v2.9.0 h1:+iQbUeTeCOFMb5BsOMgUhV8KWyrv9yjKpcK4x7+MFrg=
40
+github.com/sijms/go-ora/v2 v2.9.0/go.mod h1:QgFInVi3ZWyqAiJwzBQA+nbKYKH77tdp1PYoCqhR2dU=
41
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
42
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
43
+golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
44
+golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
45
+golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
46
+golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
47
+golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
48
+golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
49
+golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
50
+golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
51
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
52
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
53
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
54
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
55
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
56
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 67
- 0
test.go Переглянути файл

@@ -0,0 +1,67 @@
1
+package main
2
+
3
+import (
4
+	"database/sql"
5
+	"fmt"
6
+	"log"
7
+
8
+	"git.x2erp.com/qdy/go-db/factory"
9
+)
10
+
11
+func main() {
12
+	// 创建数据库工厂
13
+	dbFactory, err := factory.NewDBFactory()
14
+	if err != nil {
15
+		log.Fatalf("Failed to create DB factory: %v", err)
16
+	}
17
+
18
+	// 显示可用的数据库驱动
19
+	drivers := dbFactory.GetAvailableDrivers()
20
+	fmt.Printf("Available database drivers: %v\n", drivers)
21
+
22
+	// 显示当前使用的数据库配置
23
+	config := dbFactory.GetConfig()
24
+	fmt.Printf("Using database type: %s\n", config.Database.Type)
25
+	fmt.Printf("Database host: %s:%d\n", config.Database.Host, config.Database.Port) // 修正这里
26
+	fmt.Printf("Database name: %s\n", config.Database.Database)                      // 修正这里
27
+
28
+	// 创建数据库连接
29
+	db, err := dbFactory.CreateDB()
30
+	if err != nil {
31
+		log.Fatalf("Failed to create database connection: %v", err)
32
+	}
33
+	defer db.Close()
34
+
35
+	fmt.Println("Successfully connected to database!")
36
+
37
+	// 测试连接
38
+	if err := testConnection(db, config.Database.Type); err != nil { // 修正这里
39
+		log.Printf("Query test failed: %v", err)
40
+	} else {
41
+		fmt.Println("Database connection test passed!")
42
+	}
43
+}
44
+
45
+func testConnection(db *sql.DB, dbType string) error {
46
+	var query string
47
+	switch dbType {
48
+	case "mysql", "postgres", "sqlserver":
49
+		query = "SELECT 1"
50
+	case "oracle":
51
+		query = "SELECT 1 FROM DUAL"
52
+	default:
53
+		query = "SELECT 1"
54
+	}
55
+
56
+	var result int
57
+	err := db.QueryRow(query).Scan(&result)
58
+	if err != nil {
59
+		return fmt.Errorf("test query failed: %v", err)
60
+	}
61
+
62
+	if result != 1 {
63
+		return fmt.Errorf("unexpected test result: %d", result)
64
+	}
65
+
66
+	return nil
67
+}

Завантаження…
Відмінити
Зберегти