package functions import ( "database/sql" "encoding/csv" "fmt" "strings" "github.com/jmoiron/sqlx" ) // QueryToCSV 无参数查询并返回 CSV 字节数据 func QueryToCSV(db *sqlx.DB, sql string) ([]byte, error) { if sql == "" { return nil, fmt.Errorf("SQL query cannot be empty") } rows, err := db.Query(sql) if err != nil { return nil, fmt.Errorf("query execution failed: %v", err) } return rowsToCSV(rows) } // QueryParamsToCSV 位置参数查询并返回 CSV 字节数据 func QueryPositionalToCSV(db *sqlx.DB, sql string, params ...interface{}) ([]byte, error) { if sql == "" { return nil, fmt.Errorf("SQL query cannot be empty") } rows, err := db.Query(sql, params...) if err != nil { return nil, fmt.Errorf("query execution failed: %v", err) } return rowsToCSV(rows) } // QueryParamsNameToCSV 命名参数查询并返回 CSV 字节数据 // params 可以是 map[string]interface{} 或结构体 func QueryParamsNameToCSV(db *sqlx.DB, sql string, params map[string]interface{}) ([]byte, error) { if sql == "" { return nil, fmt.Errorf("SQL query cannot be empty") } query, args, err := sqlx.Named(sql, params) if err != nil { return nil, fmt.Errorf("failed to bind named parameters: %v", err) } query = db.Rebind(query) rows, err := db.Query(query, args...) if err != nil { return nil, fmt.Errorf("query execution failed: %v", err) } return rowsToCSV(rows) } // / rowsToCSV 公共方法:将查询结果转换为 CSV 字节数据 func rowsToCSV(rows *sql.Rows) ([]byte, error) { defer rows.Close() columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("failed to get columns: %v", err) } var builder strings.Builder writer := csv.NewWriter(&builder) // 写入表头 if err := writer.Write(columns); err != nil { return nil, fmt.Errorf("failed to write CSV header: %v", err) } for rows.Next() { values := make([]interface{}, len(columns)) valuePtrs := make([]any, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return nil, fmt.Errorf("failed to scan row: %v", err) } // 所有值转为字符串 row := make([]string, len(columns)) for i, val := range values { if val == nil { row[i] = "" } else { row[i] = fmt.Sprintf("%v", val) } } if err := writer.Write(row); err != nil { return nil, fmt.Errorf("failed to write CSV row: %v", err) } } writer.Flush() if err := writer.Error(); err != nil { return nil, fmt.Errorf("failed to flush CSV: %v", err) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("row iteration error: %v", err) } return []byte(builder.String()), nil }