| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- package function
-
- import (
- "database/sql"
- "encoding/csv"
- "fmt"
- "strings"
-
- "git.x2erp.com/qdy/go-base/ctx"
- "git.x2erp.com/qdy/go-base/logger"
- "github.com/jmoiron/sqlx"
- "go.uber.org/zap"
- )
-
- // QueryToCSV 无参数查询并返回 CSV 字节数据
- func QueryToCSV(db *sqlx.DB, sql string, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) {
- logger.DebugC(reqCtx, "Executing QueryToCSV",
- zap.String("sql", sql),
- zap.Bool("writerHeader", writerHeader))
-
- if sql == "" {
- return nil, logger.ErrorCf(reqCtx, "SQL query cannot be empty")
- }
-
- rows, err := db.Query(sql)
- if err != nil {
- return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err)
- }
-
- return rowsToCSV(rows, writerHeader, reqCtx)
- }
-
- // QueryParamsToCSV 位置参数查询并返回 CSV 字节数据
- func QueryPositionalToCSV(db *sqlx.DB, sql string, writerHeader bool, params []interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
-
- logger.DebugC(reqCtx, "Executing QueryToCSV: sql=%s, writerHeader=%v", sql, writerHeader)
-
- if sql == "" {
- return nil, logger.ErrorCf(reqCtx, "SQL query cannot be empty")
- }
-
- rows, err := db.Query(sql, params...)
- if err != nil {
- return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err)
- }
-
- return rowsToCSV(rows, writerHeader, reqCtx)
- }
-
- // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
- // params 可以是 map[string]interface{} 或结构体
- func QueryParamsNameToCSV(db *sqlx.DB, sql string, writerHeader bool, params map[string]interface{}, reqCtx *ctx.RequestContext) ([]byte, error) {
-
- logger.DebugC(reqCtx, "Executing QueryToCSV",
- zap.String("sql", sql),
- zap.Bool("writerHeader", writerHeader))
- if sql == "" {
- return nil, logger.ErrorCf(reqCtx, "SQL query cannot be empty")
- }
-
- query, args, err := sqlx.Named(sql, params)
- if err != nil {
- return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err)
- }
-
- query = db.Rebind(query)
- rows, err := db.Query(query, args...)
- if err != nil {
- return nil, logger.ErrorCf(reqCtx, "query execution failed: %v", err)
- }
-
- return rowsToCSV(rows, writerHeader, reqCtx)
- }
-
- // rowsToCSV 公共方法:将查询结果转换为 CSV 字节数据
- func rowsToCSV(rows *sql.Rows, writerHeader bool, reqCtx *ctx.RequestContext) ([]byte, error) {
- defer rows.Close()
-
- columns, err := rows.Columns()
- if err != nil {
- return nil, logger.ErrorCf(reqCtx, "failed to get columns: %v", err)
- }
-
- var builder strings.Builder
- writer := csv.NewWriter(&builder)
-
- // 根据参数决定是否写入表头
- if writerHeader {
- if err := writer.Write(columns); err != nil {
- return nil, logger.ErrorCf(reqCtx, "failed to write CSV header: %v", err)
- }
- }
-
- for rows.Next() {
- values := make([]interface{}, len(columns))
- valuePtrs := make([]any, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
-
- if err := rows.Scan(valuePtrs...); err != nil {
- return nil, logger.ErrorCf(reqCtx, "failed to scan row: %v", err)
- }
-
- // 所有值转为字符串
- row := make([]string, len(columns))
- for i, val := range values {
- if val == nil {
- row[i] = ""
- } else {
- row[i] = fmt.Sprintf("%v", val)
- }
- }
-
- if err := writer.Write(row); err != nil {
- return nil, logger.ErrorCf(reqCtx, "failed to write CSV row: %v", err)
- }
- }
-
- writer.Flush()
- if err := writer.Error(); err != nil {
- return nil, logger.ErrorCf(reqCtx, "failed to flush CSV: %v", err)
- }
-
- if err := rows.Err(); err != nil {
- return nil, logger.ErrorCf(reqCtx, "row iteration error: %v", err)
- }
-
- return []byte(builder.String()), nil
- }
|