| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- package ctx
-
- import (
- "context"
- "net/http"
- )
-
- // UserContext 用户上下文
- type RequestContext struct {
- //StdCtx context.Context // 新增:标准上下文
- TraceID string `json:"trace_id"`
- ServiceName string `json:"service_name"`
- InstanceName string `json:"instance_name"`
- TenantID string `json:"tenant_id"`
- UserID string `json:"user_id"`
- ProjectID string `json:"project_id"`
- Username string `json:"username"`
- }
-
- // 内部key,不会和其他包冲突
- type ctxKey struct{}
-
- var loggerKey = ctxKey{}
-
- // // GetStdContext 获取标准上下文(方便使用)
- // func (rc *RequestContext) GetStdContext() context.Context {
- // if rc.StdCtx == nil {
- // return context.Background()
- // }
- // return rc.StdCtx
- // }
-
- // Save RequestContext
- func SaveContext(r *http.Request, requestContext *RequestContext) *http.Request {
- ctx := context.WithValue(r.Context(), loggerKey, requestContext)
- return r.WithContext(ctx)
- }
-
- // GetContext 从请求获取RequestContext
- func GetContext(r *http.Request) *RequestContext {
- if r == nil {
- return &RequestContext{}
- }
-
- if v := r.Context().Value(loggerKey); v != nil {
- // 修正:类型断言为 *RequestContext
- if loggerCtx, ok := v.(*RequestContext); ok {
- return loggerCtx
- }
- }
-
- return &RequestContext{}
- }
-
- // FromContext 从 context.Context 中提取 RequestContext
- func FromContext(ctx context.Context) *RequestContext {
- if ctx == nil {
- return &RequestContext{}
- }
-
- if v := ctx.Value(loggerKey); v != nil {
- if reqCtx, ok := v.(*RequestContext); ok {
- return reqCtx
- }
- }
-
- return &RequestContext{}
- }
-
- func GetContextTest() *RequestContext {
-
- return &RequestContext{
- TraceID: "test-TTraceID",
- TenantID: "test-TenantID",
- ServiceName: "test-ServiceName",
- InstanceName: "test-InstanceName",
- UserID: "test-UserID",
- Username: "test-Username",
- ProjectID: "test-ProjectID",
- }
- }
|