설명 없음
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

db_factory.go 9.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. package database
  2. import (
  3. "fmt"
  4. "log"
  5. "sync"
  6. "git.x2erp.com/qdy/go-base/config"
  7. "git.x2erp.com/qdy/go-base/config/subconfigs"
  8. "git.x2erp.com/qdy/go-base/ctx"
  9. "git.x2erp.com/qdy/go-base/logger"
  10. "git.x2erp.com/qdy/go-base/model/response"
  11. "git.x2erp.com/qdy/go-db/drivers"
  12. "git.x2erp.com/qdy/go-db/functions"
  13. "github.com/jmoiron/sqlx"
  14. )
  15. // DBFactory 数据库工厂(全局单例模式)
  16. type DBFactory struct {
  17. db *sqlx.DB
  18. config *subconfigs.DatabaseConfig
  19. }
  20. var (
  21. instanceDB *DBFactory
  22. instanceDBOnce sync.Once
  23. initErrDB error
  24. )
  25. // CreateDBFactory 获取数据库工厂单例
  26. func CreateDBFactory(cfg config.IConfig) *DBFactory {
  27. config := cfg.GetDatabaseConfig()
  28. instanceDBOnce.Do(func() {
  29. instanceDB, initErrDB = createDBFactoryNew(config)
  30. })
  31. if initErrDB != nil {
  32. log.Fatalf("DBFactory is error: '%v'", initErrDB)
  33. }
  34. return instanceDB
  35. }
  36. // createDBFactoryNew 获取数据库工厂单例
  37. func createDBFactoryNew(config *subconfigs.DatabaseConfig) (*DBFactory, error) {
  38. if config == nil {
  39. log.Fatal("配置未初始化,请先在yaml进行配置")
  40. }
  41. // 设置默认值
  42. if config.MaxOpenConns == 0 {
  43. config.MaxOpenConns = 100
  44. }
  45. if config.MaxIdleConns == 0 {
  46. config.MaxIdleConns = 10
  47. }
  48. if config.ConnMaxLifetime == 0 {
  49. config.ConnMaxLifetime = 5 * 60 // 5分钟,单位秒
  50. }
  51. // 验证配置
  52. if config.Type == "" {
  53. initErrDB = fmt.Errorf("database type must be configured")
  54. return nil, initErrDB
  55. }
  56. if config.Host == "" {
  57. initErrDB = fmt.Errorf("database host must be configured")
  58. return nil, initErrDB
  59. }
  60. if config.Database == "" {
  61. initErrDB = fmt.Errorf("database name must be configured")
  62. return nil, initErrDB
  63. }
  64. log.Printf("Creating database connection...")
  65. // 获取对应的驱动
  66. dbDriver, err := drivers.Get(config.Type)
  67. if err != nil {
  68. initErrDB = fmt.Errorf("failed to get database driver: %v", err)
  69. return nil, initErrDB
  70. }
  71. // 将内部 DBConfig 转换为 drivers.DBConfig
  72. driverConfig := drivers.DBConfig{
  73. Type: config.Type,
  74. Host: config.Host,
  75. Port: config.Port,
  76. Username: config.Username,
  77. Password: config.Password,
  78. Database: config.Database,
  79. MaxOpenConns: config.MaxOpenConns,
  80. MaxIdleConns: config.MaxIdleConns,
  81. ConnMaxLifetime: config.ConnMaxLifetime,
  82. }
  83. // 创建数据库连接
  84. db, err := dbDriver.Open(driverConfig)
  85. if err != nil {
  86. initErrDB = fmt.Errorf("failed to open database connection: %v", err)
  87. return nil, initErrDB
  88. }
  89. // 测试连接
  90. if err := functions.TestConnection(db, config.Type); err != nil {
  91. db.Close()
  92. initErrDB = fmt.Errorf("database connection test failed: %v", err)
  93. return nil, initErrDB
  94. }
  95. log.Printf("DBFactory is successfully created.\n")
  96. instanceDB = &DBFactory{
  97. db: db,
  98. config: config,
  99. }
  100. return instanceDB, initErrDB
  101. }
  102. // ========== DBFactory 实例方法 ==========
  103. // GetDB 获取数据库连接
  104. func (f *DBFactory) GetDB() *sqlx.DB {
  105. return f.db
  106. }
  107. func (f *DBFactory) GetName() string {
  108. return "DBFactory"
  109. }
  110. // Close 关闭数据库连接
  111. func (f *DBFactory) Close() {
  112. if f.db != nil {
  113. err := f.db.Close()
  114. if err != nil {
  115. logger.Errorf("failed to close database connection: %v", err)
  116. }
  117. log.Printf("Database connection closed gracefully\n")
  118. f.db = nil
  119. }
  120. }
  121. // GetConfig 获取配置信息
  122. func (f *DBFactory) GetConfig() subconfigs.DatabaseConfig {
  123. return *f.config
  124. }
  125. // TestConnection 测试连接
  126. func (f *DBFactory) TestConnection() error {
  127. return functions.TestConnection(f.db, f.config.Type)
  128. }
  129. // ========== 快捷操作方法 ==========
  130. // QueryToJSON 快捷查询,直接返回 JSON 字节流
  131. func (f *DBFactory) QueryToJSON(sql string, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
  132. return functions.QueryToJSON(f.db, sql, reqCtx)
  133. }
  134. // QueryPositionalToJSON 位置参数查询并返回 JSON 字节数据
  135. func (f *DBFactory) QueryPositionalToJSON(sql string, params []interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
  136. return functions.QueryPositionalToJSON(f.db, sql, params, reqCtx)
  137. }
  138. // QueryParamsNameToJSON 命名参数查询并返回 JSON 字节数据
  139. func (f *DBFactory) QueryParamsNameToJSON(sql string, params map[string]interface{}, reqCtx *ctx.RequestContext) *response.QueryResult[[]map[string]interface{}] {
  140. return functions.QueryParamsNameToJSON(f.db, sql, params, reqCtx)
  141. }
  142. // QueryToCSV 快捷查询,直接返回 CSV 字符串(包含表头)
  143. func (f *DBFactory) QueryToCSV(sql string, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) {
  144. return functions.QueryToCSV(f.db, sql, writerHeader, reqCtx)
  145. }
  146. // QueryPositionalToCSV 位置参数查询并返回 CSV 字节数据
  147. func (f *DBFactory) QueryPositionalToCSV(sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
  148. return functions.QueryPositionalToCSV(f.db, sql, writerHeader, params, reqCtx)
  149. }
  150. // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
  151. func (f *DBFactory) QueryParamsNameToCSV(sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
  152. return functions.QueryParamsNameToCSV(f.db, sql, writerHeader, params, reqCtx)
  153. }
  154. // ExecuteDDL 快捷执行DDL语句
  155. func (f *DBFactory) ExecuteDDL(ddlSQL string) error {
  156. return functions.ExecuteDDL(f.db, ddlSQL)
  157. }
  158. // ExecuteDDLWithTx 快捷在事务中执行DDL语句
  159. func (f *DBFactory) ExecuteDDLWithTx(ddlSQL string) error {
  160. return functions.ExecuteDDLWithTx(f.db, ddlSQL)
  161. }
  162. // ExecuteMultipleDDL 快捷执行多个DDL语句
  163. func (f *DBFactory) ExecuteMultipleDDL(ddlSQLs []string) error {
  164. return functions.ExecuteMultipleDDL(f.db, ddlSQLs)
  165. }
  166. // GetDBType 得到当前使用数据库类型
  167. func (f *DBFactory) GetDBType() string {
  168. return f.config.Type
  169. }
  170. // GetDatabaseName 获取数据库名称
  171. func (f *DBFactory) GetDatabaseName() string {
  172. return f.config.Database
  173. }
  174. // GetHost 获取数据库主机
  175. func (f *DBFactory) GetHost() string {
  176. return f.config.Host
  177. }
  178. // GetPort 获取数据库端口
  179. func (f *DBFactory) GetPort() int {
  180. return f.config.Port
  181. }
  182. // BeginTx 开始事务
  183. func (f *DBFactory) BeginTx() (*sqlx.Tx, error) {
  184. return f.db.Beginx()
  185. }
  186. // GetStats 获取数据库连接统计信息
  187. func (f *DBFactory) GetStats() interface{} {
  188. return f.db.Stats()
  189. }
  190. // Ping 测试数据库连接是否正常
  191. func (f *DBFactory) Ping() error {
  192. return f.db.Ping()
  193. }
  194. // GetAvailableDrivers 获取可用的数据库驱动
  195. func (f *DBFactory) GetAvailableDrivers() []string {
  196. return drivers.GetAllDrivers()
  197. }
  198. // ========== 新增的简化操作方法 ==========
  199. // QueryOne 查询单条记录
  200. func (f *DBFactory) QueryOne(sql string, dest interface{}) error {
  201. return f.db.Get(dest, sql)
  202. }
  203. // QueryOneWithParams 带参数查询单条记录
  204. func (f *DBFactory) QueryOneWithParams(sql string, dest interface{}, params ...interface{}) error {
  205. return f.db.Get(dest, sql, params...)
  206. }
  207. // QueryMany 查询多条记录
  208. func (f *DBFactory) QueryMany(sql string, dest interface{}) error {
  209. return f.db.Select(dest, sql)
  210. }
  211. // QueryManyWithParams 带参数查询多条记录
  212. func (f *DBFactory) QueryManyWithParams(sql string, dest interface{}, params ...interface{}) error {
  213. return f.db.Select(dest, sql, params...)
  214. }
  215. // Execute 执行更新操作
  216. func (f *DBFactory) Execute(sql string) (int64, error) {
  217. result, err := f.db.Exec(sql)
  218. if err != nil {
  219. return 0, err
  220. }
  221. return result.RowsAffected()
  222. }
  223. // ExecuteWithParams 带参数执行更新操作
  224. func (f *DBFactory) ExecuteWithParams(sql string, params ...interface{}) (int64, error) {
  225. result, err := f.db.Exec(sql, params...)
  226. if err != nil {
  227. return 0, err
  228. }
  229. return result.RowsAffected()
  230. }
  231. // QueryMap 查询单条记录到map
  232. func (f *DBFactory) QueryMap(sql string) (map[string]interface{}, error) {
  233. result := make(map[string]interface{})
  234. err := f.db.QueryRowx(sql).MapScan(result)
  235. return result, err
  236. }
  237. // QueryMapWithParams 带参数查询单条记录到map
  238. func (f *DBFactory) QueryMapWithParams(sql string, params ...interface{}) (map[string]interface{}, error) {
  239. result := make(map[string]interface{})
  240. err := f.db.QueryRowx(sql, params...).MapScan(result)
  241. return result, err
  242. }
  243. // QuerySliceMap 查询多条记录到map切片
  244. func (f *DBFactory) QuerySliceMap(sql string) ([]map[string]interface{}, error) {
  245. rows, err := f.db.Queryx(sql)
  246. if err != nil {
  247. return nil, err
  248. }
  249. defer rows.Close()
  250. var results []map[string]interface{}
  251. for rows.Next() {
  252. result := make(map[string]interface{})
  253. if err := rows.MapScan(result); err != nil {
  254. return nil, err
  255. }
  256. results = append(results, result)
  257. }
  258. return results, nil
  259. }
  260. // QuerySliceMapWithParams 带参数查询多条记录到map切片
  261. func (f *DBFactory) QuerySliceMapWithParams(sql string, params ...interface{}) ([]map[string]interface{}, error) {
  262. rows, err := f.db.Queryx(sql, params...)
  263. if err != nil {
  264. return nil, err
  265. }
  266. defer rows.Close()
  267. var results []map[string]interface{}
  268. for rows.Next() {
  269. result := make(map[string]interface{})
  270. if err := rows.MapScan(result); err != nil {
  271. return nil, err
  272. }
  273. results = append(results, result)
  274. }
  275. return results, nil
  276. }