package factory import ( "database/sql" "encoding/csv" "encoding/json" "fmt" "io" "strings" "time" "git.x2erp.com/qdy/go-base/types" ) func typesFunction() *types.QueryResult { // 确保使用了该包导出的类型,例如 QueryResult return &types.QueryResult{} } // QueryExecutor 查询执行器 type QueryExecutor struct { db *sql.DB } // NewQueryExecutor 创建查询执行器 func NewQueryExecutor(db *sql.DB) *QueryExecutor { return &QueryExecutor{db: db} } // QueryToJSON 执行查询并返回JSON格式数据(统一返回QueryResult) func (e *QueryExecutor) QueryToJSON(sql string) *types.QueryResult { startTime := time.Now() result := &types.QueryResult{} if sql == "" { result.Success = false result.Error = "SQL query cannot be empty" result.Time = time.Since(startTime).String() return result } rows, err := e.db.Query(sql) if err != nil { result.Success = false result.Error = fmt.Sprintf("Query execution failed: %v", err) result.Time = time.Since(startTime).String() return result } defer rows.Close() columns, err := rows.Columns() if err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to get columns: %v", err) result.Time = time.Since(startTime).String() return result } var results []map[string]interface{} count := 0 for rows.Next() { count++ values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to scan row: %v", err) result.Time = time.Since(startTime).String() return result } resultMap := make(map[string]interface{}) for i, col := range columns { // 完全不处理类型,直接赋值,让 json.Marshal 自己处理 resultMap[col] = values[i] } results = append(results, resultMap) } if err := rows.Err(); err != nil { result.Success = false result.Error = fmt.Sprintf("Row iteration error: %v", err) result.Time = time.Since(startTime).String() return result } jsonData, err := json.Marshal(results) if err != nil { result.Success = false result.Error = fmt.Sprintf("JSON marshal failed: %v", err) result.Time = time.Since(startTime).String() return result } result.Success = true result.Data = map[string]interface{}{ "json": string(jsonData), "rows": results, "count": count, } result.Count = count result.Time = time.Since(startTime).String() return result } // QueryToCSV 查询并返回 CSV 字节数据(包含表头) func (e *QueryExecutor) QueryToCSV(sql string) ([]byte, error) { if sql == "" { return nil, fmt.Errorf("SQL query cannot be empty") } rows, err := e.db.Query(sql) if err != nil { return nil, fmt.Errorf("query execution failed: %v", err) } defer rows.Close() columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("failed to get columns: %v", err) } var builder strings.Builder writer := csv.NewWriter(&builder) // 写入表头 if err := writer.Write(columns); err != nil { return nil, fmt.Errorf("failed to write CSV header: %v", err) } count := 0 for rows.Next() { count++ values := make([]interface{}, len(columns)) valuePtrs := make([]any, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return nil, fmt.Errorf("failed to scan row: %v", err) } // 所有值转为字符串 row := make([]string, len(columns)) for i, val := range values { if val == nil { row[i] = "" } else { row[i] = fmt.Sprintf("%v", val) } } if err := writer.Write(row); err != nil { return nil, fmt.Errorf("failed to write CSV row: %v", err) } } writer.Flush() if err := writer.Error(); err != nil { return nil, fmt.Errorf("failed to flush CSV: %v", err) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("row iteration error: %v", err) } return []byte(builder.String()), nil } // ExecuteQueryWithColumns 执行查询并返回完整结果(包含列信息) func (e *QueryExecutor) ExecuteQueryWithColumns(sql string) *types.QueryResult { startTime := time.Now() result := &types.QueryResult{} if sql == "" { result.Success = false result.Error = "SQL query cannot be empty" result.Time = time.Since(startTime).String() return result } rows, err := e.db.Query(sql) if err != nil { result.Success = false result.Error = fmt.Sprintf("Query execution failed: %v", err) result.Time = time.Since(startTime).String() return result } defer rows.Close() columns, err := rows.Columns() if err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to get columns: %v", err) result.Time = time.Since(startTime).String() return result } var results []map[string]interface{} count := 0 for rows.Next() { count++ values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to scan row: %v", err) result.Time = time.Since(startTime).String() return result } resultRow := make(map[string]interface{}) for i, col := range columns { val := values[i] switch v := val.(type) { case []byte: resultRow[col] = string(v) case time.Time: resultRow[col] = v.Format(time.RFC3339) default: resultRow[col] = v } } results = append(results, resultRow) } if err := rows.Err(); err != nil { result.Success = false result.Error = fmt.Sprintf("Row iteration error: %v", err) result.Time = time.Since(startTime).String() return result } result.Success = true result.Data = results result.Count = count result.Time = time.Since(startTime).String() return result } // ExecuteQueryDataOnly 执行查询并返回纯数据(不包含列信息,性能更高) func (e *QueryExecutor) ExecuteQueryDataOnly(sql string) *types.QueryResult { startTime := time.Now() result := &types.QueryResult{} if sql == "" { result.Success = false result.Error = "SQL query cannot be empty" result.Time = time.Since(startTime).String() return result } rows, err := e.db.Query(sql) if err != nil { result.Success = false result.Error = fmt.Sprintf("Query execution failed: %v", err) result.Time = time.Since(startTime).String() return result } defer rows.Close() columns, err := rows.Columns() if err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to get columns: %v", err) result.Time = time.Since(startTime).String() return result } var results []interface{} count := 0 for rows.Next() { count++ values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to scan row: %v", err) result.Time = time.Since(startTime).String() return result } resultRow := make([]interface{}, len(columns)) for i, val := range values { switch v := val.(type) { case []byte: resultRow[i] = string(v) case time.Time: resultRow[i] = v.Format(time.RFC3339) default: resultRow[i] = v } } results = append(results, resultRow) } if err := rows.Err(); err != nil { result.Success = false result.Error = fmt.Sprintf("Row iteration error: %v", err) result.Time = time.Since(startTime).String() return result } result.Success = true result.Data = map[string]interface{}{ "rows": results, "count": count, } result.Count = count result.Time = time.Since(startTime).String() return result } // ExecuteQueryCSV 执行查询并返回CSV格式数据 func (e *QueryExecutor) ExecuteQueryCSV(sql string, includeHeader bool) *types.QueryResult { startTime := time.Now() result := &types.QueryResult{} if sql == "" { result.Success = false result.Error = "SQL query cannot be empty" result.Time = time.Since(startTime).String() return result } rows, err := e.db.Query(sql) if err != nil { result.Success = false result.Error = fmt.Sprintf("Query execution failed: %v", err) result.Time = time.Since(startTime).String() return result } defer rows.Close() columns, err := rows.Columns() if err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to get columns: %v", err) result.Time = time.Since(startTime).String() return result } var csvBuilder strings.Builder writer := csv.NewWriter(&csvBuilder) if includeHeader { if err := writer.Write(columns); err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to write CSV header: %v", err) result.Time = time.Since(startTime).String() return result } } count := 0 for rows.Next() { count++ values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to scan row: %v", err) result.Time = time.Since(startTime).String() return result } rowData := make([]string, len(columns)) for i, val := range values { if val == nil { rowData[i] = "" continue } switch v := val.(type) { case []byte: rowData[i] = string(v) case string: rowData[i] = v case int, int8, int16, int32, int64: rowData[i] = fmt.Sprintf("%d", v) case uint, uint8, uint16, uint32, uint64: rowData[i] = fmt.Sprintf("%d", v) case float32, float64: rowData[i] = fmt.Sprintf("%f", v) case bool: if v { rowData[i] = "true" } else { rowData[i] = "false" } case time.Time: rowData[i] = v.Format(time.RFC3339) default: rowData[i] = fmt.Sprintf("%v", v) } } if err := writer.Write(rowData); err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to write CSV row: %v", err) result.Time = time.Since(startTime).String() return result } } if err := rows.Err(); err != nil { result.Success = false result.Error = fmt.Sprintf("Row iteration error: %v", err) result.Time = time.Since(startTime).String() return result } writer.Flush() if err := writer.Error(); err != nil { result.Success = false result.Error = fmt.Sprintf("Failed to flush CSV: %v", err) result.Time = time.Since(startTime).String() return result } result.Success = true result.Data = map[string]interface{}{ "csv": csvBuilder.String(), "count": count, "includeHeader": includeHeader, } result.Count = count result.Time = time.Since(startTime).String() return result } // ExecuteQueryCSVStream 流式返回CSV数据 func (e *QueryExecutor) ExecuteQueryCSVStream(sql string, w io.Writer, includeHeader bool) (int, error) { rows, err := e.db.Query(sql) if err != nil { return 0, err } defer rows.Close() columns, err := rows.Columns() if err != nil { return 0, err } writer := csv.NewWriter(w) count := 0 if includeHeader { if err := writer.Write(columns); err != nil { return 0, err } } for rows.Next() { count++ values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return count, err } rowData := make([]string, len(columns)) for i, val := range values { if val == nil { rowData[i] = "" continue } rowData[i] = fmt.Sprintf("%v", val) } if err := writer.Write(rowData); err != nil { return count, err } } writer.Flush() if err := writer.Error(); err != nil { return count, err } if err := rows.Err(); err != nil { return count, err } return count, nil }