| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- package main
-
- import (
- "database/sql"
- "fmt"
- "log"
- "net/http"
- "strings"
- "time"
-
- "git.x2erp.com/qdy/go-base/types"
- "git.x2erp.com/qdy/go-db/factory"
- "git.x2erp.com/qdy/go-service-ck/auth"
- "git.x2erp.com/qdy/go-service-ck/routes"
- "github.com/gin-gonic/gin"
- )
-
- // 单例实例
- var (
- dbFactory *factory.DBFactory
- db *sql.DB
- )
-
- // initDB 初始化数据库连接(单例)
- func initDB() error {
- var err error
-
- // 创建数据库工厂(单例)
- if dbFactory == nil {
- dbFactory, err = factory.NewDBFactory()
- if err != nil {
- return fmt.Errorf("failed to create DB factory: %v", err)
- }
- }
-
- // 创建数据库连接(单例)
- if db == nil {
- db, err = dbFactory.CreateDB()
- if err != nil {
- return fmt.Errorf("failed to create database connection: %v", err)
- }
- }
-
- // 测试连接
- config := dbFactory.GetConfig()
- if err := factory.TestConnection(db, config.GetDatabase().Type); err != nil {
- return fmt.Errorf("database connection test failed: %v", err)
- }
-
- return nil
- }
-
- func main() {
- // 1. 初始化数据库(单例)
- if err := initDB(); err != nil {
- log.Fatalf("Database initialization failed: %v", err)
- }
-
- // 2. 显示基础信息
- drivers := dbFactory.GetAvailableDrivers()
- config := dbFactory.GetConfig()
-
- log.Printf("Service Port: %d", config.GetService().Port)
- log.Printf("Service IdleTimeout: %d", config.GetService().IdleTimeout)
- log.Printf("Service ReadTimeout: %d", config.GetService().ReadTimeout)
- log.Printf("Service WriteTimeout: %d", config.GetService().WriteTimeout)
- log.Printf("Service TrustedProxies: %s", config.GetService().TrustedProxies)
-
- log.Printf("Available database drivers: %v", drivers)
- log.Printf("Using database type: %s", config.GetDatabase().Type)
- log.Printf("Database host: %s:%d", config.GetDatabase().Host, config.GetDatabase().Port)
- log.Printf("Database name: %s", config.GetDatabase().Database)
- log.Println("Database connection test passed!")
-
- // 3. 启动Gin HTTP服务
- startHTTPServer()
- }
-
- // 启动HTTP服务器
- func startHTTPServer() {
- //建立路由
- router := gin.Default()
- config := dbFactory.GetConfig()
- serviceConfig := config.GetService()
-
- // 核心路由
- router.GET("/api/health", routes.HealthHandler(db, config.GetDatabase().Type))
- router.POST("/api/init/table", auth.AuthMiddleware(), routes.ExecuteDDLHandler(db))
- //router.POST("/api/query/csv", auth.AuthMiddleware(), withQueryRequest(routes.QueryHandlerCSV(db)))
- router.GET("/api/info", routes.InfoHandler(dbFactory))
-
- // 日志输出配置信息
- log.Printf("Service Port: %d", serviceConfig.Port)
- log.Printf("Service IdleTimeout: %d", serviceConfig.IdleTimeout)
- log.Printf("Service ReadTimeout: %d", serviceConfig.ReadTimeout)
- log.Printf("Service WriteTimeout: %d", serviceConfig.WriteTimeout)
- log.Printf("Service TrustedProxies: %s", serviceConfig.TrustedProxies)
-
- // 设置可信代理
- setupTrustedProxies(router, serviceConfig.TrustedProxies)
-
- // 启动服务
- log.Println("POST /api/query - Execute SQL query to JSON")
- log.Println("POST /api/query/csv - Execute SQL query to CSV")
- log.Println("GET /api/health - Health check")
- log.Println("GET /api/info - Database info")
-
- // 创建HTTP服务器配置
- server := &http.Server{
- Addr: fmt.Sprintf(":%d", serviceConfig.Port),
- Handler: router,
- IdleTimeout: time.Duration(serviceConfig.IdleTimeout) * time.Second,
- ReadTimeout: time.Duration(serviceConfig.ReadTimeout) * time.Second,
- WriteTimeout: time.Duration(serviceConfig.WriteTimeout) * time.Second,
- }
-
- log.Printf("Starting HTTP server on port %d", serviceConfig.Port)
- if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- log.Fatalf("Failed to start server: %v", err)
- }
- }
-
- // 参数绑定包装器
- func withQueryRequest(handler func(c *gin.Context, req types.QueryRequest)) gin.HandlerFunc {
- return func(c *gin.Context) {
- var req types.QueryRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(400, &types.QueryResult{
- Success: false,
- Error: "Invalid request: " + err.Error(),
- Data: nil,
- })
- return
- }
-
- handler(c, req)
- }
- }
-
- // 设置可信代理
- func setupTrustedProxies(router *gin.Engine, trustedProxiesStr string) {
- if trustedProxiesStr == "" {
- setupTrustedProxiesRouter(router, nil)
- return
- }
-
- // 按逗号分割字符串,并去除空格
- proxies := strings.Split(trustedProxiesStr, ",")
- trimmedProxies := make([]string, 0, len(proxies))
-
- for _, proxy := range proxies {
- trimmed := strings.TrimSpace(proxy)
- if trimmed != "" {
- trimmedProxies = append(trimmedProxies, trimmed)
- }
- }
-
- if len(trimmedProxies) > 0 {
- setupTrustedProxiesRouter(router, trimmedProxies)
- } else {
- setupTrustedProxiesRouter(router, nil)
- }
- }
-
- func setupTrustedProxiesRouter(router *gin.Engine, trimmedProxies []string) {
-
- err := router.SetTrustedProxies(trimmedProxies)
- if err != nil {
- log.Printf("Warning: Failed to set trusted proxies: %v", err)
- } else {
- log.Printf("Trusted proxies set: %v", trimmedProxies)
- }
-
- }
|