diff --git a/docs/api/fiber.md b/docs/api/fiber.md index 07d24dcc58..2c50339d50 100644 --- a/docs/api/fiber.md +++ b/docs/api/fiber.md @@ -110,6 +110,7 @@ app.Listen(":8080", fiber.ListenConfig{ | EnablePrefork | `bool` | When set to true, this will spawn multiple Go processes listening on the same port. | `false` | | EnablePrintRoutes | `bool` | If set to true, will print all routes with their method, path, and handler. | `false` | | GracefulContext | `context.Context` | Field to shutdown Fiber by given context gracefully. | `nil` | +| ShutdownTimeout | `time.Duration` | Specifies the maximum duration to wait for the server to gracefully shutdown. When the timeout is reached, the graceful shutdown process is interrupted and forcibly terminated, and the `context.DeadlineExceeded` error is passed to the `OnShutdownError` callback. Set to 0 to disable the timeout and wait indefinitely. | `10 * time.Second` | | ListenerAddrFunc | `func(addr net.Addr)` | Allows accessing and customizing `net.Listener`. | `nil` | | ListenerNetwork | `string` | Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only). WARNING: When prefork is set to true, only "tcp4" and "tcp6" can be chosen. | `tcp4` | | OnShutdownError | `func(err error)` | Allows to customize error behavior when gracefully shutting down the server by given signal. Prints error with `log.Fatalf()` | `nil` | diff --git a/listen.go b/listen.go index 0df4e1a060..624b2539f3 100644 --- a/listen.go +++ b/listen.go @@ -18,6 +18,7 @@ import ( "strconv" "strings" "text/tabwriter" + "time" "github.com/gofiber/fiber/v3/log" "github.com/mattn/go-colorable" @@ -37,8 +38,6 @@ const ( ) // ListenConfig is a struct to customize startup of Fiber. -// -// TODO: Add timeout for graceful shutdown. type ListenConfig struct { // GracefulContext is a field to shutdown Fiber by given context gracefully. // @@ -94,6 +93,13 @@ type ListenConfig struct { // Default : "" CertClientFile string `json:"cert_client_file"` + // When the graceful shutdown begins, use this field to set the timeout + // duration. If the timeout is reached, OnShutdownError will be called. + // Set to 0 to disable the timeout and wait indefinitely. + // + // Default: 10 * time.Second + ShutdownTimeout time.Duration `json:"shutdown_timeout"` + // When set to true, it will not print out the «Fiber» ASCII art and listening address. // // Default: false @@ -116,8 +122,9 @@ func listenConfigDefault(config ...ListenConfig) ListenConfig { return ListenConfig{ ListenerNetwork: NetworkTCP4, OnShutdownError: func(err error) { - log.Fatalf("shutdown: %v", err) //nolint:revive // It's an optipn + log.Fatalf("shutdown: %v", err) //nolint:revive // It's an option }, + ShutdownTimeout: 10 * time.Second, } } @@ -128,7 +135,7 @@ func listenConfigDefault(config ...ListenConfig) ListenConfig { if cfg.OnShutdownError == nil { cfg.OnShutdownError = func(err error) { - log.Fatalf("shutdown: %v", err) //nolint:revive // It's an optipn + log.Fatalf("shutdown: %v", err) //nolint:revive // It's an option } } @@ -472,8 +479,17 @@ func (app *App) printRoutesMessage() { func (app *App) gracefulShutdown(ctx context.Context, cfg ListenConfig) { <-ctx.Done() - if err := app.Shutdown(); err != nil { //nolint:contextcheck // TODO: Implement it + var err error + + if cfg.ShutdownTimeout != 0 { + err = app.ShutdownWithTimeout(cfg.ShutdownTimeout) //nolint:contextcheck // TODO: Implement it + } else { + err = app.Shutdown() //nolint:contextcheck // TODO: Implement it + } + + if err != nil { cfg.OnShutdownError(err) + return } if success := cfg.OnShutdownSuccess; success != nil { diff --git a/listen_test.go b/listen_test.go index c828a911cb..123cf2b3b8 100644 --- a/listen_test.go +++ b/listen_test.go @@ -85,7 +85,7 @@ func Test_Listen_Graceful_Shutdown(t *testing.T) { ExpectedStatusCode int }{ {Time: 500 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExpectedErr: nil}, - {Time: 3 * time.Second, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExpectedErr: errors.New("InmemoryListener is already closed: use of closed network connection")}, + {Time: 3 * time.Second, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExpectedErr: fasthttputil.ErrInmemoryListenerClosed}, } for _, tc := range testCases { @@ -115,6 +115,118 @@ func Test_Listen_Graceful_Shutdown(t *testing.T) { mu.Unlock() } +// go test -run Test_Listen_Graceful_Shutdown_Timeout +func Test_Listen_Graceful_Shutdown_Timeout(t *testing.T) { + var mu sync.Mutex + var shutdownSuccess bool + var shutdownTimeoutError error + + app := New() + + app.Get("/", func(c Ctx) error { + return c.SendString(c.Hostname()) + }) + + ln := fasthttputil.NewInmemoryListener() + errs := make(chan error) + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + errs <- app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + GracefulContext: ctx, + ShutdownTimeout: 500 * time.Millisecond, + OnShutdownSuccess: func() { + mu.Lock() + shutdownSuccess = true + mu.Unlock() + }, + OnShutdownError: func(err error) { + mu.Lock() + shutdownTimeoutError = err + mu.Unlock() + }, + }) + }() + + // Server readiness check + for i := 0; i < 10; i++ { + conn, err := ln.Dial() + // To test a graceful shutdown timeout, do not close the connection. + if err == nil { + _ = conn + break + } + // Wait a bit before retrying + time.Sleep(100 * time.Millisecond) + if i == 9 { + t.Fatalf("Server did not become ready in time: %v", err) + } + } + + testCases := []struct { + ExpectedErr error + ExpectedShutdownError error + ExpectedBody string + Time time.Duration + ExpectedStatusCode int + ExpectedShutdownSuccess bool + }{ + { + Time: 100 * time.Millisecond, + ExpectedBody: "example.com", + ExpectedStatusCode: StatusOK, + ExpectedErr: nil, + ExpectedShutdownError: nil, + ExpectedShutdownSuccess: false, + }, + { + Time: 3 * time.Second, + ExpectedBody: "", + ExpectedStatusCode: StatusOK, + ExpectedErr: fasthttputil.ErrInmemoryListenerClosed, + ExpectedShutdownError: context.DeadlineExceeded, + ExpectedShutdownSuccess: false, + }, + } + + for _, tc := range testCases { + time.Sleep(tc.Time) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("http://example.com") + + client := fasthttp.HostClient{} + client.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } + + resp := fasthttp.AcquireResponse() + err := client.Do(req, resp) + + if err == nil { + require.NoError(t, err) + require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) + require.Equal(t, tc.ExpectedBody, string(resp.Body())) + } else { + require.ErrorIs(t, err, tc.ExpectedErr) + } + + mu.Lock() + require.Equal(t, tc.ExpectedShutdownSuccess, shutdownSuccess) + require.Equal(t, tc.ExpectedShutdownError, shutdownTimeoutError) + mu.Unlock() + + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + } + + mu.Lock() + err := <-errs + require.NoError(t, err) + mu.Unlock() +} + // go test -run Test_Listen_Prefork func Test_Listen_Prefork(t *testing.T) { testPreforkMaster = true