Bladeren bron

Release v0.1.1220

qdy 2 maanden geleden
commit
43ef29540f
3 gewijzigde bestanden met toevoegingen van 195 en 44 verwijderingen
  1. 55
    16
      dbstart/db_bootstrapper.go
  2. 138
    26
      factory/database/db_factory.go
  3. 2
    2
      test.go

+ 55
- 16
dbstart/db_bootstrapper.go Bestand weergeven

@@ -4,53 +4,92 @@ import (
4 4
 	"log"
5 5
 
6 6
 	"git.x2erp.com/qdy/go-base/config"
7
+	"git.x2erp.com/qdy/go-base/logger"
7 8
 	"git.x2erp.com/qdy/go-db/factory/database"
8 9
 )
9 10
 
10 11
 // DBBootstrapper 数据库启动器
11 12
 type DBBootstrapper struct {
12
-	DBFactory *database.DBFactory
13
-	cfg       config.IConfig
13
+	DBFactories map[string]*database.DBFactory // 改为map存储多个数据库实例
14
+	cfg         config.IConfig
14 15
 }
15 16
 
16 17
 // NewDBBootstrapper 创建数据库启动器
17 18
 func NewDBBootstrapper(cfg config.IConfig) *DBBootstrapper {
18 19
 	return &DBBootstrapper{
19
-		cfg: cfg,
20
+		DBFactories: make(map[string]*database.DBFactory),
21
+		cfg:         cfg,
20 22
 	}
21 23
 }
22 24
 
23
-// Init 初始化数据库
25
+// Init 初始化默认数据库
24 26
 func (db *DBBootstrapper) Init() *DBBootstrapper {
25 27
 	if db.cfg == nil {
26 28
 		log.Fatal("配置未初始化,请先传入配置")
27 29
 	}
28 30
 
31
+	// 初始化默认数据库
29 32
 	dbCfg := db.cfg.GetDatabase()
33
+	if dbCfg == nil {
34
+		log.Fatal("默认数据库配置未找到")
35
+	}
30 36
 
31
-	log.Printf("正在连接数据库: %s:%d/%s",
37
+	log.Printf("正在连接默认数据库: %s:%d/%s",
32 38
 		dbCfg.Host, dbCfg.Port, dbCfg.Database)
33 39
 
34
-	dbFactory, err := database.GetDBFactory()
40
+	dbFactory, err := database.GetDefaultDBFactory()
35 41
 	if err != nil {
36
-		log.Fatalf("数据库连接失败: %v", err)
42
+		log.Fatalf("默认数据库连接失败: %v", err)
37 43
 	}
38 44
 
39
-	db.DBFactory = dbFactory
40
-	log.Println("数据库连接成功")
45
+	db.DBFactories["default"] = dbFactory
46
+	log.Println("默认数据库连接成功")
41 47
 
42 48
 	return db
43 49
 }
44 50
 
45
-// Close 关闭数据库连接
51
+// GetDBFactory 获取数据库工厂
52
+func (db *DBBootstrapper) GetDBFactory(dbName string) *database.DBFactory {
53
+	// 如果已经初始化,直接返回
54
+	if factory, exists := db.DBFactories[dbName]; exists {
55
+		return factory
56
+	}
57
+
58
+	// 如果还没初始化,按需初始化
59
+	dbFactory, err := database.GetDBFactory(dbName)
60
+	if err != nil {
61
+		log.Printf("❌ 获取数据库 '%s' 失败: %v", dbName, err)
62
+		return nil
63
+	}
64
+
65
+	db.DBFactories[dbName] = dbFactory
66
+	return dbFactory
67
+}
68
+
69
+// GetDefaultDBFactory 获取默认数据库工厂
70
+func (db *DBBootstrapper) GetDefaultDBFactory() *database.DBFactory {
71
+	return db.GetDBFactory("default")
72
+}
73
+
74
+// Close 关闭所有数据库连接
46 75
 func (db *DBBootstrapper) Close() {
47
-	if db.DBFactory != nil {
48
-		db.DBFactory.Close()
49
-		log.Println("数据库连接已关闭")
76
+
77
+	for name, factory := range db.DBFactories {
78
+		logger.Info("正在关闭数据库: %s, factory: %v", name, factory)
79
+		if factory != nil {
80
+			factory.Close()
81
+			logger.Info("数据库 '%s' 连接已关闭", name)
82
+		} else {
83
+			logger.Error("⚠️  警告: 数据库 '%s' 的 factory 为 nil", name)
84
+		}
50 85
 	}
86
+
87
+	// 清空map
88
+	db.DBFactories = make(map[string]*database.DBFactory)
89
+	logger.Info("所有数据库连接已关闭.")
51 90
 }
52 91
 
53
-// GetDBFactory 获取数据库工厂
54
-func (db *DBBootstrapper) GetDBFactory() *database.DBFactory {
55
-	return db.DBFactory
92
+// 在 dbstart 包中
93
+func (db *DBBootstrapper) OnShutdown() {
94
+	db.Close()
56 95
 }

+ 138
- 26
factory/database/db_factory.go Bestand weergeven

@@ -15,41 +15,59 @@ import (
15 15
 )
16 16
 
17 17
 type DBFactory struct {
18
-	db *sqlx.DB
18
+	db   *sqlx.DB
19
+	name string // 记录数据库配置名称
19 20
 }
20 21
 
21 22
 var (
22
-	instanceDBFactory *DBFactory
23
-	once              sync.Once
23
+	// 多实例存储:配置名称 -> DBFactory 实例
24
+	instances = make(map[string]*DBFactory)
25
+	// 每个配置名称对应的once,确保线程安全
26
+	onceMap = make(map[string]*sync.Once)
27
+	// 保护instances和onceMap的读写锁
28
+	instancesMutex sync.RWMutex
24 29
 )
25 30
 
26
-// GetDBFactory 创建数据库工厂单例
27
-func GetDBFactory() (*DBFactory, error) {
31
+// GetDBFactory 获取指定名称的数据库工厂单例
32
+func GetDBFactory(dbName string) (*DBFactory, error) {
33
+	// 获取或创建该名称的once对象
34
+	instancesMutex.Lock()
35
+	once, exists := onceMap[dbName]
36
+	if !exists {
37
+		once = &sync.Once{}
38
+		onceMap[dbName] = once
39
+	}
40
+	instancesMutex.Unlock()
41
+
28 42
 	var initErr error
29
-	var msg = "DBFactory instance retrieved from memory.\n"
43
+	var instance *DBFactory
44
+	var msg = fmt.Sprintf("DBFactory '%s' instance retrieved from memory.\n", dbName)
30 45
 
31 46
 	once.Do(func() {
32 47
 		// 使用配置单例
33 48
 		cfg, err := config.GetConfig()
34
-
35
-		// 检查配置初始化是否有错误
36 49
 		if err != nil {
37 50
 			initErr = fmt.Errorf("failed to load config: %v", err)
38 51
 			return
39 52
 		}
40 53
 
41
-		// 检查数据库配置是否完整
42
-		if !cfg.IsDatabaseConfigured() {
43
-			initErr = fmt.Errorf("database configuration is incomplete")
54
+		// 获取指定名称的数据库配置
55
+		dbConfig := cfg.GetDatabaseConfig(dbName)
56
+		if dbConfig == nil {
57
+			initErr = fmt.Errorf("database configuration '%s' not found", dbName)
44 58
 			return
45 59
 		}
46 60
 
47
-		// 显示所支持的数据库驱动
48
-		//driversStr := drivers.GetAllDrivers()
61
+		// // 检查数据库配置是否完整
62
+		// if !dbConfig.IsConfigured() {
63
+		// 	initErr = fmt.Errorf("database configuration '%s' is incomplete", dbName)
64
+		// 	return
65
+		// }
49 66
 
50
-		dbConfig := cfg.GetDatabase()
67
+		// 获取数据库类型
51 68
 		dbType := dbConfig.Type
52
-		log.Printf("Available database drivers: %v\n", dbType)
69
+		log.Printf("Creating database connection for '%s' with type: %s\n", dbName, dbType)
70
+
53 71
 		// 获取对应的驱动
54 72
 		dbDriver, err := drivers.Get(dbType)
55 73
 		if err != nil {
@@ -73,19 +91,27 @@ func GetDBFactory() (*DBFactory, error) {
73 91
 		// 创建数据库连接
74 92
 		db, err := dbDriver.Open(driverConfig)
75 93
 		if err != nil {
76
-			initErr = fmt.Errorf("failed to open database connection: %v", err)
94
+			initErr = fmt.Errorf("failed to open database connection for '%s': %v", dbName, err)
77 95
 			return
78 96
 		}
79 97
 
80 98
 		// 测试连接
81 99
 		if err := functions.TestConnection(db, dbType); err != nil {
82 100
 			db.Close()
83
-			initErr = fmt.Errorf("database connection test failed: %v", err)
101
+			initErr = fmt.Errorf("database connection test failed for '%s': %v", dbName, err)
84 102
 			return
85 103
 		}
86 104
 
87
-		msg = "DBFactory is successfully created.\n"
88
-		instanceDBFactory = &DBFactory{db: db}
105
+		msg = fmt.Sprintf("DBFactory '%s' is successfully created.\n", dbName)
106
+		instance = &DBFactory{
107
+			db:   db,
108
+			name: dbName,
109
+		}
110
+
111
+		// 保存实例到map
112
+		instancesMutex.Lock()
113
+		instances[dbName] = instance
114
+		instancesMutex.Unlock()
89 115
 	})
90 116
 
91 117
 	if initErr != nil {
@@ -94,9 +120,84 @@ func GetDBFactory() (*DBFactory, error) {
94 120
 
95 121
 	log.Print(msg)
96 122
 
97
-	return instanceDBFactory, nil
123
+	// 从map中获取实例
124
+	instancesMutex.RLock()
125
+	instance = instances[dbName]
126
+	instancesMutex.RUnlock()
127
+
128
+	return instance, nil
129
+}
130
+
131
+// GetDefaultDBFactory 获取默认数据库工厂(向后兼容)
132
+func GetDefaultDBFactory() (*DBFactory, error) {
133
+	return GetDBFactory("default")
134
+}
135
+
136
+// GetAllDBFactories 获取所有已创建的数据库工厂实例
137
+func GetAllDBFactories() map[string]*DBFactory {
138
+	instancesMutex.RLock()
139
+	defer instancesMutex.RUnlock()
140
+
141
+	// 创建副本,避免外部修改
142
+	result := make(map[string]*DBFactory)
143
+	for k, v := range instances {
144
+		result[k] = v
145
+	}
146
+	return result
98 147
 }
99 148
 
149
+// GetDBFactoryNames 获取所有可用的数据库配置名称
150
+func GetDBFactoryNames() []string {
151
+	cfg, err := config.GetConfig()
152
+	if err != nil {
153
+		return []string{}
154
+	}
155
+
156
+	dbs := cfg.GetDatabases()
157
+	if dbs == nil {
158
+		return []string{}
159
+	}
160
+
161
+	return dbs.GetAllDatabaseNames()
162
+}
163
+
164
+// CloseInstance 关闭指定名称的数据库连接
165
+func CloseInstance(dbName string) error {
166
+	instancesMutex.Lock()
167
+	defer instancesMutex.Unlock()
168
+
169
+	if instance, exists := instances[dbName]; exists {
170
+		err := instance.Close()
171
+		delete(instances, dbName)
172
+		delete(onceMap, dbName)
173
+		return err
174
+	}
175
+
176
+	return fmt.Errorf("database instance '%s' not found", dbName)
177
+}
178
+
179
+// CloseAll 关闭所有数据库连接
180
+func CloseAll() {
181
+	instancesMutex.Lock()
182
+	defer instancesMutex.Unlock()
183
+
184
+	for name, instance := range instances {
185
+		if err := instance.Close(); err != nil {
186
+			log.Printf("Error closing database instance '%s': %v\n", name, err)
187
+		}
188
+		delete(instances, name)
189
+		delete(onceMap, name)
190
+	}
191
+
192
+	// 重新初始化maps
193
+	instances = make(map[string]*DBFactory)
194
+	onceMap = make(map[string]*sync.Once)
195
+
196
+	log.Println("All database connections closed gracefully")
197
+}
198
+
199
+// ========== DBFactory 实例方法 ==========
200
+
100 201
 // GetDB 获取数据库连接(线程安全)
101 202
 func (f *DBFactory) GetDB() interface{} {
102 203
 	return f.db
@@ -107,7 +208,7 @@ func (f *DBFactory) Close() error {
107 208
 	if f.db != nil {
108 209
 		err := f.db.Close()
109 210
 		f.db = nil
110
-		log.Println("Database connection closed gracefully")
211
+		log.Printf("Database connection '%s' closed gracefully\n", f.name)
111 212
 		return err
112 213
 	}
113 214
 	return nil
@@ -115,10 +216,25 @@ func (f *DBFactory) Close() error {
115 216
 
116 217
 // GetDBType 得到当前使用数据库类型
117 218
 func (f *DBFactory) GetDBType() string {
118
-	dbConfig := config.GetDatabase()
219
+	// 通过配置获取当前数据库的类型
220
+	cfg, err := config.GetConfig()
221
+	if err != nil {
222
+		return ""
223
+	}
224
+
225
+	dbConfig := cfg.GetDatabaseConfig(f.name)
226
+	if dbConfig == nil {
227
+		return ""
228
+	}
229
+
119 230
 	return dbConfig.Type
120 231
 }
121 232
 
233
+// GetDBName 获取数据库配置名称
234
+func (f *DBFactory) GetDBName() string {
235
+	return f.name
236
+}
237
+
122 238
 // QueryToJSON 快捷查询,直接返回 JSON 字节流
123 239
 func (f *DBFactory) QueryToJSON(sql string, reqCtx *ctx.RequestContext) *types.QueryResult[[]map[string]interface{}] {
124 240
 	return functions.QueryToJSON(f.db, sql, reqCtx)
@@ -126,14 +242,12 @@ func (f *DBFactory) QueryToJSON(sql string, reqCtx *ctx.RequestContext) *types.Q
126 242
 
127 243
 // QueryParamsToJSON 位置参数查询并返回 JSON 字节数据
128 244
 func (f *DBFactory) QueryPositionalToJSON(sql string, params []interface{}, reqCtx *ctx.RequestContext) *types.QueryResult[[]map[string]interface{}] {
129
-
130 245
 	return functions.QueryPositionalToJSON(f.db, sql, params, reqCtx)
131 246
 }
132 247
 
133 248
 // QueryParamsNameToJSON 命名参数查询并返回 JSON 字节数据
134 249
 // params 可以是 map[string]interface{} 或结构体
135 250
 func (f *DBFactory) QueryParamsNameToJSON(sql string, params map[string]interface{}, reqCtx *ctx.RequestContext) *types.QueryResult[[]map[string]interface{}] {
136
-
137 251
 	return functions.QueryParamsNameToJSON(f.db, sql, params, reqCtx)
138 252
 }
139 253
 
@@ -144,14 +258,12 @@ func (f *DBFactory) QueryToCSV(sql string, writerHeader bool, reqCtx *ctx.Reques
144 258
 
145 259
 // QueryParamsToCSV 位置参数查询并返回 CSV 字节数据
146 260
 func (f *DBFactory) QueryPositionalToCSV(sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
147
-
148 261
 	return functions.QueryPositionalToCSV(f.db, sql, writerHeader, params, reqCtx)
149 262
 }
150 263
 
151 264
 // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
152 265
 // params 可以是 map[string]interface{} 或结构体
153 266
 func (f *DBFactory) QueryParamsNameToCSV(sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
154
-
155 267
 	return functions.QueryParamsNameToCSV(f.db, sql, writerHeader, params, reqCtx)
156 268
 }
157 269
 

+ 2
- 2
test.go Bestand weergeven

@@ -25,7 +25,7 @@ func main() {
25 25
 
26 26
 	// 创建数据库工厂
27 27
 	fmt.Printf("第1次.\n")
28
-	dbFactory, err := database.GetDBFactory()
28
+	dbFactory, err := database.GetDefaultDBFactory()
29 29
 
30 30
 	if err != nil {
31 31
 		log.Fatalf("Failed to create DB factory: %v", err)
@@ -33,7 +33,7 @@ func main() {
33 33
 
34 34
 	//测试单例是否生效
35 35
 	fmt.Printf("第2次.\n")
36
-	dbFactory1, err1 := database.GetDBFactory()
36
+	dbFactory1, err1 := database.GetDefaultDBFactory()
37 37
 
38 38
 	if err1 != nil {
39 39
 		log.Fatalf("Failed to create DB factory: %v", err1)

Laden…
Annuleren
Opslaan