package functions import ( "database/sql" "encoding/csv" "fmt" "strings" "git.x2erp.com/qdy/go-base/ctx" "git.x2erp.com/qdy/go-base/logger" "github.com/jmoiron/sqlx" "go.uber.org/zap" ) // QueryToCSV 无参数查询并返回 CSV 字节数据 func QueryToCSV(db *sqlx.DB, sql string, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) { logger.DebugC(reqCtx, "Executing QueryToCSV", zap.String("sql", sql), zap.Bool("writerHeader", writerHeader)) if sql == "" { return nil, logger.ErrorCf(reqCtx, "SQL query cannot be empty") } rows, err := db.Query(sql) if err != nil { return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err) } return rowsToCSV(rows, writerHeader, reqCtx) } // QueryParamsToCSV 位置参数查询并返回 CSV 字节数据 func QueryPositionalToCSV(db *sqlx.DB, sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) { logger.DebugC(reqCtx, "Executing QueryToCSV: sql=%s, writerHeader=%v", sql, writerHeader) if sql == "" { return nil, logger.ErrorCf(reqCtx, "SQL query cannot be empty") } rows, err := db.Query(sql, params...) if err != nil { return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err) } return rowsToCSV(rows, writerHeader, reqCtx) } // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据 // params 可以是 map[string]interface{} 或结构体 func QueryParamsNameToCSV(db *sqlx.DB, sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) { logger.DebugC(reqCtx, "Executing QueryToCSV", zap.String("sql", sql), zap.Bool("writerHeader", writerHeader)) if sql == "" { return nil, logger.ErrorCf(reqCtx, "SQL query cannot be empty") } query, args, err := sqlx.Named(sql, params) if err != nil { return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err) } query = db.Rebind(query) rows, err := db.Query(query, args...) if err != nil { return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err) } return rowsToCSV(rows, writerHeader, reqCtx) } // rowsToCSV 公共方法:将查询结果转换为 CSV 字节数据 func rowsToCSV(rows *sql.Rows, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) { defer rows.Close() columns, err := rows.Columns() if err != nil { return nil, logger.ErrorCf(reqCtx, "failed to get columns: %v", err) } var builder strings.Builder writer := csv.NewWriter(&builder) // 根据参数决定是否写入表头 if writerHeader { if err := writer.Write(columns); err != nil { return nil, logger.ErrorCf(reqCtx, "failed to write CSV header: %v", err) } } for rows.Next() { values := make([]interface{}, len(columns)) valuePtrs := make([]any, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return nil, logger.ErrorCf(reqCtx, "failed to scan row: %v", err) } // 所有值转为字符串 row := make([]string, len(columns)) for i, val := range values { if val == nil { row[i] = "" } else { row[i] = fmt.Sprintf("%v", val) } } if err := writer.Write(row); err != nil { return nil, logger.ErrorCf(reqCtx, "failed to write CSV row: %v", err) } } writer.Flush() if err := writer.Error(); err != nil { return nil, logger.ErrorCf(reqCtx, "failed to flush CSV: %v", err) } if err := rows.Err(); err != nil { return nil, logger.ErrorCf(reqCtx, "row iteration error: %v", err) } return []byte(builder.String()), nil }