Нема описа
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

db_factory.go 8.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. package database
  2. import (
  3. "fmt"
  4. "log"
  5. "sync"
  6. "git.x2erp.com/qdy/go-base/config"
  7. "git.x2erp.com/qdy/go-base/config/subconfigs"
  8. "git.x2erp.com/qdy/go-base/ctx"
  9. "git.x2erp.com/qdy/go-base/logger"
  10. "git.x2erp.com/qdy/go-base/model/response"
  11. "git.x2erp.com/qdy/go-db/driver"
  12. "git.x2erp.com/qdy/go-db/function"
  13. "github.com/jmoiron/sqlx"
  14. )
  15. // DBFactory 数据库工厂(全局单例模式)
  16. type DBFactory struct {
  17. db *sqlx.DB
  18. config *subconfigs.DatabaseConfig
  19. }
  20. var (
  21. instanceDB *DBFactory
  22. instanceDBOnce sync.Once
  23. initErrDB error
  24. )
  25. // CreateDBFactory 获取数据库工厂单例
  26. func CreateDBFactory(cfg config.IConfig) *DBFactory {
  27. config := cfg.GetDatabaseConfig()
  28. instanceDBOnce.Do(func() {
  29. instanceDB, initErrDB = createDBFactoryNew(config)
  30. })
  31. if initErrDB != nil {
  32. log.Fatalf("DBFactory is error: '%v'", initErrDB)
  33. }
  34. return instanceDB
  35. }
  36. // createDBFactoryNew 获取数据库工厂单例
  37. func createDBFactoryNew(config *subconfigs.DatabaseConfig) (*DBFactory, error) {
  38. if config == nil {
  39. log.Fatal("配置未初始化,请先在yaml进行配置")
  40. }
  41. if config.Type == "" {
  42. initErrDB = fmt.Errorf("database type must be configured")
  43. return nil, initErrDB
  44. }
  45. if config.Host == "" {
  46. initErrDB = fmt.Errorf("database host must be configured")
  47. return nil, initErrDB
  48. }
  49. if config.Database == "" {
  50. initErrDB = fmt.Errorf("database name must be configured")
  51. return nil, initErrDB
  52. }
  53. log.Printf("Creating database connection...")
  54. // 获取对应的驱动
  55. dbDriver, err := driver.Get(config.Type)
  56. if err != nil {
  57. initErrDB = fmt.Errorf("failed to get database driver: %v", err)
  58. return nil, initErrDB
  59. }
  60. // 将内部 DBConfig 转换为 drivers.DBConfig
  61. driverConfig := driver.DBConfig{
  62. Type: config.Type,
  63. Host: config.Host,
  64. Port: config.Port,
  65. Username: config.Username,
  66. Password: config.Password,
  67. Database: config.Database,
  68. MaxOpenConns: config.MaxOpenConns,
  69. MaxIdleConns: config.MaxIdleConns,
  70. ConnMaxLifetime: config.ConnMaxLifetime,
  71. }
  72. // 创建数据库连接
  73. db, err := dbDriver.Open(driverConfig)
  74. if err != nil {
  75. initErrDB = fmt.Errorf("failed to open database connection: %v", err)
  76. return nil, initErrDB
  77. }
  78. // 测试连接
  79. if err := function.TestConnection(db, config.Type); err != nil {
  80. db.Close()
  81. initErrDB = fmt.Errorf("database connection test failed: %v", err)
  82. return nil, initErrDB
  83. }
  84. log.Printf("DBFactory is successfully created.\n")
  85. instanceDB = &DBFactory{
  86. db: db,
  87. config: config,
  88. }
  89. return instanceDB, initErrDB
  90. }
  91. // ========== DBFactory 实例方法 ==========
  92. // GetDB 获取数据库连接
  93. func (f *DBFactory) GetDB() *sqlx.DB {
  94. return f.db
  95. }
  96. func (f *DBFactory) GetName() string {
  97. return "DBFactory"
  98. }
  99. // Close 关闭数据库连接
  100. func (f *DBFactory) Close() {
  101. if f.db != nil {
  102. err := f.db.Close()
  103. if err != nil {
  104. logger.Errorf("failed to close database connection: %v", err)
  105. }
  106. log.Printf("Database connection closed gracefully\n")
  107. f.db = nil
  108. }
  109. }
  110. // GetConfig 获取配置信息
  111. func (f *DBFactory) GetConfig() subconfigs.DatabaseConfig {
  112. return *f.config
  113. }
  114. // TestConnection 测试连接
  115. func (f *DBFactory) TestConnection() error {
  116. return function.TestConnection(f.db, f.config.Type)
  117. }
  118. // ========== 快捷操作方法 ==========
  119. // QueryToJSON 快捷查询,直接返回 JSON 字节流
  120. func (f *DBFactory) QueryToJSON(sql string, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
  121. return function.QueryToJSON(f.db, sql, reqCtx)
  122. }
  123. // QueryPositionalToJSON 位置参数查询并返回 JSON 字节数据
  124. func (f *DBFactory) QueryPositionalToJSON(sql string, params []interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
  125. return function.QueryPositionalToJSON(f.db, sql, params, reqCtx)
  126. }
  127. // QueryParamsNameToJSON 命名参数查询并返回 JSON 字节数据
  128. func (f *DBFactory) QueryParamsNameToJSON(sql string, params map[string]interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
  129. return function.QueryParamsNameToJSON(f.db, sql, params, reqCtx)
  130. }
  131. // QueryToCSV 快捷查询,直接返回 CSV 字符串(包含表头)
  132. func (f *DBFactory) QueryToCSV(sql string, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) {
  133. return function.QueryToCSV(f.db, sql, writerHeader, reqCtx)
  134. }
  135. // QueryPositionalToCSV 位置参数查询并返回 CSV 字节数据
  136. func (f *DBFactory) QueryPositionalToCSV(sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
  137. return function.QueryPositionalToCSV(f.db, sql, writerHeader, params, reqCtx)
  138. }
  139. // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
  140. func (f *DBFactory) QueryParamsNameToCSV(sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
  141. return function.QueryParamsNameToCSV(f.db, sql, writerHeader, params, reqCtx)
  142. }
  143. // ExecuteDDL 快捷执行DDL语句
  144. func (f *DBFactory) ExecuteDDL(ddlSQL string) error {
  145. return function.ExecuteDDL(f.db, ddlSQL)
  146. }
  147. // ExecuteDDLWithTx 快捷在事务中执行DDL语句
  148. func (f *DBFactory) ExecuteDDLWithTx(ddlSQL string) error {
  149. return function.ExecuteDDLWithTx(f.db, ddlSQL)
  150. }
  151. // ExecuteMultipleDDL 快捷执行多个DDL语句
  152. func (f *DBFactory) ExecuteMultipleDDL(ddlSQLs []string) error {
  153. return function.ExecuteMultipleDDL(f.db, ddlSQLs)
  154. }
  155. // GetDBType 得到当前使用数据库类型
  156. func (f *DBFactory) GetDBType() string {
  157. return f.config.Type
  158. }
  159. // GetDatabaseName 获取数据库名称
  160. func (f *DBFactory) GetDatabaseName() string {
  161. return f.config.Database
  162. }
  163. // GetHost 获取数据库主机
  164. func (f *DBFactory) GetHost() string {
  165. return f.config.Host
  166. }
  167. // GetPort 获取数据库端口
  168. func (f *DBFactory) GetPort() int {
  169. return f.config.Port
  170. }
  171. // BeginTx 开始事务
  172. func (f *DBFactory) BeginTx() (*sqlx.Tx, error) {
  173. return f.db.Beginx()
  174. }
  175. // GetStats 获取数据库连接统计信息
  176. func (f *DBFactory) GetStats() interface{} {
  177. return f.db.Stats()
  178. }
  179. // Ping 测试数据库连接是否正常
  180. func (f *DBFactory) Ping() error {
  181. return f.db.Ping()
  182. }
  183. // GetAvailableDrivers 获取可用的数据库驱动
  184. func (f *DBFactory) GetAvailableDrivers() []string {
  185. return driver.GetAllDrivers()
  186. }
  187. // ========== 新增的简化操作方法 ==========
  188. // QueryOne 查询单条记录
  189. func (f *DBFactory) QueryOne(sql string, dest interface{}) error {
  190. return f.db.Get(dest, sql)
  191. }
  192. // QueryOneWithParams 带参数查询单条记录
  193. func (f *DBFactory) QueryOneWithParams(sql string, dest interface{}, params ...interface{}) error {
  194. return f.db.Get(dest, sql, params...)
  195. }
  196. // QueryMany 查询多条记录
  197. func (f *DBFactory) QueryMany(sql string, dest interface{}) error {
  198. return f.db.Select(dest, sql)
  199. }
  200. // QueryManyWithParams 带参数查询多条记录
  201. func (f *DBFactory) QueryManyWithParams(sql string, dest interface{}, params ...interface{}) error {
  202. return f.db.Select(dest, sql, params...)
  203. }
  204. // Execute 执行更新操作
  205. func (f *DBFactory) Execute(sql string) (int64, error) {
  206. result, err := f.db.Exec(sql)
  207. if err != nil {
  208. return 0, err
  209. }
  210. return result.RowsAffected()
  211. }
  212. // ExecuteWithParams 带参数执行更新操作
  213. func (f *DBFactory) ExecuteWithParams(sql string, params ...interface{}) (int64, error) {
  214. result, err := f.db.Exec(sql, params...)
  215. if err != nil {
  216. return 0, err
  217. }
  218. return result.RowsAffected()
  219. }
  220. // QueryMap 查询单条记录到map
  221. func (f *DBFactory) QueryMap(sql string) (map[string]interface{}, error) {
  222. result := make(map[string]interface{})
  223. err := f.db.QueryRowx(sql).MapScan(result)
  224. return result, err
  225. }
  226. // QueryMapWithParams 带参数查询单条记录到map
  227. func (f *DBFactory) QueryMapWithParams(sql string, params ...interface{}) (map[string]interface{}, error) {
  228. result := make(map[string]interface{})
  229. err := f.db.QueryRowx(sql, params...).MapScan(result)
  230. return result, err
  231. }
  232. // QuerySliceMap 查询多条记录到map切片
  233. func (f *DBFactory) QuerySliceMap(sql string) ([]map[string]interface{}, error) {
  234. rows, err := f.db.Queryx(sql)
  235. if err != nil {
  236. return nil, err
  237. }
  238. defer rows.Close()
  239. var results []map[string]interface{}
  240. for rows.Next() {
  241. result := make(map[string]interface{})
  242. if err := rows.MapScan(result); err != nil {
  243. return nil, err
  244. }
  245. results = append(results, result)
  246. }
  247. return results, nil
  248. }
  249. // QuerySliceMapWithParams 带参数查询多条记录到map切片
  250. func (f *DBFactory) QuerySliceMapWithParams(sql string, params ...interface{}) ([]map[string]interface{}, error) {
  251. rows, err := f.db.Queryx(sql, params...)
  252. if err != nil {
  253. return nil, err
  254. }
  255. defer rows.Close()
  256. var results []map[string]interface{}
  257. for rows.Next() {
  258. result := make(map[string]interface{})
  259. if err := rows.MapScan(result); err != nil {
  260. return nil, err
  261. }
  262. results = append(results, result)
  263. }
  264. return results, nil
  265. }