diff --git a/logger.go b/logger.go index c894c60..bef3ff9 100644 --- a/logger.go +++ b/logger.go @@ -1,10 +1,12 @@ package logger import ( + "bytes" "io" "net/http" "os" "regexp" + "strings" "time" "github.com/gin-gonic/gin" @@ -30,6 +32,26 @@ type config struct { clientErrorLevel zerolog.Level // the log level used for request with status code >= 500 serverErrorLevel zerolog.Level + // whether to log response body for request with status code >= 400 + logErrorResponseBody bool + // whether to log response body for request with status code < 400 + logResponseBody bool + // max len of response body message (whatever the status code) + maxResponseBodyLen int + // whether to log request body + logRequestBody bool + // max len of log request body + maxRequestBodyLen int +} + +type bodyLogWriter struct { + gin.ResponseWriter + body *bytes.Buffer +} + +func (w bodyLogWriter) Write(b []byte) (int, error) { + w.body.Write(b) + return w.ResponseWriter.Write(b) } var isTerm bool = isatty.IsTerminal(os.Stdout.Fd()) @@ -37,10 +59,15 @@ var isTerm bool = isatty.IsTerminal(os.Stdout.Fd()) // SetLogger initializes the logging middleware. func SetLogger(opts ...Option) gin.HandlerFunc { cfg := &config{ - defaultLevel: zerolog.InfoLevel, - clientErrorLevel: zerolog.WarnLevel, - serverErrorLevel: zerolog.ErrorLevel, - output: gin.DefaultWriter, + defaultLevel: zerolog.InfoLevel, + clientErrorLevel: zerolog.WarnLevel, + serverErrorLevel: zerolog.ErrorLevel, + output: gin.DefaultWriter, + logErrorResponseBody: false, + logResponseBody: false, + maxResponseBodyLen: 50, + logRequestBody: false, + maxRequestBodyLen: 50, } // Loop through each option @@ -80,6 +107,26 @@ func SetLogger(opts ...Option) gin.HandlerFunc { path = path + "?" + raw } + var blw *bodyLogWriter + if cfg.logErrorResponseBody || cfg.logResponseBody { + blw = &bodyLogWriter{body: bytes.NewBufferString(""), ResponseWriter: c.Writer} + c.Writer = blw + } + + var requestBody string + if cfg.logRequestBody && c.Request.Body != nil { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + body = []byte(err.Error()) + } + + requestBody = string(body) + if len(requestBody) > cfg.maxRequestBodyLen { + requestBody = requestBody[:cfg.maxRequestBodyLen] + "..." + } + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + } + c.Next() track := true @@ -100,10 +147,29 @@ func SetLogger(opts ...Option) gin.HandlerFunc { } latency := end.Sub(start) - l = l.With(). - Int("status", c.Writer.Status()). + statusCode := c.Writer.Status() + var response string + withResponse := (cfg.logErrorResponseBody && statusCode >= 400) || (cfg.logResponseBody && statusCode < 400) + if withResponse && blw.body != nil { + response = blw.body.String() + response = strings.TrimPrefix(response, "\"") + response = strings.TrimSuffix(response, "\"") + if len(response) > cfg.maxResponseBodyLen { + response = response[:cfg.maxResponseBodyLen] + "..." + } + } + + ctx := l.With(). + Int("status", statusCode). Str("method", c.Request.Method). - Str("path", c.Request.URL.Path). + Str("path", c.Request.URL.Path) + if cfg.logRequestBody { + ctx = ctx.Logger().With().Str("body", requestBody) + } + if withResponse { + ctx = ctx.Logger().With().Str("response", response) + } + l = ctx.Logger().With(). Str("ip", c.ClientIP()). Dur("latency", latency). Str("user_agent", c.Request.UserAgent()).Logger() @@ -114,12 +180,12 @@ func SetLogger(opts ...Option) gin.HandlerFunc { } switch { - case c.Writer.Status() >= http.StatusBadRequest && c.Writer.Status() < http.StatusInternalServerError: + case statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError: { l.WithLevel(cfg.clientErrorLevel). Msg(msg) } - case c.Writer.Status() >= http.StatusInternalServerError: + case statusCode >= http.StatusInternalServerError: { l.WithLevel(cfg.serverErrorLevel). Msg(msg) diff --git a/logger_test.go b/logger_test.go index a9cd711..79bb476 100644 --- a/logger_test.go +++ b/logger_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "regexp" + "strings" "testing" "github.com/gin-gonic/gin" @@ -18,7 +19,16 @@ type header struct { } func performRequest(r http.Handler, method, path string, headers ...header) *httptest.ResponseRecorder { - req := httptest.NewRequest(method, path, nil) + return performRequestWithBody(r, method, path, "", headers...) +} + +func performRequestWithBody(r http.Handler, method, path string, body string, headers ...header) *httptest.ResponseRecorder { + var req *http.Request + if body != "" { + req = httptest.NewRequest(method, path, bytes.NewBuffer([]byte(body))) + } else { + req = httptest.NewRequest(method, path, nil) + } for _, h := range headers { req.Header.Add(h.Key, h.Value) } @@ -183,6 +193,113 @@ func TestLoggerParseLevel(t *testing.T) { } } +func TestLoggerWithErrorResponse(t *testing.T) { + buffer := new(bytes.Buffer) + gin.SetMode(gin.ReleaseMode) + r := gin.New() + r.Use(SetLogger(WithWriter(buffer), WithLogErrorResponseBody(true))) + r.GET("/example", func(c *gin.Context) {}) + r.POST("/example", func(c *gin.Context) { + c.String(http.StatusBadRequest, "bad status") + }) + + performRequest(r, "GET", "/example?a=100") + assert.NotContains(t, buffer.String(), "response= ") + + buffer.Reset() + performRequest(r, "POST", "/example?a=100") + assert.Contains(t, buffer.String(), "response=") + assert.Contains(t, buffer.String(), "\"bad status\"") +} + +func TestLoggerWithResponse(t *testing.T) { + buffer := new(bytes.Buffer) + gin.SetMode(gin.ReleaseMode) + r := gin.New() + r.Use(SetLogger(WithWriter(buffer), WithLogResponseBody(true))) + r.GET("/example", func(c *gin.Context) {}) + r.POST("/example", func(c *gin.Context) { + c.String(http.StatusOK, "example response") + }) + + performRequest(r, "GET", "/example?a=100") + assert.Contains(t, buffer.String(), "response=") + + buffer.Reset() + performRequest(r, "POST", "/example?a=100") + assert.Contains(t, buffer.String(), "response=") + assert.Contains(t, buffer.String(), "\"example response\"") +} + +func TestLoggerWithTruncatedResponse(t *testing.T) { + longMessage := strings.Repeat("X", 20) + truncatedMessage := strings.Repeat("X", 10) + "..." + buffer := new(bytes.Buffer) + gin.SetMode(gin.ReleaseMode) + r := gin.New() + r.Use(SetLogger(WithWriter(buffer), WithLogErrorResponseBody(true), WithLogResponseBody(true), WithMaxResponseBodyLen(10))) + r.GET("/example", func(c *gin.Context) { + c.String(http.StatusBadRequest, longMessage) + }) + r.POST("/example", func(c *gin.Context) { + // c.String(http.StatusOK, strings.Repeat("X", 20)) + c.String(http.StatusOK, longMessage) + }) + + performRequest(r, "GET", "/example?a=100") + assert.Contains(t, buffer.String(), "response=") + assert.Contains(t, buffer.String(), truncatedMessage) + + buffer.Reset() + performRequest(r, "POST", "/example?a=100") + assert.Contains(t, buffer.String(), "response=") + assert.Contains(t, buffer.String(), truncatedMessage) +} + +func TestLoggerWithRequest(t *testing.T) { + buffer := new(bytes.Buffer) + gin.SetMode(gin.ReleaseMode) + r := gin.New() + r.Use(SetLogger(WithWriter(buffer), WithLogRequestBody(true))) + r.GET("/example", func(c *gin.Context) {}) + r.POST("/example", func(c *gin.Context) { + c.String(http.StatusOK, "example response") + }) + + performRequestWithBody(r, "GET", "/example?a=100", "GET body") + assert.Contains(t, buffer.String(), "body=\"GET body\"") + + buffer.Reset() + performRequestWithBody(r, "POST", "/example?a=100", "POST body") + assert.Contains(t, buffer.String(), "body=\"POST body\"") + + buffer.Reset() + longBody := strings.Repeat("X", 20) + performRequestWithBody(r, "POST", "/example?a=100", longBody) + assert.Contains(t, buffer.String(), "body="+longBody+" ") +} + +func TestLoggerWithTruncatedRequest(t *testing.T) { + longBody := strings.Repeat("X", 20) + truncatedBody := strings.Repeat("X", 10) + "..." + + buffer := new(bytes.Buffer) + gin.SetMode(gin.ReleaseMode) + r := gin.New() + r.Use(SetLogger(WithWriter(buffer), WithLogRequestBody(true), WithMaxRequestBodyLen(10))) + r.GET("/example", func(c *gin.Context) {}) + r.POST("/example", func(c *gin.Context) { + c.String(http.StatusOK, "example response") + }) + + performRequestWithBody(r, "GET", "/example?a=100", longBody) + assert.Contains(t, buffer.String(), "body="+truncatedBody+"") + + buffer.Reset() + performRequestWithBody(r, "POST", "/example?a=100", longBody) + assert.Contains(t, buffer.String(), "body="+truncatedBody+"") +} + func BenchmarkLogger(b *testing.B) { gin.SetMode(gin.ReleaseMode) r := gin.New() diff --git a/options.go b/options.go index 362e4db..009bb37 100644 --- a/options.go +++ b/options.go @@ -74,3 +74,37 @@ func WithServerErrorLevel(lvl zerolog.Level) Option { c.serverErrorLevel = lvl }) } + +func WithLogErrorResponseBody(logErrorResponseBody bool) Option { + return optionFunc(func(c *config) { + c.logErrorResponseBody = logErrorResponseBody + }) +} + +func WithLogResponseBody(logResponseBody bool) Option { + return optionFunc(func(c *config) { + c.logResponseBody = logResponseBody + }) +} + +func WithMaxResponseBodyLen(maxResponseBodyLen int) Option { + return optionFunc(func(c *config) { + if maxResponseBodyLen > 0 { + c.maxResponseBodyLen = maxResponseBodyLen + } + }) +} + +func WithLogRequestBody(logRequestBody bool) Option { + return optionFunc(func(c *config) { + c.logRequestBody = logRequestBody + }) +} + +func WithMaxRequestBodyLen(maxRequestBodyLen int) Option { + return optionFunc(func(c *config) { + if maxRequestBodyLen > 0 { + c.maxRequestBodyLen = maxRequestBodyLen + } + }) +}