Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Health middleware #58

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,4 @@ func splitString(s string) []string {
array = append(array, strings.TrimSpace(v))
}
return array
}
}
94 changes: 94 additions & 0 deletions internal/middleware/healthcheck.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package middleware

import (
"context"
"errors"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"server/config"
"strconv"
"github.com/dicedb/dicedb-go"
"github.com/gin-gonic/gin"
"server/internal/db"
)

// HealthCheckMiddleware is a middleware that performs a health check on the server
// and applies rate limiting if necessary using the RateLimiterMiddleware.
type (
HealthCheckMiddleware struct {
client *db.DiceDB
limit int64
window float64
cronFrequencyInterval time.Duration
}
)

// NewHealthCheckMiddleware creates a new instance of HealthCheckMiddleware.
func NewHealthCheckMiddleware(client *db.DiceDB, limit int64, window float64) *HealthCheckMiddleware {
// Initialize RateLimiterMiddleware
h := &HealthCheckMiddleware{
client: client,
limit: limit,
window: window,
cronFrequencyInterval: config.LoadConfig().Server.CronCleanupFrequency,
}
return h
}

// Exec handles the health check request.
func (h *HealthCheckMiddleware) Exec(c *gin.Context) {
// Only allow rate limiting for specific paths, here health check path

if c.Request.URL.Path != "/health" {
// If the path is not "/health", return immediately without further processing
c.Next()
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

currentWindow := time.Now().Unix() / int64(h.window)
key := fmt.Sprintf("request_count:%d", currentWindow)
slog.Debug("Created rate limiter key", slog.Any("key", key))

// Get the current request count for this window
val, err := h.client.Client.Get(ctx, key).Result()
if err != nil && !errors.Is(err, dicedb.Nil) {
slog.Error("Error fetching request count", "error", err)
http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError)
return
}

// Parse the current request count or initialize to 0
var requestCount int64 = 0
if val != "" {
requestCount, err = strconv.ParseInt(val, 10, 64)
if err != nil {
slog.Error("Error converting request count", "error", err)
http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError)
return
}
}
slog.Info("Fetched and parsed request count successfully", "key", key, "requestCount", requestCount)

secondsDifference, err := calculateNextCleanupTime(ctx, h.client, h.cronFrequencyInterval)
if err != nil {
slog.Error("Error calculating next cleanup time", "error", err)
}

AddRateLimitHeaders(c.Writer, h.limit, h.limit-requestCount, requestCount, currentWindow+int64(h.window),
secondsDifference)


c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(http.StatusOK)
if err := json.NewEncoder(c.Writer).Encode(map[string]string{"message": "server is running"}); err != nil {
http.Error(c.Writer, err.Error(), http.StatusInternalServerError)
}

c.Next()
}

10 changes: 5 additions & 5 deletions internal/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (rl *RateLimiterMiddleware) Exec(c *gin.Context) {
// Check if the request count exceeds the limit
if requestCount >= rl.limit {
slog.Warn("Request limit exceeded", "count", requestCount)
addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), 0)
AddRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), 0)
http.Error(c.Writer, "429 - Too Many Requests", http.StatusTooManyRequests)
return
}
Expand All @@ -103,7 +103,7 @@ func (rl *RateLimiterMiddleware) Exec(c *gin.Context) {
slog.Error("Error calculating next cleanup time", "error", err)
}

addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window),
AddRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window),
secondsDifference)

slog.Info("Request processed", "count", requestCount+1)
Expand Down Expand Up @@ -173,7 +173,7 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi
// Check if the request limit has been exceeded
if requestCount >= limit {
slog.Warn("Request limit exceeded", "count", requestCount)
addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0)
AddRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0)
http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests)
return
}
Expand All @@ -194,14 +194,14 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi
}
}

addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0)
AddRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0)

slog.Info("Request processed", "count", requestCount)
next.ServeHTTP(w, r)
})
}

func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) {
func AddRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) {
w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10))
w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10))
w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10))
Expand Down
4 changes: 0 additions & 4 deletions internal/server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ func (s *HTTPServer) Shutdown() error {
return s.httpServer.Shutdown(context.Background())
}

func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) {
util.JSONResponse(w, http.StatusOK, map[string]string{"message": "server is running"})
}

func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) {
diceCmd, err := util.ParseHTTPRequest(r)
if err != nil {
Expand Down
7 changes: 6 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ func main() {
configValue.Server.RequestWindowSec,
).Exec))

// HealthCheck middleware
router.Use((middleware.NewHealthCheckMiddleware(diceDBAdminClient,
configValue.Server.RequestLimitPerMin,
configValue.Server.RequestWindowSec,
).Exec))

httpServer := server.NewHTTPServer(
router,
diceDBAdminClient,
Expand All @@ -72,7 +78,6 @@ func main() {
)

// Register routes
router.GET("/health", gin.WrapF(httpServer.HealthCheck))
router.POST("/shell/exec/:cmd", gin.WrapF(httpServer.CliHandler))
router.GET("/search", gin.WrapF(httpServer.SearchHandler))

Expand Down
Loading