middleware.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package middleware
  2. import (
  3. "fmt"
  4. "net/http"
  5. "runtime"
  6. "strings"
  7. "time"
  8. "git.shuncheng.lu/bigthing/gocommon/pkg/internal/util"
  9. "git.shuncheng.lu/bigthing/gocommon/pkg/logger"
  10. "git.shuncheng.lu/bigthing/gocommon/pkg/trace"
  11. "git.shuncheng.lu/bigthing/gocommon/pkg/net/engines"
  12. )
  13. // 记录响应时常
  14. func AccessLogMiddleware() engines.HandlerFunc {
  15. return func(c *engines.Context) {
  16. var (
  17. beginTime = time.Now()
  18. statusCode = http.StatusOK
  19. )
  20. defer func() {
  21. // recover 一定要在最上面,同时记录错误的堆栈信息,记住堆栈信息从下往上看是调用过程
  22. if err := recover(); err != nil {
  23. statusCode = http.StatusInternalServerError
  24. stackBuffer := make([]byte, 64<<10) // 最多打印64k的堆栈信息
  25. stackBuffer = stackBuffer[:runtime.Stack(stackBuffer, false)]
  26. logger.Errorc(c, "[AccessLogMiddleware] panic err, path: %s, panic_info: %+v, stack trace:\n %s", c.Request.URL.Path, err, stackBuffer)
  27. }
  28. // spend time
  29. duration := time.Now().Sub(beginTime)
  30. // 超过20s 记录504 todo 变更可以进行设置
  31. if duration.Seconds() >= 20 {
  32. statusCode = http.StatusGatewayTimeout
  33. logger.Errorc(c, "[AccessLogMiddleware] timeout, path: %s, spend: %.4fs", c.Request.URL.Path, duration.Seconds())
  34. }
  35. // 记录access log
  36. logger.AccessInfof(`%s %s %s [%s] "%s %s %s" %d %s %d`,
  37. c.ClientIP(),
  38. "-",
  39. "-",
  40. time.Now().Format("02/Jan/2006:15:04:05 -0700"),
  41. c.Request.Method,
  42. c.Request.URL.Path,
  43. c.Request.Proto,
  44. statusCode,
  45. "0",
  46. int64(duration/time.Millisecond)+1, // 加1毫秒,防止go整除舍去导致记录为0
  47. )
  48. }()
  49. c.Next()
  50. }
  51. }
  52. // 添加trace-id
  53. func AddTraceId() engines.HandlerFunc {
  54. return func(ctx *engines.Context) {
  55. traceIds := make([]string, 0, 2)
  56. skyWalkingTraceId := trace.GinGetTraceId(ctx)
  57. zipKinTraceId := getZipKinTraceId(ctx)
  58. if skyWalkingTraceId != "" {
  59. traceIds = append(traceIds, "s:"+skyWalkingTraceId)
  60. }
  61. if zipKinTraceId != "" {
  62. traceIds = append(traceIds, "z:"+zipKinTraceId)
  63. }
  64. if len(traceIds) == 0 {
  65. traceIds = append(traceIds, logger.GenerateTraceId())
  66. }
  67. ctx.Set(logger.TraceIdKey, strings.Join(traceIds, ","))
  68. ctx.Next()
  69. }
  70. }
  71. func getZipKinTraceId(ctx *engines.Context) string {
  72. return ctx.Request.Header.Get("x-b3-traceid")
  73. }
  74. //添加 skywalking
  75. func SkyWalkingTracerMiddleware(engine *engines.Engine) engines.HandlerFunc {
  76. return func(ctx *engines.Context) {
  77. trace.GinMiddleware(engine, trace.GetSkyWalkingTracer())(ctx)
  78. }
  79. }
  80. /**
  81. 必须全部绑定,不然抛出异常(同时解码header,使用url_decoder)
  82. */
  83. func MustBindHeaderMiddleware(headers ...string) func(ctx *engines.Context) {
  84. return func(ctx *engines.Context) {
  85. if headers == nil || len(headers) == 0 {
  86. ctx.Next()
  87. return
  88. }
  89. result, _ := ctx.Value(util.HeaderData).(map[string]string)
  90. if result == nil {
  91. result = make(map[string]string, len(headers))
  92. }
  93. for _, elem := range headers {
  94. header := ctx.GetHeader(elem)
  95. if header == "" {
  96. _ = ctx.JSON(util.HttpStatus, util.NewFailMessageHttpResponse(fmt.Sprintf("must bind header %s find err", elem)))
  97. ctx.Abort()
  98. return
  99. }
  100. result[header] = util.UrlDecode(elem)
  101. }
  102. ctx.Set(util.HeaderData, result)
  103. logger.Infoc(ctx, "[Middleware] MustBindHeaderMiddleware end, bind data: %v", result)
  104. ctx.Next()
  105. return
  106. }
  107. }
  108. /**
  109. 不需要全部绑定
  110. */
  111. func ShouldBindHeaderMiddleware(headers ...string) func(ctx *engines.Context) {
  112. return func(ctx *engines.Context) {
  113. if headers == nil || len(headers) == 0 {
  114. ctx.Next()
  115. return
  116. }
  117. result, _ := ctx.Value(util.HeaderData).(map[string]string)
  118. if result == nil {
  119. result = make(map[string]string, len(headers))
  120. }
  121. for _, elem := range headers {
  122. header := ctx.GetHeader(elem)
  123. result[header] = util.UrlDecode(elem)
  124. }
  125. ctx.Set(util.HeaderData, result)
  126. logger.Infoc(ctx, "[Middleware] ShouldBindHeaderMiddleware end, bind data: %v", result)
  127. ctx.Next()
  128. return
  129. }
  130. }