| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- package functions
-
- import (
- "database/sql"
- "encoding/csv"
- "fmt"
- "strings"
-
- "github.com/jmoiron/sqlx"
- )
-
- // QueryToCSV 无参数查询并返回 CSV 字节数据
- func QueryToCSV(db *sqlx.DB, sql string) ([]byte, error) {
- if sql == "" {
- return nil, fmt.Errorf("SQL query cannot be empty")
- }
-
- rows, err := db.Query(sql)
- if err != nil {
- return nil, fmt.Errorf("query execution failed: %v", err)
- }
-
- return rowsToCSV(rows)
- }
-
- // QueryParamsToCSV 位置参数查询并返回 CSV 字节数据
- func QueryPositionalToCSV(db *sqlx.DB, sql string, positionalParams ...interface{}) ([]byte, error) {
- if sql == "" {
- return nil, fmt.Errorf("SQL query cannot be empty")
- }
-
- rows, err := db.Query(sql, positionalParams)
- if err != nil {
- return nil, fmt.Errorf("query execution failed: %v", err)
- }
-
- return rowsToCSV(rows)
- }
-
- // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据
- // params 可以是 map[string]interface{} 或结构体
- func QueryParamsNameToCSV(db *sqlx.DB, sql string, params map[string]interface{}) ([]byte, error) {
- if sql == "" {
- return nil, fmt.Errorf("SQL query cannot be empty")
- }
-
- query, args, err := sqlx.Named(sql, params)
- if err != nil {
- return nil, fmt.Errorf("failed to bind named parameters: %v", err)
- }
-
- query = db.Rebind(query)
- rows, err := db.Query(query, args...)
- if err != nil {
- return nil, fmt.Errorf("query execution failed: %v", err)
- }
-
- return rowsToCSV(rows)
- }
-
- // / rowsToCSV 公共方法:将查询结果转换为 CSV 字节数据
- func rowsToCSV(rows *sql.Rows) ([]byte, error) {
- defer rows.Close()
-
- columns, err := rows.Columns()
- if err != nil {
- return nil, fmt.Errorf("failed to get columns: %v", err)
- }
-
- var builder strings.Builder
- writer := csv.NewWriter(&builder)
-
- // 写入表头
- if err := writer.Write(columns); err != nil {
- return nil, fmt.Errorf("failed to write CSV header: %v", err)
- }
-
- for rows.Next() {
- values := make([]interface{}, len(columns))
- valuePtrs := make([]any, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
-
- if err := rows.Scan(valuePtrs...); err != nil {
- return nil, fmt.Errorf("failed to scan row: %v", err)
- }
-
- // 所有值转为字符串
- row := make([]string, len(columns))
- for i, val := range values {
- if val == nil {
- row[i] = ""
- } else {
- row[i] = fmt.Sprintf("%v", val)
- }
- }
-
- if err := writer.Write(row); err != nil {
- return nil, fmt.Errorf("failed to write CSV row: %v", err)
- }
- }
-
- writer.Flush()
- if err := writer.Error(); err != nil {
- return nil, fmt.Errorf("failed to flush CSV: %v", err)
- }
-
- if err := rows.Err(); err != nil {
- return nil, fmt.Errorf("row iteration error: %v", err)
- }
-
- return []byte(builder.String()), nil
- }
|