Skip to content

Commit

Permalink
feat: Added Global rate limiting middleware - Implemented RateLimiter…
Browse files Browse the repository at this point in the history
… middleware to limit requests to the /command endpoint, integrated Dice for tracking request counts and managing rate limits, introduced structured logging with slog for error tracking. Future improvements: dynamic rate limits, improved error handling, and integration with monitoring tools.
  • Loading branch information
RishabhC-137 committed Sep 27, 2024
1 parent 7fec0f0 commit 34e1f27
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 4 deletions.
42 changes: 42 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package config

import (
"os"
"strconv"
)

// Config holds the application configuration
type Config struct {
RedisAddr string
ServerPort string
RequestLimit int // Field for the request limit
RequestWindow int // Field for the time window in seconds
}

// LoadConfig loads the application configuration from environment variables or defaults
func LoadConfig() *Config {
return &Config{
RedisAddr: getEnv("REDIS_ADDR", "localhost:7379"), // Default Redis address
ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port
RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit
RequestWindow: getEnvInt("REQUEST_WINDOW", 60), // Default request window in seconds
}
}

// getEnv retrieves an environment variable or returns a default value
func getEnv(key, fallback string) string {
if value, exists := os.LookupEnv(key); exists {
return value
}
return fallback
}

// getEnvInt retrieves an environment variable as an integer or returns a default value
func getEnvInt(key string, fallback int) int {
if value, exists := os.LookupEnv(key); exists {
if intValue, err := strconv.Atoi(value); err == nil {
return intValue
}
}
return fallback
}
7 changes: 7 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
module server

go 1.22.5

require github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831

require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 h1:Cqyj9WCtoobN6++bFbDSe27q94SPwJD9Z0wmu+SDRuk=
github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831/go.mod h1:8+VZrr14c2LW8fW4tWZ8Bv3P2lfvlg+PpsSn5cWWuiQ=
74 changes: 73 additions & 1 deletion internal/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,83 @@
package middleware

import (
"context"
"fmt"
"log/slog" // Import the slog package for structured logging
"net/http"
"strconv"
"time"

redis "github.com/dicedb/go-dice"
)

func RateLimiter(next http.Handler) http.Handler {
// RateLimiter middleware to limit requests based on a specified limit and duration
func RateLimiter(diceClient *redis.Client, next http.Handler, limit int, window int) http.Handler {

Check failure on line 15 in internal/middleware/ratelimiter.go

View workflow job for this annotation

GitHub Actions / lint

paramTypeCombine: func(diceClient *redis.Client, next http.Handler, limit int, window int) http.Handler could be replaced with func(diceClient *redis.Client, next http.Handler, limit, window int) http.Handler (gocritic)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Check Redis connection health
if err := diceClient.Ping(ctx).Err(); err != nil {
slog.Error("Redis connection is down", "error", err)
http.Error(w, "Service Unavailable", http.StatusServiceUnavailable)
return
}

// Skip rate limiting for non-command endpoints
if r.URL.Path != "/command" {
next.ServeHTTP(w, r)
return
}

// Get the current time window as a unique key
currentWindow := time.Now().Unix() / int64(window)
key := fmt.Sprintf("request_count:%d", currentWindow)

// Fetch the current request count
val, err := diceClient.Get(ctx, key).Result()
if err != nil && err != redis.Nil {
slog.Error("Error fetching request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}

// Initialize request count
requestCount := 0
if val != "" {
requestCount, err = strconv.Atoi(val)
if err != nil {
slog.Error("Error converting request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}

// Check if the request count exceeds the limit
if requestCount >= limit {
slog.Warn("Request limit exceeded", "count", requestCount)
http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests)
return
}

// Increment the request count
if _, err := diceClient.Incr(ctx, key).Result(); err != nil {
slog.Error("Error incrementing request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}

// Set the key expiry if it's newly created
if requestCount == 0 {
if err := diceClient.Expire(ctx, key, time.Duration(window)*time.Second).Err(); err != nil {
slog.Error("Error setting expiry for request count", "error", err)
}
}

// Log the successful request increment
slog.Info("Request processed", "count", requestCount+1)

// Call the next handler
next.ServeHTTP(w, r)
})
}
40 changes: 37 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,45 @@ import (
"context"
"log"
"net/http"
"server/config"
"sync"
"time"

"server/internal/api"
"server/internal/middleware"

redis "github.com/dicedb/go-dice"
)

type HTTPServer struct {
httpServer *http.Server
diceClient *redis.Client
}

func NewHTTPServer(addr string, mux *http.ServeMux) *HTTPServer {
func NewHTTPServer(addr string, mux *http.ServeMux, client *redis.Client) *HTTPServer {
return &HTTPServer{
httpServer: &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
},
diceClient: client,
}
}

func initDiceClient(config *config.Config) (*redis.Client, error) {
client := redis.NewClient(&redis.Options{

Check failure on line 34 in main.go

View workflow job for this annotation

GitHub Actions / lint

importShadow: shadow of imported from 'server/config' package 'config' (gocritic)
Addr: config.RedisAddr,
DialTimeout: 10 * time.Second,
MaxRetries: 10,
})

// Ping the Redis server to verify the connection
if err := client.Ping(context.Background()).Err(); err != nil {
return nil, err
}

return client, nil
}

func (s *HTTPServer) Run(ctx context.Context) error {
Expand All @@ -39,16 +59,30 @@ func (s *HTTPServer) Run(ctx context.Context) error {

<-ctx.Done()
log.Println("Shutting down server...")
return s.Shutdown()
}

func (s *HTTPServer) Shutdown() error {
// Additional cleanup if necessary
if err := s.diceClient.Close(); err != nil {
log.Printf("Failed to close Redis client: %v", err)
}
return s.httpServer.Shutdown(context.Background())
}

func main() {
config := config.LoadConfig()
diceClient, err := initDiceClient(config)

Check failure on line 75 in main.go

View workflow job for this annotation

GitHub Actions / lint

importShadow: shadow of imported from 'server/config' package 'config' (gocritic)
if err != nil {
log.Fatalf("Failed to initialize Redis client: %v", err)
}

mux := http.NewServeMux()

mux.Handle("/", middleware.RateLimiter(http.HandlerFunc(api.HealthCheck)))
mux.Handle("/", middleware.RateLimiter(diceClient, http.HandlerFunc(api.HealthCheck), config.RequestLimit, config.RequestWindow))
api.RegisterRoutes(mux)

httpServer := NewHTTPServer(":8080", mux)
httpServer := NewHTTPServer(":8080", mux, diceClient)

// context for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
Expand Down

0 comments on commit 34e1f27

Please sign in to comment.