qdy 2 месяцев назад
Сommit
43ef29540f
3 измененных файлов: 195 добавлений и 44 удалений
  1. 55
    16
      dbstart/db_bootstrapper.go
  2. 138
    26
      factory/database/db_factory.go
  3. 2
    2
      test.go

+ 55
- 16
dbstart/db_bootstrapper.go Просмотреть файл

4
 	"log"
4
 	"log"
5
 
5
 
6
 	"git.x2erp.com/qdy/go-base/config"
6
 	"git.x2erp.com/qdy/go-base/config"
7
+	"git.x2erp.com/qdy/go-base/logger"
7
 	"git.x2erp.com/qdy/go-db/factory/database"
8
 	"git.x2erp.com/qdy/go-db/factory/database"
8
 )
9
 )
9
 
10
 
10
 // DBBootstrapper 数据库启动器
11
 // DBBootstrapper 数据库启动器
11
 type DBBootstrapper struct {
12
 type DBBootstrapper struct {
12
-	DBFactory *database.DBFactory
13
-	cfg       config.IConfig
13
+	DBFactories map[string]*database.DBFactory // 改为map存储多个数据库实例
14
+	cfg         config.IConfig
14
 }
15
 }
15
 
16
 
16
 // NewDBBootstrapper 创建数据库启动器
17
 // NewDBBootstrapper 创建数据库启动器
17
 func NewDBBootstrapper(cfg config.IConfig) *DBBootstrapper {
18
 func NewDBBootstrapper(cfg config.IConfig) *DBBootstrapper {
18
 	return &DBBootstrapper{
19
 	return &DBBootstrapper{
19
-		cfg: cfg,
20
+		DBFactories: make(map[string]*database.DBFactory),
21
+		cfg:         cfg,
20
 	}
22
 	}
21
 }
23
 }
22
 
24
 
23
-// Init 初始化数据库
25
+// Init 初始化默认数据库
24
 func (db *DBBootstrapper) Init() *DBBootstrapper {
26
 func (db *DBBootstrapper) Init() *DBBootstrapper {
25
 	if db.cfg == nil {
27
 	if db.cfg == nil {
26
 		log.Fatal("配置未初始化,请先传入配置")
28
 		log.Fatal("配置未初始化,请先传入配置")
27
 	}
29
 	}
28
 
30
 
31
+	// 初始化默认数据库
29
 	dbCfg := db.cfg.GetDatabase()
32
 	dbCfg := db.cfg.GetDatabase()
33
+	if dbCfg == nil {
34
+		log.Fatal("默认数据库配置未找到")
35
+	}
30
 
36
 
31
-	log.Printf("正在连接数据库: %s:%d/%s",
37
+	log.Printf("正在连接默认数据库: %s:%d/%s",
32
 		dbCfg.Host, dbCfg.Port, dbCfg.Database)
38
 		dbCfg.Host, dbCfg.Port, dbCfg.Database)
33
 
39
 
34
-	dbFactory, err := database.GetDBFactory()
40
+	dbFactory, err := database.GetDefaultDBFactory()
35
 	if err != nil {
41
 	if err != nil {
36
-		log.Fatalf("数据库连接失败: %v", err)
42
+		log.Fatalf("默认数据库连接失败: %v", err)
37
 	}
43
 	}
38
 
44
 
39
-	db.DBFactory = dbFactory
40
-	log.Println("数据库连接成功")
45
+	db.DBFactories["default"] = dbFactory
46
+	log.Println("默认数据库连接成功")
41
 
47
 
42
 	return db
48
 	return db
43
 }
49
 }
44
 
50
 
45
-// Close 关闭数据库连接
51
+// GetDBFactory 获取数据库工厂
52
+func (db *DBBootstrapper) GetDBFactory(dbName string) *database.DBFactory {
53
+	// 如果已经初始化,直接返回
54
+	if factory, exists := db.DBFactories[dbName]; exists {
55
+		return factory
56
+	}
57
+
58
+	// 如果还没初始化,按需初始化
59
+	dbFactory, err := database.GetDBFactory(dbName)
60
+	if err != nil {
61
+		log.Printf("❌ 获取数据库 '%s' 失败: %v", dbName, err)
62
+		return nil
63
+	}
64
+
65
+	db.DBFactories[dbName] = dbFactory
66
+	return dbFactory
67
+}
68
+
69
+// GetDefaultDBFactory 获取默认数据库工厂
70
+func (db *DBBootstrapper) GetDefaultDBFactory() *database.DBFactory {
71
+	return db.GetDBFactory("default")
72
+}
73
+
74
+// Close 关闭所有数据库连接
46
 func (db *DBBootstrapper) Close() {
75
 func (db *DBBootstrapper) Close() {
47
-	if db.DBFactory != nil {
48
-		db.DBFactory.Close()
49
-		log.Println("数据库连接已关闭")
76
+
77
+	for name, factory := range db.DBFactories {
78
+		logger.Info("正在关闭数据库: %s, factory: %v", name, factory)
79
+		if factory != nil {
80
+			factory.Close()
81
+			logger.Info("数据库 '%s' 连接已关闭", name)
82
+		} else {
83
+			logger.Error("⚠️  警告: 数据库 '%s' 的 factory 为 nil", name)
84
+		}
50
 	}
85
 	}
86
+
87
+	// 清空map
88
+	db.DBFactories = make(map[string]*database.DBFactory)
89
+	logger.Info("所有数据库连接已关闭.")
51
 }
90
 }
52
 
91
 
53
-// GetDBFactory 获取数据库工厂
54
-func (db *DBBootstrapper) GetDBFactory() *database.DBFactory {
55
-	return db.DBFactory
92
+// 在 dbstart 包中
93
+func (db *DBBootstrapper) OnShutdown() {
94
+	db.Close()
56
 }
95
 }

+ 138
- 26
factory/database/db_factory.go Просмотреть файл

15
 )
15
 )
16
 
16
 
17
 type DBFactory struct {
17
 type DBFactory struct {
18
-	db *sqlx.DB
18
+	db   *sqlx.DB
19
+	name string // 记录数据库配置名称
19
 }
20
 }
20
 
21
 
21
 var (
22
 var (
22
-	instanceDBFactory *DBFactory
23
-	once              sync.Once
23
+	// 多实例存储:配置名称 -> DBFactory 实例
24
+	instances = make(map[string]*DBFactory)
25
+	// 每个配置名称对应的once,确保线程安全
26
+	onceMap = make(map[string]*sync.Once)
27
+	// 保护instances和onceMap的读写锁
28
+	instancesMutex sync.RWMutex
24
 )
29
 )
25
 
30
 
26
-// GetDBFactory 创建数据库工厂单例
27
-func GetDBFactory() (*DBFactory, error) {
31
+// GetDBFactory 获取指定名称的数据库工厂单例
32
+func GetDBFactory(dbName string) (*DBFactory, error) {
33
+	// 获取或创建该名称的once对象
34
+	instancesMutex.Lock()
35
+	once, exists := onceMap[dbName]
36
+	if !exists {
37
+		once = &sync.Once{}
38
+		onceMap[dbName] = once
39
+	}
40
+	instancesMutex.Unlock()
41
+
28
 	var initErr error
42
 	var initErr error
29
-	var msg = "DBFactory instance retrieved from memory.\n"
43
+	var instance *DBFactory
44
+	var msg = fmt.Sprintf("DBFactory '%s' instance retrieved from memory.\n", dbName)
30
 
45
 
31
 	once.Do(func() {
46
 	once.Do(func() {
32
 		// 使用配置单例
47
 		// 使用配置单例
33
 		cfg, err := config.GetConfig()
48
 		cfg, err := config.GetConfig()
34
-
35
-		// 检查配置初始化是否有错误
36
 		if err != nil {
49
 		if err != nil {
37
 			initErr = fmt.Errorf("failed to load config: %v", err)
50
 			initErr = fmt.Errorf("failed to load config: %v", err)
38
 			return
51
 			return
39
 		}
52
 		}
40
 
53
 
41
-		// 检查数据库配置是否完整
42
-		if !cfg.IsDatabaseConfigured() {
43
-			initErr = fmt.Errorf("database configuration is incomplete")
54
+		// 获取指定名称的数据库配置
55
+		dbConfig := cfg.GetDatabaseConfig(dbName)
56
+		if dbConfig == nil {
57
+			initErr = fmt.Errorf("database configuration '%s' not found", dbName)
44
 			return
58
 			return
45
 		}
59
 		}
46
 
60
 
47
-		// 显示所支持的数据库驱动
48
-		//driversStr := drivers.GetAllDrivers()
61
+		// // 检查数据库配置是否完整
62
+		// if !dbConfig.IsConfigured() {
63
+		// 	initErr = fmt.Errorf("database configuration '%s' is incomplete", dbName)
64
+		// 	return
65
+		// }
49
 
66
 
50
-		dbConfig := cfg.GetDatabase()
67
+		// 获取数据库类型
51
 		dbType := dbConfig.Type
68
 		dbType := dbConfig.Type
52
-		log.Printf("Available database drivers: %v\n", dbType)
69
+		log.Printf("Creating database connection for '%s' with type: %s\n", dbName, dbType)
70
+
53
 		// 获取对应的驱动
71
 		// 获取对应的驱动
54
 		dbDriver, err := drivers.Get(dbType)
72
 		dbDriver, err := drivers.Get(dbType)
55
 		if err != nil {
73
 		if err != nil {
73
 		// 创建数据库连接
91
 		// 创建数据库连接
74
 		db, err := dbDriver.Open(driverConfig)
92
 		db, err := dbDriver.Open(driverConfig)
75
 		if err != nil {
93
 		if err != nil {
76
-			initErr = fmt.Errorf("failed to open database connection: %v", err)
94
+			initErr = fmt.Errorf("failed to open database connection for '%s': %v", dbName, err)
77
 			return
95
 			return
78
 		}
96
 		}
79
 
97
 
80
 		// 测试连接
98
 		// 测试连接
81
 		if err := functions.TestConnection(db, dbType); err != nil {
99
 		if err := functions.TestConnection(db, dbType); err != nil {
82
 			db.Close()
100
 			db.Close()
83
-			initErr = fmt.Errorf("database connection test failed: %v", err)
101
+			initErr = fmt.Errorf("database connection test failed for '%s': %v", dbName, err)
84
 			return
102
 			return
85
 		}
103
 		}
86
 
104
 
87
-		msg = "DBFactory is successfully created.\n"
88
-		instanceDBFactory = &DBFactory{db: db}
105
+		msg = fmt.Sprintf("DBFactory '%s' is successfully created.\n", dbName)
106
+		instance = &DBFactory{
107
+			db:   db,
108
+			name: dbName,
109
+		}
110
+
111
+		// 保存实例到map
112
+		instancesMutex.Lock()
113
+		instances[dbName] = instance
114
+		instancesMutex.Unlock()
89
 	})
115
 	})
90
 
116
 
91
 	if initErr != nil {
117
 	if initErr != nil {
94
 
120
 
95
 	log.Print(msg)
121
 	log.Print(msg)
96
 
122
 
97
-	return instanceDBFactory, nil
123
+	// 从map中获取实例
124
+	instancesMutex.RLock()
125
+	instance = instances[dbName]
126
+	instancesMutex.RUnlock()
127
+
128
+	return instance, nil
129
+}
130
+
131
+// GetDefaultDBFactory 获取默认数据库工厂(向后兼容)
132
+func GetDefaultDBFactory() (*DBFactory, error) {
133
+	return GetDBFactory("default")
134
+}
135
+
136
+// GetAllDBFactories 获取所有已创建的数据库工厂实例
137
+func GetAllDBFactories() map[string]*DBFactory {
138
+	instancesMutex.RLock()
139
+	defer instancesMutex.RUnlock()
140
+
141
+	// 创建副本,避免外部修改
142
+	result := make(map[string]*DBFactory)
143
+	for k, v := range instances {
144
+		result[k] = v
145
+	}
146
+	return result
98
 }
147
 }
99
 
148
 
149
+// GetDBFactoryNames 获取所有可用的数据库配置名称
150
+func GetDBFactoryNames() []string {
151
+	cfg, err := config.GetConfig()
152
+	if err != nil {
153
+		return []string{}
154
+	}
155
+
156
+	dbs := cfg.GetDatabases()
157
+	if dbs == nil {
158
+		return []string{}
159
+	}
160
+
161
+	return dbs.GetAllDatabaseNames()
162
+}
163
+
164
+// CloseInstance 关闭指定名称的数据库连接
165
+func CloseInstance(dbName string) error {
166
+	instancesMutex.Lock()
167
+	defer instancesMutex.Unlock()
168
+
169
+	if instance, exists := instances[dbName]; exists {
170
+		err := instance.Close()
171
+		delete(instances, dbName)
172
+		delete(onceMap, dbName)
173
+		return err
174
+	}
175
+
176
+	return fmt.Errorf("database instance '%s' not found", dbName)
177
+}
178
+
179
+// CloseAll 关闭所有数据库连接
180
+func CloseAll() {
181
+	instancesMutex.Lock()
182
+	defer instancesMutex.Unlock()
183
+
184
+	for name, instance := range instances {
185
+		if err := instance.Close(); err != nil {
186
+			log.Printf("Error closing database instance '%s': %v\n", name, err)
187
+		}
188
+		delete(instances, name)
189
+		delete(onceMap, name)
190
+	}
191
+
192
+	// 重新初始化maps
193
+	instances = make(map[string]*DBFactory)
194
+	onceMap = make(map[string]*sync.Once)
195
+
196
+	log.Println("All database connections closed gracefully")
197
+}
198
+
199
+// ========== DBFactory 实例方法 ==========
200
+
100
 // GetDB 获取数据库连接(线程安全)
201
 // GetDB 获取数据库连接(线程安全)
101
 func (f *DBFactory) GetDB() interface{} {
202
 func (f *DBFactory) GetDB() interface{} {
102
 	return f.db
203
 	return f.db
107
 	if f.db != nil {
208
 	if f.db != nil {
108
 		err := f.db.Close()
209
 		err := f.db.Close()
109
 		f.db = nil
210
 		f.db = nil
110
-		log.Println("Database connection closed gracefully")
211
+		log.Printf("Database connection '%s' closed gracefully\n", f.name)
111
 		return err
212
 		return err
112
 	}
213
 	}
113
 	return nil
214
 	return nil
115
 
216
 
116
 // GetDBType 得到当前使用数据库类型
217
 // GetDBType 得到当前使用数据库类型
117
 func (f *DBFactory) GetDBType() string {
218
 func (f *DBFactory) GetDBType() string {
118
-	dbConfig := config.GetDatabase()
219
+	// 通过配置获取当前数据库的类型
220
+	cfg, err := config.GetConfig()
221
+	if err != nil {
222
+		return ""
223
+	}
224
+
225
+	dbConfig := cfg.GetDatabaseConfig(f.name)
226
+	if dbConfig == nil {
227
+		return ""
228
+	}
229
+
119
 	return dbConfig.Type
230
 	return dbConfig.Type
120
 }
231
 }
121
 
232
 
233
+// GetDBName 获取数据库配置名称
234
+func (f *DBFactory) GetDBName() string {
235
+	return f.name
236
+}
237
+
122
 // QueryToJSON 快捷查询,直接返回 JSON 字节流
238
 // QueryToJSON 快捷查询,直接返回 JSON 字节流
123
 func (f *DBFactory) QueryToJSON(sql string, reqCtx *ctx.RequestContext) *types.QueryResult[[]map[string]interface{}] {
239
 func (f *DBFactory) QueryToJSON(sql string, reqCtx *ctx.RequestContext) *types.QueryResult[[]map[string]interface{}] {
124
 	return functions.QueryToJSON(f.db, sql, reqCtx)
240
 	return functions.QueryToJSON(f.db, sql, reqCtx)
126
 
242
 
127
 // QueryParamsToJSON 位置参数查询并返回 JSON 字节数据
243
 // QueryParamsToJSON 位置参数查询并返回 JSON 字节数据
128
 func (f *DBFactory) QueryPositionalToJSON(sql string, params []interface{}, reqCtx *ctx.RequestContext) *types.QueryResult[[]map[string]interface{}] {
244
 func (f *DBFactory) QueryPositionalToJSON(sql string, params []interface{}, reqCtx *ctx.RequestContext) *types.QueryResult[[]map[string]interface{}] {
129
-
130
 	return functions.QueryPositionalToJSON(f.db, sql, params, reqCtx)
245
 	return functions.QueryPositionalToJSON(f.db, sql, params, reqCtx)
131
 }
246
 }
132
 
247
 
133
 // QueryParamsNameToJSON 命名参数查询并返回 JSON 字节数据
248
 // QueryParamsNameToJSON 命名参数查询并返回 JSON 字节数据
134
 // params 可以是 map[string]interface{} 或结构体
249
 // params 可以是 map[string]interface{} 或结构体
135
 func (f *DBFactory) QueryParamsNameToJSON(sql string, params map[string]interface{}, reqCtx *ctx.RequestContext) *types.QueryResult[[]map[string]interface{}] {
250
 func (f *DBFactory) QueryParamsNameToJSON(sql string, params map[string]interface{}, reqCtx *ctx.RequestContext) *types.QueryResult[[]map[string]interface{}] {
136
-
137
 	return functions.QueryParamsNameToJSON(f.db, sql, params, reqCtx)
251
 	return functions.QueryParamsNameToJSON(f.db, sql, params, reqCtx)
138
 }
252
 }
139
 
253
 
144
 
258
 
145
 // QueryParamsToCSV 位置参数查询并返回 CSV 字节数据
259
 // QueryParamsToCSV 位置参数查询并返回 CSV 字节数据
146
 func (f *DBFactory) QueryPositionalToCSV(sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
260
 func (f *DBFactory) QueryPositionalToCSV(sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
147
-
148
 	return functions.QueryPositionalToCSV(f.db, sql, writerHeader, params, reqCtx)
261
 	return functions.QueryPositionalToCSV(f.db, sql, writerHeader, params, reqCtx)
149
 }
262
 }
150
 
263
 
151
 // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
264
 // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
152
 // params 可以是 map[string]interface{} 或结构体
265
 // params 可以是 map[string]interface{} 或结构体
153
 func (f *DBFactory) QueryParamsNameToCSV(sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
266
 func (f *DBFactory) QueryParamsNameToCSV(sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
154
-
155
 	return functions.QueryParamsNameToCSV(f.db, sql, writerHeader, params, reqCtx)
267
 	return functions.QueryParamsNameToCSV(f.db, sql, writerHeader, params, reqCtx)
156
 }
268
 }
157
 
269
 

+ 2
- 2
test.go Просмотреть файл

25
 
25
 
26
 	// 创建数据库工厂
26
 	// 创建数据库工厂
27
 	fmt.Printf("第1次.\n")
27
 	fmt.Printf("第1次.\n")
28
-	dbFactory, err := database.GetDBFactory()
28
+	dbFactory, err := database.GetDefaultDBFactory()
29
 
29
 
30
 	if err != nil {
30
 	if err != nil {
31
 		log.Fatalf("Failed to create DB factory: %v", err)
31
 		log.Fatalf("Failed to create DB factory: %v", err)
33
 
33
 
34
 	//测试单例是否生效
34
 	//测试单例是否生效
35
 	fmt.Printf("第2次.\n")
35
 	fmt.Printf("第2次.\n")
36
-	dbFactory1, err1 := database.GetDBFactory()
36
+	dbFactory1, err1 := database.GetDefaultDBFactory()
37
 
37
 
38
 	if err1 != nil {
38
 	if err1 != nil {
39
 		log.Fatalf("Failed to create DB factory: %v", err1)
39
 		log.Fatalf("Failed to create DB factory: %v", err1)

Загрузка…
Отмена
Сохранить