package middleware import ( "fmt" "net/http" "runtime" "strings" "time" "git.shuncheng.lu/bigthing/gocommon/pkg/internal/util" "git.shuncheng.lu/bigthing/gocommon/pkg/logger" "git.shuncheng.lu/bigthing/gocommon/pkg/trace" "git.shuncheng.lu/bigthing/gocommon/pkg/net/engines" ) // 记录响应时常 func AccessLogMiddleware() engines.HandlerFunc { return func(c *engines.Context) { var ( beginTime = time.Now() statusCode = http.StatusOK ) defer func() { // recover 一定要在最上面,同时记录错误的堆栈信息,记住堆栈信息从下往上看是调用过程 if err := recover(); err != nil { statusCode = http.StatusInternalServerError stackBuffer := make([]byte, 64<<10) // 最多打印64k的堆栈信息 stackBuffer = stackBuffer[:runtime.Stack(stackBuffer, false)] logger.Errorc(c, "[AccessLogMiddleware] panic err, path: %s, panic_info: %+v, stack trace:\n %s", c.Request.URL.Path, err, stackBuffer) } // spend time duration := time.Now().Sub(beginTime) // 超过20s 记录504 todo 变更可以进行设置 if duration.Seconds() >= 20 { statusCode = http.StatusGatewayTimeout logger.Errorc(c, "[AccessLogMiddleware] timeout, path: %s, spend: %.4fs", c.Request.URL.Path, duration.Seconds()) } // 记录access log logger.AccessInfof(`%s %s %s [%s] "%s %s %s" %d %s %d`, c.ClientIP(), "-", "-", time.Now().Format("02/Jan/2006:15:04:05 -0700"), c.Request.Method, c.Request.URL.Path, c.Request.Proto, statusCode, "0", int64(duration/time.Millisecond)+1, // 加1毫秒,防止go整除舍去导致记录为0 ) }() c.Next() } } // 添加trace-id func AddTraceId() engines.HandlerFunc { return func(ctx *engines.Context) { traceIds := make([]string, 0, 2) skyWalkingTraceId := trace.GinGetTraceId(ctx) zipKinTraceId := getZipKinTraceId(ctx) if skyWalkingTraceId != "" { traceIds = append(traceIds, "s:"+skyWalkingTraceId) } if zipKinTraceId != "" { traceIds = append(traceIds, "z:"+zipKinTraceId) } if len(traceIds) == 0 { traceIds = append(traceIds, logger.GenerateTraceId()) } ctx.Set(logger.TraceIdKey, strings.Join(traceIds, ",")) ctx.Next() } } func getZipKinTraceId(ctx *engines.Context) string { return ctx.Request.Header.Get("x-b3-traceid") } //添加 skywalking func SkyWalkingTracerMiddleware(engine *engines.Engine) engines.HandlerFunc { return func(ctx *engines.Context) { trace.GinMiddleware(engine, trace.GetSkyWalkingTracer())(ctx) } } /** 必须全部绑定,不然抛出异常(同时解码header,使用url_decoder) */ func MustBindHeaderMiddleware(headers ...string) func(ctx *engines.Context) { return func(ctx *engines.Context) { if headers == nil || len(headers) == 0 { ctx.Next() return } result, _ := ctx.Value(util.HeaderData).(map[string]string) if result == nil { result = make(map[string]string, len(headers)) } for _, elem := range headers { header := ctx.GetHeader(elem) if header == "" { _ = ctx.JSON(util.HttpStatus, util.NewFailMessageHttpResponse(fmt.Sprintf("must bind header %s find err", elem))) ctx.Abort() return } result[header] = util.UrlDecode(elem) } ctx.Set(util.HeaderData, result) logger.Infoc(ctx, "[Middleware] MustBindHeaderMiddleware end, bind data: %v", result) ctx.Next() return } } /** 不需要全部绑定 */ func ShouldBindHeaderMiddleware(headers ...string) func(ctx *engines.Context) { return func(ctx *engines.Context) { if headers == nil || len(headers) == 0 { ctx.Next() return } result, _ := ctx.Value(util.HeaderData).(map[string]string) if result == nil { result = make(map[string]string, len(headers)) } for _, elem := range headers { header := ctx.GetHeader(elem) result[header] = util.UrlDecode(elem) } ctx.Set(util.HeaderData, result) logger.Infoc(ctx, "[Middleware] ShouldBindHeaderMiddleware end, bind data: %v", result) ctx.Next() return } }