package router import ( "context" "encoding/json" "net/http" "reflect" "strconv" "strings" "sync" "git.x2erp.com/qdy/go-base/ctx" ) // WebService 路由服务 type RouterService struct { router *http.ServeMux middlewares []func(http.Handler) http.Handler handlers map[string]http.Handler // key: method + " " + path registeredPaths map[string]bool // 路径是否已注册通用处理器 mu sync.RWMutex } // NewWebService 创建WebService func NewWebService(router *http.ServeMux) *RouterService { ws := &RouterService{ router: router, handlers: make(map[string]http.Handler), registeredPaths: make(map[string]bool), } return ws } // Use 添加全局中间件 func (ws *RouterService) Use(middleware func(http.Handler) http.Handler) *RouterService { ws.middlewares = append(ws.middlewares, middleware) return ws } // GET 注册GET请求 func (ws *RouterService) GET(path string, handler interface{}) *RouteBuilder { return ws.handle("GET", path, handler) } // POST 注册POST请求 func (ws *RouterService) POST(path string, handler interface{}) *RouteBuilder { return ws.handle("POST", path, handler) } // PUT 注册PUT请求 func (ws *RouterService) PUT(path string, handler interface{}) *RouteBuilder { return ws.handle("PUT", path, handler) } // DELETE 注册DELETE请求 func (ws *RouterService) DELETE(path string, handler interface{}) *RouteBuilder { return ws.handle("DELETE", path, handler) } // handle 统一处理方法 func (ws *RouterService) handle(method, path string, handler interface{}) *RouteBuilder { // 解析路径参数名 paramNames := extractPathParams(path) // 获取处理器函数信息 handlerValue := reflect.ValueOf(handler) handlerType := handlerValue.Type() // 验证处理器函数 if handlerType.Kind() != reflect.Func { panic("handler must be a function") } // 验证返回值 if handlerType.NumOut() != 2 { panic("handler must return exactly 2 values: (T, error)") } return &RouteBuilder{ ws: ws, method: method, path: path, handlerFunc: handlerValue, paramNames: paramNames, } } // RouteBuilder 路由构建器 type RouteBuilder struct { ws *RouterService method string path string handlerFunc reflect.Value paramNames []string middlewares []func(http.Handler) http.Handler description string } // WrapMiddleware 通用包装函数 func WrapMiddleware[T any](middleware func(http.Handler, T) http.Handler, dep T) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return middleware(next, dep) } } // Use 添加路由级中间件 func (rb *RouteBuilder) Use(middleware ...func(http.Handler) http.Handler) *RouteBuilder { rb.middlewares = append(rb.middlewares, middleware...) return rb } // Desc 添加描述 func (rb *RouteBuilder) Desc(description string) *RouteBuilder { rb.description = description return rb } // registerPathIfNeeded 注册路径通用处理器(如果尚未注册) func (ws *RouterService) registerPathIfNeeded(path string) { ws.mu.Lock() defer ws.mu.Unlock() if ws.registeredPaths[path] { return } // 创建通用处理器,根据方法分发 genericHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key := r.Method + " " + path ws.mu.RLock() handler, ok := ws.handlers[key] ws.mu.RUnlock() if !ok { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } handler.ServeHTTP(w, r) }) ws.router.Handle(path, genericHandler) ws.registeredPaths[path] = true } // Register 注册路由 func (rb *RouteBuilder) Register() { // 创建适配器 adapter := &handlerAdapter{ ws: rb.ws, method: rb.method, pathPattern: rb.path, paramNames: rb.paramNames, handlerFunc: rb.handlerFunc, } // 构建处理链:全局中间件 → 路由中间件 → 处理器 var handler http.Handler = adapter // 1. 应用路由级中间件(从后往前) for i := len(rb.middlewares) - 1; i >= 0; i-- { handler = rb.middlewares[i](handler) } // 2. 应用全局中间件(从后往前) for i := len(rb.ws.middlewares) - 1; i >= 0; i-- { handler = rb.ws.middlewares[i](handler) } // 存储处理器到映射 key := rb.method + " " + rb.path rb.ws.mu.Lock() rb.ws.handlers[key] = handler rb.ws.mu.Unlock() // 确保路径已注册通用处理器 rb.ws.registerPathIfNeeded(rb.path) } // handlerAdapter 处理器适配器 type handlerAdapter struct { ws *RouterService method string pathPattern string paramNames []string handlerFunc reflect.Value } func (ha *handlerAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 只处理指定方法的请求 if r.Method != ha.method { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 1. 解析路径参数 pathParams := ha.parsePathParams(r.URL.Path) // 2. 构建函数参数 args := ha.buildArgs(w, r, pathParams) if args == nil { http.Error(w, "Invalid parameters", http.StatusBadRequest) return } // 3. 调用处理器函数 results := ha.handlerFunc.Call(args) // 4. 处理返回结果 ha.handleResponse(w, results) } // parsePathParams 解析路径参数 func (ha *handlerAdapter) parsePathParams(requestPath string) map[string]string { params := make(map[string]string) pattern := strings.Trim(ha.pathPattern, "/") request := strings.Trim(requestPath, "/") patternParts := strings.Split(pattern, "/") requestParts := strings.Split(request, "/") for i, patternPart := range patternParts { if strings.HasPrefix(patternPart, "{") && strings.HasSuffix(patternPart, "}") { paramName := patternPart[1 : len(patternPart)-1] if i < len(requestParts) { params[paramName] = requestParts[i] } } } return params } func (ha *handlerAdapter) buildArgs(w http.ResponseWriter, r *http.Request, pathParams map[string]string) []reflect.Value { handlerType := ha.handlerFunc.Type() numIn := handlerType.NumIn() args := make([]reflect.Value, numIn) for i := 0; i < numIn; i++ { paramType := handlerType.In(i) paramName := getParamName(i, handlerType, ha.paramNames) // 1. 检查是否是特殊类型 if arg := ha.bindSpecialType(paramType, w, r); arg.IsValid() { args[i] = arg continue } // 2. 检查是否是 *ctx.RequestContext 类型 if paramType == reflect.TypeOf((*ctx.RequestContext)(nil)) { reqCtx := ctx.GetContext(r) args[i] = reflect.ValueOf(reqCtx) continue } // 3. 只按名称匹配路径参数 if value, ok := pathParams[paramName]; ok { args[i] = convertToType(value, paramType) continue } // 4. 尝试从查询参数获取 if queryValue := r.URL.Query().Get(paramName); queryValue != "" { args[i] = convertToType(queryValue, paramType) continue } // 5. 尝试从JSON body获取(POST/PUT请求) if (r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH") && (paramType.Kind() == reflect.Struct || (paramType.Kind() == reflect.Ptr && paramType.Elem().Kind() == reflect.Struct)) { if arg := ha.parseBody(r, paramType); arg.IsValid() { args[i] = arg continue } } // 6. 返回零值 args[i] = reflect.Zero(paramType) } return args } // bindSpecialType 绑定特殊类型 func (ha *handlerAdapter) bindSpecialType(paramType reflect.Type, w http.ResponseWriter, r *http.Request) reflect.Value { switch paramType { case reflect.TypeOf((*http.ResponseWriter)(nil)).Elem(): return reflect.ValueOf(w) case reflect.TypeOf((*http.Request)(nil)): return reflect.ValueOf(r) // 新增:支持 context.Context case reflect.TypeOf((*context.Context)(nil)).Elem(): return reflect.ValueOf(r.Context()) } return reflect.Value{} } // getParamName 获取参数名 func getParamName(index int, handlerType reflect.Type, pathParamNames []string) string { // 如果有路径参数名,优先使用 if index < len(pathParamNames) { return pathParamNames[index] } // 使用类型名的蛇形格式 paramType := handlerType.In(index) typeName := paramType.Name() if typeName == "" { return strconv.Itoa(index) } // 转换为蛇形格式:MyInterface -> my_interface var result []rune for i, r := range typeName { if i > 0 && 'A' <= r && r <= 'Z' { result = append(result, '_') } result = append(result, r) } return strings.ToLower(string(result)) } // convertToType 字符串转换为指定类型 func convertToType(str string, targetType reflect.Type) reflect.Value { switch targetType.Kind() { case reflect.String: return reflect.ValueOf(str) case reflect.Int: if i, err := strconv.Atoi(str); err == nil { return reflect.ValueOf(i) } case reflect.Int64: if i, err := strconv.ParseInt(str, 10, 64); err == nil { return reflect.ValueOf(i) } case reflect.Bool: if b, err := strconv.ParseBool(str); err == nil { return reflect.ValueOf(b) } case reflect.Float64: if f, err := strconv.ParseFloat(str, 64); err == nil { return reflect.ValueOf(f) } } return reflect.Value{} } // parseBody 解析请求体 func (ha *handlerAdapter) parseBody(r *http.Request, paramType reflect.Type) reflect.Value { // 创建参数实例 var paramValue reflect.Value if paramType.Kind() == reflect.Ptr { paramValue = reflect.New(paramType.Elem()) } else { paramValue = reflect.New(paramType) } // 解析JSON if err := json.NewDecoder(r.Body).Decode(paramValue.Interface()); err != nil { return reflect.Value{} } defer r.Body.Close() // 如果是非指针类型,需要解引用 if paramType.Kind() != reflect.Ptr { paramValue = paramValue.Elem() } return paramValue } // handleResponse 处理响应 func (ha *handlerAdapter) handleResponse(w http.ResponseWriter, results []reflect.Value) { // 第一个返回值是数据 data := results[0].Interface() // 第二个返回值是error errVal := results[1].Interface() if errVal != nil { err := errVal.(error) http.Error(w, err.Error(), http.StatusInternalServerError) return } // 检查是否是 []byte 类型 if bytes, ok := data.([]byte); ok { // 如果是 []byte,直接写入响应,不进行 JSON 编码 w.Write(bytes) return } // 返回JSON响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(data) } // extractPathParams 从路径中提取参数名 func extractPathParams(path string) []string { var params []string parts := strings.Split(strings.Trim(path, "/"), "/") for _, part := range parts { if strings.HasPrefix(part, "{") && strings.HasSuffix(part, "}") { paramName := part[1 : len(part)-1] params = append(params, paramName) } } return params }