package webx import ( "encoding/json" "net/http" "reflect" "strings" ) // WebService 路由服务 type WebService struct { router *http.ServeMux } // NewWebService 创建WebService func NewWebService(router *http.ServeMux) *WebService { return &WebService{router: router} } // RouteBuilder 路由构建器 type RouteBuilder struct { method string path string handlerFunc reflect.Value paramNames []string paramTypes []reflect.Type middlewares []func(http.Handler) http.Handler router *http.ServeMux } // GET 注册GET请求 func (ws *WebService) GET(path string, handler interface{}) *RouteBuilder { return ws.handle("GET", path, handler) } // POST 注册POST请求 func (ws *WebService) POST(path string, handler interface{}) *RouteBuilder { return ws.handle("POST", path, handler) } // PUT 注册PUT请求 func (ws *WebService) PUT(path string, handler interface{}) *RouteBuilder { return ws.handle("PUT", path, handler) } // DELETE 注册DELETE请求 func (ws *WebService) DELETE(path string, handler interface{}) *RouteBuilder { return ws.handle("DELETE", path, handler) } // handle 统一处理方法 func (ws *WebService) handle(method, path string, handler interface{}) *RouteBuilder { // 解析路径参数名 paramNames := extractPathParams(path) // 获取处理器函数信息 handlerValue := reflect.ValueOf(handler) handlerType := handlerValue.Type() // 验证处理器函数 if handlerType.Kind() != reflect.Func { panic("handler must be a function") } // 获取参数类型 paramTypes := make([]reflect.Type, handlerType.NumIn()) for i := 0; i < handlerType.NumIn(); i++ { paramTypes[i] = handlerType.In(i) } // 验证返回值 if handlerType.NumOut() != 2 { panic("handler must return exactly 2 values: (T, error)") } return &RouteBuilder{ method: method, path: path, handlerFunc: handlerValue, paramNames: paramNames, paramTypes: paramTypes, router: ws.router, } } // Use 添加中间件 func (rb *RouteBuilder) Use(middleware ...func(http.Handler) http.Handler) *RouteBuilder { rb.middlewares = append(rb.middlewares, middleware...) return rb } // Register 注册路由 func (rb *RouteBuilder) Register() { // 创建适配器 adapter := &handlerAdapter{ method: rb.method, pathPattern: rb.path, paramNames: rb.paramNames, paramTypes: rb.paramTypes, handlerFunc: rb.handlerFunc, } // 应用中间件 var handler http.Handler = adapter for i := len(rb.middlewares) - 1; i >= 0; i-- { handler = rb.middlewares[i](handler) } // 注册到路由器 rb.router.Handle(rb.path, handler) } // handlerAdapter 处理器适配器 type handlerAdapter struct { method string pathPattern string paramNames []string paramTypes []reflect.Type handlerFunc reflect.Value } func (ha *handlerAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 只处理指定方法的请求 if r.Method != ha.method { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 1. 解析路径参数 pathParams := ha.parsePathParams(r.URL.Path) // 2. 构建函数参数 args := ha.buildArgs(r, pathParams) if args == nil { http.Error(w, "Invalid parameters", http.StatusBadRequest) return } // 3. 调用处理器函数 results := ha.handlerFunc.Call(args) // 4. 处理返回结果 ha.handleResponse(w, results) } // parsePathParams 解析路径参数 func (ha *handlerAdapter) parsePathParams(requestPath string) map[string]string { params := make(map[string]string) // 简单实现:按/分割,然后匹配 patternParts := strings.Split(ha.pathPattern, "/") requestParts := strings.Split(requestPath, "/") if len(patternParts) != len(requestParts) { return params } for i, patternPart := range patternParts { if strings.HasPrefix(patternPart, "{") && strings.HasSuffix(patternPart, "}") { paramName := patternPart[1 : len(patternPart)-1] params[paramName] = requestParts[i] } } return params } // buildArgs 构建函数参数 func (ha *handlerAdapter) buildArgs(r *http.Request, pathParams map[string]string) []reflect.Value { args := make([]reflect.Value, len(ha.paramTypes)) // 参数索引 pathParamIndex := 0 for i, paramType := range ha.paramTypes { // 情况1:路径参数(必须是string类型) if pathParamIndex < len(ha.paramNames) && paramType.Kind() == reflect.String { paramName := ha.paramNames[pathParamIndex] if value, ok := pathParams[paramName]; ok { args[i] = reflect.ValueOf(value) pathParamIndex++ continue } } // 情况2:Body参数(POST/PUT等请求的结构体) if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" { // 检查是否是结构体类型 if paramType.Kind() == reflect.Struct || (paramType.Kind() == reflect.Ptr && paramType.Elem().Kind() == reflect.Struct) { bodyParam, err := ha.parseBody(r, paramType) if err != nil { return nil } args[i] = bodyParam continue } } // 无法处理的参数类型 return nil } return args } // parseBody 解析请求体 func (ha *handlerAdapter) parseBody(r *http.Request, paramType reflect.Type) (reflect.Value, error) { // 创建参数实例 var paramValue reflect.Value if paramType.Kind() == reflect.Ptr { paramValue = reflect.New(paramType.Elem()) } else { paramValue = reflect.New(paramType) } // 解析JSON if err := json.NewDecoder(r.Body).Decode(paramValue.Interface()); err != nil { return reflect.Value{}, err } // 如果是非指针类型,需要解引用 if paramType.Kind() != reflect.Ptr { paramValue = paramValue.Elem() } return paramValue, nil } // handleResponse 处理响应 func (ha *handlerAdapter) handleResponse(w http.ResponseWriter, results []reflect.Value) { // 第一个返回值是数据 data := results[0].Interface() // 第二个返回值是error errVal := results[1].Interface() if errVal != nil { err := errVal.(error) http.Error(w, err.Error(), http.StatusInternalServerError) return } // 返回JSON响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(data) } // extractPathParams 从路径中提取参数名 func extractPathParams(path string) []string { var params []string parts := strings.Split(path, "/") for _, part := range parts { if strings.HasPrefix(part, "{") && strings.HasSuffix(part, "}") { paramName := part[1 : len(part)-1] params = append(params, paramName) } } return params }