package queryreq import ( "fmt" "strings" ) // FieldMapper 字段映射函数类型:前端字段名 -> 数据库列名 type FieldMapper func(string) string // DefaultFieldMapper 默认字段映射(直接使用原字段名) func DefaultFieldMapper(field string) string { return field } // BuildWhereClause 安全构建WHERE子句 // 返回: WHERE条件字符串, 参数值数组 func BuildWhereClause(filters []FilterParam, fieldMapper FieldMapper) (string, []interface{}) { if len(filters) == 0 { return "", nil } var conditions []string var args []interface{} for _, filter := range filters { if !filter.Operator.IsValid() { continue } // 获取安全的数据库列名 dbField := fieldMapper(filter.Field) if dbField == "" { continue // 无效字段,跳过 } // 根据运算符构建条件 condition, filterArgs := buildCondition(dbField, filter.Operator, filter.Value) if condition != "" { conditions = append(conditions, condition) args = append(args, filterArgs...) } } if len(conditions) == 0 { return "", nil } return "WHERE " + strings.Join(conditions, " AND "), args } // BuildOrderByClause 安全构建ORDER BY子句 // 返回: ORDER BY子句字符串 func BuildOrderByClause(sorts []SortParam, fieldMapper FieldMapper) string { if len(sorts) == 0 { return "" } var orderClauses []string for _, sort := range sorts { // 获取安全的数据库列名 dbField := fieldMapper(sort.Field) if dbField == "" { continue // 无效字段,跳过 } // 验证排序方向 order := strings.ToUpper(sort.Order) if order != "ASC" && order != "DESC" { order = "ASC" // 默认升序 } orderClauses = append(orderClauses, fmt.Sprintf("%s %s", dbField, order)) } if len(orderClauses) == 0 { return "" } return "ORDER BY " + strings.Join(orderClauses, ", ") } // buildCondition 构建单个筛选条件 func buildCondition(field string, operator Operator, value interface{}) (string, []interface{}) { switch operator { case OpEquals, OpNotEquals, OpGreaterThan, OpLessThan, OpGreaterOrEq, OpLessOrEq: return fmt.Sprintf("%s %s ?", field, operator.ToSQL()), []interface{}{value} case OpLike: // LIKE操作自动添加通配符 if strVal, ok := value.(string); ok { return fmt.Sprintf("%s LIKE ?", field), []interface{}{"%" + strVal + "%"} } // 非字符串类型,直接使用值 return fmt.Sprintf("%s LIKE ?", field), []interface{}{value} case OpIn: return buildInCondition(field, value) default: return "", nil } } // buildInCondition 构建IN条件 func buildInCondition(field string, value interface{}) (string, []interface{}) { // 处理IN操作符的值 values, ok := processInValue(value) if !ok || len(values) == 0 { return "", nil } // 构建占位符 (?, ?, ?) placeholders := strings.Repeat("?, ", len(values)) placeholders = placeholders[:len(placeholders)-2] // 移除最后的逗号和空格 return fmt.Sprintf("%s IN (%s)", field, placeholders), values } // processInValue 处理IN操作符的值,转换为参数数组 // 支持逗号分隔的字符串、字符串数组、接口数组等 func processInValue(value interface{}) ([]interface{}, bool) { switch v := value.(type) { case string: // 逗号分隔的字符串 if strings.Contains(v, ",") { parts := strings.Split(v, ",") var result []interface{} for _, part := range parts { trimmed := strings.TrimSpace(part) if trimmed != "" { // 尝试解析为数值 if num, err := parseNumber(trimmed); err == nil { result = append(result, num) } else { result = append(result, trimmed) } } } return result, len(result) > 0 } // 单个字符串值 return []interface{}{v}, true case []string: // 字符串数组 var result []interface{} for _, item := range v { if trimmed := strings.TrimSpace(item); trimmed != "" { if num, err := parseNumber(trimmed); err == nil { result = append(result, num) } else { result = append(result, trimmed) } } } return result, len(result) > 0 case []interface{}: // 接口数组 var result []interface{} for _, item := range v { if item != nil { result = append(result, item) } } return result, len(result) > 0 default: // 其他类型视为单个值 return []interface{}{v}, true } } // parseNumber 尝试将字符串解析为数值 func parseNumber(s string) (interface{}, error) { // 尝试解析为整数 if intVal, err := parseInt(s); err == nil { return intVal, nil } // 尝试解析为浮点数 if floatVal, err := parseFloat(s); err == nil { return floatVal, nil } return nil, fmt.Errorf("not a number") } // parseInt 解析整数 func parseInt(s string) (int64, error) { var val int64 _, err := fmt.Sscanf(s, "%d", &val) return val, err } // parseFloat 解析浮点数 func parseFloat(s string) (float64, error) { var val float64 _, err := fmt.Sscanf(s, "%f", &val) return val, err }