diff --git a/config/config.go b/config/config.go index a056938..78407bf 100644 --- a/config/config.go +++ b/config/config.go @@ -124,4 +124,4 @@ func splitString(s string) []string { array = append(array, strings.TrimSpace(v)) } return array -} +} \ No newline at end of file diff --git a/internal/middleware/healthcheck.go b/internal/middleware/healthcheck.go new file mode 100644 index 0000000..058ef91 --- /dev/null +++ b/internal/middleware/healthcheck.go @@ -0,0 +1,94 @@ +package middleware + +import ( + "context" + "errors" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + "server/config" + "strconv" + "github.com/dicedb/dicedb-go" + "github.com/gin-gonic/gin" + "server/internal/db" +) + +// HealthCheckMiddleware is a middleware that performs a health check on the server +// and applies rate limiting if necessary using the RateLimiterMiddleware. +type ( + HealthCheckMiddleware struct { + client *db.DiceDB + limit int64 + window float64 + cronFrequencyInterval time.Duration + } +) + +// NewHealthCheckMiddleware creates a new instance of HealthCheckMiddleware. +func NewHealthCheckMiddleware(client *db.DiceDB, limit int64, window float64) *HealthCheckMiddleware { + // Initialize RateLimiterMiddleware + h := &HealthCheckMiddleware{ + client: client, + limit: limit, + window: window, + cronFrequencyInterval: config.LoadConfig().Server.CronCleanupFrequency, + } + return h +} + +// Exec handles the health check request. +func (h *HealthCheckMiddleware) Exec(c *gin.Context) { + // Only allow rate limiting for specific paths, here health check path + + if c.Request.URL.Path != "/health" { + // If the path is not "/health", return immediately without further processing + c.Next() + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentWindow := time.Now().Unix() / int64(h.window) + key := fmt.Sprintf("request_count:%d", currentWindow) + slog.Debug("Created rate limiter key", slog.Any("key", key)) + + // Get the current request count for this window + val, err := h.client.Client.Get(ctx, key).Result() + if err != nil && !errors.Is(err, dicedb.Nil) { + slog.Error("Error fetching request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Parse the current request count or initialize to 0 + var requestCount int64 = 0 + if val != "" { + requestCount, err = strconv.ParseInt(val, 10, 64) + if err != nil { + slog.Error("Error converting request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) + return + } + } + slog.Info("Fetched and parsed request count successfully", "key", key, "requestCount", requestCount) + + secondsDifference, err := calculateNextCleanupTime(ctx, h.client, h.cronFrequencyInterval) + if err != nil { + slog.Error("Error calculating next cleanup time", "error", err) + } + + AddRateLimitHeaders(c.Writer, h.limit, h.limit-requestCount, requestCount, currentWindow+int64(h.window), + secondsDifference) + + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(http.StatusOK) + if err := json.NewEncoder(c.Writer).Encode(map[string]string{"message": "server is running"}); err != nil { + http.Error(c.Writer, err.Error(), http.StatusInternalServerError) + } + + c.Next() +} + diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index 4d8a01b..c3ef6d6 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -79,7 +79,7 @@ func (rl *RateLimiterMiddleware) Exec(c *gin.Context) { // Check if the request count exceeds the limit if requestCount >= rl.limit { slog.Warn("Request limit exceeded", "count", requestCount) - addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), 0) + AddRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), 0) http.Error(c.Writer, "429 - Too Many Requests", http.StatusTooManyRequests) return } @@ -103,7 +103,7 @@ func (rl *RateLimiterMiddleware) Exec(c *gin.Context) { slog.Error("Error calculating next cleanup time", "error", err) } - addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), + AddRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), secondsDifference) slog.Info("Request processed", "count", requestCount+1) @@ -173,7 +173,7 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi // Check if the request limit has been exceeded if requestCount >= limit { slog.Warn("Request limit exceeded", "count", requestCount) - addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0) + AddRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0) http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) return } @@ -194,14 +194,14 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi } } - addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0) + AddRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0) slog.Info("Request processed", "count", requestCount) next.ServeHTTP(w, r) }) } -func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) { +func AddRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) { w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10)) w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10)) w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10)) diff --git a/internal/server/http.go b/internal/server/http.go index b9c6d03..368a38a 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -80,10 +80,6 @@ func (s *HTTPServer) Shutdown() error { return s.httpServer.Shutdown(context.Background()) } -func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "server is running"}) -} - func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { diff --git a/main.go b/main.go index 1106be4..d703df6 100644 --- a/main.go +++ b/main.go @@ -63,6 +63,12 @@ func main() { configValue.Server.RequestWindowSec, ).Exec)) + // HealthCheck middleware + router.Use((middleware.NewHealthCheckMiddleware(diceDBAdminClient, + configValue.Server.RequestLimitPerMin, + configValue.Server.RequestWindowSec, + ).Exec)) + httpServer := server.NewHTTPServer( router, diceDBAdminClient, @@ -72,7 +78,6 @@ func main() { ) // Register routes - router.GET("/health", gin.WrapF(httpServer.HealthCheck)) router.POST("/shell/exec/:cmd", gin.WrapF(httpServer.CliHandler)) router.GET("/search", gin.WrapF(httpServer.SearchHandler))