暫無描述
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

db_factory.go 8.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. package database
  2. import (
  3. "fmt"
  4. "log"
  5. "sync"
  6. "git.x2erp.com/qdy/go-base/config"
  7. "git.x2erp.com/qdy/go-base/config/subconfigs"
  8. "git.x2erp.com/qdy/go-base/ctx"
  9. "git.x2erp.com/qdy/go-base/logger"
  10. "git.x2erp.com/qdy/go-base/model/response"
  11. "git.x2erp.com/qdy/go-db/drivers"
  12. "git.x2erp.com/qdy/go-db/functions"
  13. "github.com/jmoiron/sqlx"
  14. )
  15. // DBFactory 数据库工厂(全局单例模式)
  16. type DBFactory struct {
  17. db *sqlx.DB
  18. config *subconfigs.DatabaseConfig
  19. }
  20. var (
  21. instanceDB *DBFactory
  22. instanceDBOnce sync.Once
  23. initErrDB error
  24. )
  25. // CreateDBFactory 获取数据库工厂单例
  26. func CreateDBFactory(cfg config.IConfig) *DBFactory {
  27. config := cfg.GetDatabaseConfig()
  28. instanceDBOnce.Do(func() {
  29. if config == nil {
  30. log.Fatal("配置未初始化,请先在yaml进行配置")
  31. }
  32. // 设置默认值
  33. if config.MaxOpenConns == 0 {
  34. config.MaxOpenConns = 100
  35. }
  36. if config.MaxIdleConns == 0 {
  37. config.MaxIdleConns = 10
  38. }
  39. if config.ConnMaxLifetime == 0 {
  40. config.ConnMaxLifetime = 5 * 60 // 5分钟,单位秒
  41. }
  42. // 验证配置
  43. if config.Type == "" {
  44. initErrDB = fmt.Errorf("database type must be configured")
  45. return
  46. }
  47. if config.Host == "" {
  48. initErrDB = fmt.Errorf("database host must be configured")
  49. return
  50. }
  51. if config.Database == "" {
  52. initErrDB = fmt.Errorf("database name must be configured")
  53. return
  54. }
  55. log.Printf("Creating database connection...")
  56. // 获取对应的驱动
  57. dbDriver, err := drivers.Get(config.Type)
  58. if err != nil {
  59. initErrDB = fmt.Errorf("failed to get database driver: %v", err)
  60. return
  61. }
  62. // 将内部 DBConfig 转换为 drivers.DBConfig
  63. driverConfig := drivers.DBConfig{
  64. Type: config.Type,
  65. Host: config.Host,
  66. Port: config.Port,
  67. Username: config.Username,
  68. Password: config.Password,
  69. Database: config.Database,
  70. MaxOpenConns: config.MaxOpenConns,
  71. MaxIdleConns: config.MaxIdleConns,
  72. ConnMaxLifetime: config.ConnMaxLifetime,
  73. }
  74. // 创建数据库连接
  75. db, err := dbDriver.Open(driverConfig)
  76. if err != nil {
  77. initErrDB = fmt.Errorf("failed to open database connection: %v", err)
  78. return
  79. }
  80. // 测试连接
  81. if err := functions.TestConnection(db, config.Type); err != nil {
  82. db.Close()
  83. initErrDB = fmt.Errorf("database connection test failed: %v", err)
  84. return
  85. }
  86. log.Printf("DBFactory is successfully created.\n")
  87. instanceDB = &DBFactory{
  88. db: db,
  89. config: config,
  90. }
  91. })
  92. if initErrDB != nil {
  93. log.Fatalf("DBFactory is error: '%v'", initErrDB)
  94. }
  95. return instanceDB
  96. }
  97. // ========== DBFactory 实例方法 ==========
  98. // GetDB 获取数据库连接
  99. func (f *DBFactory) GetDB() *sqlx.DB {
  100. return f.db
  101. }
  102. func (f *DBFactory) GetName() string {
  103. return "DBFactory"
  104. }
  105. // Close 关闭数据库连接
  106. func (f *DBFactory) Close() {
  107. if f.db != nil {
  108. err := f.db.Close()
  109. if err != nil {
  110. logger.Errorf("failed to close database connection: %v", err)
  111. }
  112. log.Printf("Database connection closed gracefully\n")
  113. f.db = nil
  114. }
  115. }
  116. // GetConfig 获取配置信息
  117. func (f *DBFactory) GetConfig() subconfigs.DatabaseConfig {
  118. return *f.config
  119. }
  120. // TestConnection 测试连接
  121. func (f *DBFactory) TestConnection() error {
  122. return functions.TestConnection(f.db, f.config.Type)
  123. }
  124. // ========== 快捷操作方法 ==========
  125. // QueryToJSON 快捷查询,直接返回 JSON 字节流
  126. func (f *DBFactory) QueryToJSON(sql string, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
  127. return functions.QueryToJSON(f.db, sql, reqCtx)
  128. }
  129. // QueryPositionalToJSON 位置参数查询并返回 JSON 字节数据
  130. func (f *DBFactory) QueryPositionalToJSON(sql string, params []interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
  131. return functions.QueryPositionalToJSON(f.db, sql, params, reqCtx)
  132. }
  133. // QueryParamsNameToJSON 命名参数查询并返回 JSON 字节数据
  134. func (f *DBFactory) QueryParamsNameToJSON(sql string, params map[string]interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
  135. return functions.QueryParamsNameToJSON(f.db, sql, params, reqCtx)
  136. }
  137. // QueryToCSV 快捷查询,直接返回 CSV 字符串(包含表头)
  138. func (f *DBFactory) QueryToCSV(sql string, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) {
  139. return functions.QueryToCSV(f.db, sql, writerHeader, reqCtx)
  140. }
  141. // QueryPositionalToCSV 位置参数查询并返回 CSV 字节数据
  142. func (f *DBFactory) QueryPositionalToCSV(sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
  143. return functions.QueryPositionalToCSV(f.db, sql, writerHeader, params, reqCtx)
  144. }
  145. // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
  146. func (f *DBFactory) QueryParamsNameToCSV(sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
  147. return functions.QueryParamsNameToCSV(f.db, sql, writerHeader, params, reqCtx)
  148. }
  149. // ExecuteDDL 快捷执行DDL语句
  150. func (f *DBFactory) ExecuteDDL(ddlSQL string) error {
  151. return functions.ExecuteDDL(f.db, ddlSQL)
  152. }
  153. // ExecuteDDLWithTx 快捷在事务中执行DDL语句
  154. func (f *DBFactory) ExecuteDDLWithTx(ddlSQL string) error {
  155. return functions.ExecuteDDLWithTx(f.db, ddlSQL)
  156. }
  157. // ExecuteMultipleDDL 快捷执行多个DDL语句
  158. func (f *DBFactory) ExecuteMultipleDDL(ddlSQLs []string) error {
  159. return functions.ExecuteMultipleDDL(f.db, ddlSQLs)
  160. }
  161. // GetDBType 得到当前使用数据库类型
  162. func (f *DBFactory) GetDBType() string {
  163. return f.config.Type
  164. }
  165. // GetDatabaseName 获取数据库名称
  166. func (f *DBFactory) GetDatabaseName() string {
  167. return f.config.Database
  168. }
  169. // GetHost 获取数据库主机
  170. func (f *DBFactory) GetHost() string {
  171. return f.config.Host
  172. }
  173. // GetPort 获取数据库端口
  174. func (f *DBFactory) GetPort() int {
  175. return f.config.Port
  176. }
  177. // BeginTx 开始事务
  178. func (f *DBFactory) BeginTx() (*sqlx.Tx, error) {
  179. return f.db.Beginx()
  180. }
  181. // GetStats 获取数据库连接统计信息
  182. func (f *DBFactory) GetStats() interface{} {
  183. return f.db.Stats()
  184. }
  185. // Ping 测试数据库连接是否正常
  186. func (f *DBFactory) Ping() error {
  187. return f.db.Ping()
  188. }
  189. // GetAvailableDrivers 获取可用的数据库驱动
  190. func (f *DBFactory) GetAvailableDrivers() []string {
  191. return drivers.GetAllDrivers()
  192. }
  193. // ========== 新增的简化操作方法 ==========
  194. // QueryOne 查询单条记录
  195. func (f *DBFactory) QueryOne(sql string, dest interface{}) error {
  196. return f.db.Get(dest, sql)
  197. }
  198. // QueryOneWithParams 带参数查询单条记录
  199. func (f *DBFactory) QueryOneWithParams(sql string, dest interface{}, params ...interface{}) error {
  200. return f.db.Get(dest, sql, params...)
  201. }
  202. // QueryMany 查询多条记录
  203. func (f *DBFactory) QueryMany(sql string, dest interface{}) error {
  204. return f.db.Select(dest, sql)
  205. }
  206. // QueryManyWithParams 带参数查询多条记录
  207. func (f *DBFactory) QueryManyWithParams(sql string, dest interface{}, params ...interface{}) error {
  208. return f.db.Select(dest, sql, params...)
  209. }
  210. // Execute 执行更新操作
  211. func (f *DBFactory) Execute(sql string) (int64, error) {
  212. result, err := f.db.Exec(sql)
  213. if err != nil {
  214. return 0, err
  215. }
  216. return result.RowsAffected()
  217. }
  218. // ExecuteWithParams 带参数执行更新操作
  219. func (f *DBFactory) ExecuteWithParams(sql string, params ...interface{}) (int64, error) {
  220. result, err := f.db.Exec(sql, params...)
  221. if err != nil {
  222. return 0, err
  223. }
  224. return result.RowsAffected()
  225. }
  226. // QueryMap 查询单条记录到map
  227. func (f *DBFactory) QueryMap(sql string) (map[string]interface{}, error) {
  228. result := make(map[string]interface{})
  229. err := f.db.QueryRowx(sql).MapScan(result)
  230. return result, err
  231. }
  232. // QueryMapWithParams 带参数查询单条记录到map
  233. func (f *DBFactory) QueryMapWithParams(sql string, params ...interface{}) (map[string]interface{}, error) {
  234. result := make(map[string]interface{})
  235. err := f.db.QueryRowx(sql, params...).MapScan(result)
  236. return result, err
  237. }
  238. // QuerySliceMap 查询多条记录到map切片
  239. func (f *DBFactory) QuerySliceMap(sql string) ([]map[string]interface{}, error) {
  240. rows, err := f.db.Queryx(sql)
  241. if err != nil {
  242. return nil, err
  243. }
  244. defer rows.Close()
  245. var results []map[string]interface{}
  246. for rows.Next() {
  247. result := make(map[string]interface{})
  248. if err := rows.MapScan(result); err != nil {
  249. return nil, err
  250. }
  251. results = append(results, result)
  252. }
  253. return results, nil
  254. }
  255. // QuerySliceMapWithParams 带参数查询多条记录到map切片
  256. func (f *DBFactory) QuerySliceMapWithParams(sql string, params ...interface{}) ([]map[string]interface{}, error) {
  257. rows, err := f.db.Queryx(sql, params...)
  258. if err != nil {
  259. return nil, err
  260. }
  261. defer rows.Close()
  262. var results []map[string]interface{}
  263. for rows.Next() {
  264. result := make(map[string]interface{})
  265. if err := rows.MapScan(result); err != nil {
  266. return nil, err
  267. }
  268. results = append(results, result)
  269. }
  270. return results, nil
  271. }