| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- package middleware
- import (
- "fmt"
- "net/http"
- "runtime"
- "strings"
- "time"
- "git.shuncheng.lu/bigthing/gocommon/pkg/internal/util"
- "git.shuncheng.lu/bigthing/gocommon/pkg/logger"
- "git.shuncheng.lu/bigthing/gocommon/pkg/trace"
- "git.shuncheng.lu/bigthing/gocommon/pkg/net/engines"
- )
- // 记录响应时常
- func AccessLogMiddleware() engines.HandlerFunc {
- return func(c *engines.Context) {
- var (
- beginTime = time.Now()
- statusCode = http.StatusOK
- )
- defer func() {
- // recover 一定要在最上面,同时记录错误的堆栈信息,记住堆栈信息从下往上看是调用过程
- if err := recover(); err != nil {
- statusCode = http.StatusInternalServerError
- stackBuffer := make([]byte, 64<<10) // 最多打印64k的堆栈信息
- stackBuffer = stackBuffer[:runtime.Stack(stackBuffer, false)]
- logger.Errorc(c, "[AccessLogMiddleware] panic err, path: %s, panic_info: %+v, stack trace:\n %s", c.Request.URL.Path, err, stackBuffer)
- }
- // spend time
- duration := time.Now().Sub(beginTime)
- // 超过20s 记录504 todo 变更可以进行设置
- if duration.Seconds() >= 20 {
- statusCode = http.StatusGatewayTimeout
- logger.Errorc(c, "[AccessLogMiddleware] timeout, path: %s, spend: %.4fs", c.Request.URL.Path, duration.Seconds())
- }
- // 记录access log
- logger.AccessInfof(`%s %s %s [%s] "%s %s %s" %d %s %d`,
- c.ClientIP(),
- "-",
- "-",
- time.Now().Format("02/Jan/2006:15:04:05 -0700"),
- c.Request.Method,
- c.Request.URL.Path,
- c.Request.Proto,
- statusCode,
- "0",
- int64(duration/time.Millisecond)+1, // 加1毫秒,防止go整除舍去导致记录为0
- )
- }()
- c.Next()
- }
- }
- // 添加trace-id
- func AddTraceId() engines.HandlerFunc {
- return func(ctx *engines.Context) {
- traceIds := make([]string, 0, 2)
- skyWalkingTraceId := trace.GinGetTraceId(ctx)
- zipKinTraceId := getZipKinTraceId(ctx)
- if skyWalkingTraceId != "" {
- traceIds = append(traceIds, "s:"+skyWalkingTraceId)
- }
- if zipKinTraceId != "" {
- traceIds = append(traceIds, "z:"+zipKinTraceId)
- }
- if len(traceIds) == 0 {
- traceIds = append(traceIds, logger.GenerateTraceId())
- }
- ctx.Set(logger.TraceIdKey, strings.Join(traceIds, ","))
- ctx.Next()
- }
- }
- func getZipKinTraceId(ctx *engines.Context) string {
- return ctx.Request.Header.Get("x-b3-traceid")
- }
- //添加 skywalking
- func SkyWalkingTracerMiddleware(engine *engines.Engine) engines.HandlerFunc {
- return func(ctx *engines.Context) {
- trace.GinMiddleware(engine, trace.GetSkyWalkingTracer())(ctx)
- }
- }
- /**
- 必须全部绑定,不然抛出异常(同时解码header,使用url_decoder)
- */
- func MustBindHeaderMiddleware(headers ...string) func(ctx *engines.Context) {
- return func(ctx *engines.Context) {
- if headers == nil || len(headers) == 0 {
- ctx.Next()
- return
- }
- result, _ := ctx.Value(util.HeaderData).(map[string]string)
- if result == nil {
- result = make(map[string]string, len(headers))
- }
- for _, elem := range headers {
- header := ctx.GetHeader(elem)
- if header == "" {
- _ = ctx.JSON(util.HttpStatus, util.NewFailMessageHttpResponse(fmt.Sprintf("must bind header %s find err", elem)))
- ctx.Abort()
- return
- }
- result[header] = util.UrlDecode(elem)
- }
- ctx.Set(util.HeaderData, result)
- logger.Infoc(ctx, "[Middleware] MustBindHeaderMiddleware end, bind data: %v", result)
- ctx.Next()
- return
- }
- }
- /**
- 不需要全部绑定
- */
- func ShouldBindHeaderMiddleware(headers ...string) func(ctx *engines.Context) {
- return func(ctx *engines.Context) {
- if headers == nil || len(headers) == 0 {
- ctx.Next()
- return
- }
- result, _ := ctx.Value(util.HeaderData).(map[string]string)
- if result == nil {
- result = make(map[string]string, len(headers))
- }
- for _, elem := range headers {
- header := ctx.GetHeader(elem)
- result[header] = util.UrlDecode(elem)
- }
- ctx.Set(util.HeaderData, result)
- logger.Infoc(ctx, "[Middleware] ShouldBindHeaderMiddleware end, bind data: %v", result)
- ctx.Next()
- return
- }
- }
|