Skip to content

Commit

Permalink
ClientIP middleware proposal, intended to replace RealIP
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechVitek committed Dec 15, 2024
1 parent 0a20a0e commit c2354ea
Show file tree
Hide file tree
Showing 2 changed files with 326 additions and 0 deletions.
185 changes: 185 additions & 0 deletions middleware/client_ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package middleware

import (
"context"
"net"
"net/http"
"net/netip"
"strings"
)

var (
// clientIPCtxKey is the context key used to store the client IP address.
clientIPCtxKey = &contextKey{"clientIP"}
)

// ClientIPFromHeader parses the client IP address from a specified HTTP header
// (e.g., X-Real-IP, CF-Connecting-IP) and injects it into the request context
// if it is not already set. The parsed IP address can be retrieved using GetClientIP().
//
// The middleware validates the IP address to ignore loopback, private, and unspecified addresses.
//
// ### Important Notice:
// - Use this middleware only when your infrastructure sets a trusted header containing the client IP.
// - If the specified header is not securely set by your infrastructure, malicious clients could spoof it.
//
// Example trusted headers:
// - "X-Real-IP" - Nginx (ngx_http_realip_module)
// - "X-Client-IP" - Apache (mod_remoteip)
// - "CF-Connecting-IP" - Cloudflare
// - "True-Client-IP" - Akamai, Cloudflare Enterprise
// - "X-Azure-ClientIP" - Azure Front Door
// - "Fastly-Client-IP" - Fastly
func ClientIPFromHeader(trustedHeader string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// Check if the client IP is already set in the context.
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok {
h.ServeHTTP(w, r)
return
}

// Parse the IP address from the trusted header.
ip, err := netip.ParseAddr(r.Header.Get(trustedHeader))
if err != nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsPrivate() {
// Ignore invalid or private IPs.
h.ServeHTTP(w, r)
return
}

// Store the valid client IP in the context.
ctx = context.WithValue(ctx, clientIPCtxKey, ip)
h.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
}

// ClientIPFromXFFHeader parses the client IP address from the X-Forwarded-For
// header and injects it into the request context if it is not already set. The
// parsed IP address can be retrieved using GetClientIP().
//
// The middleware traverses the X-Forwarded-For chain (rightmost untrusted IP)
// and excludes loopback, private, unspecified, and trusted IP ranges.
//
// ### Important Notice:
// - Use this middleware only when your infrastructure sets and validates the X-Forwarded-For header.
// - Malicious clients can spoof the header unless a trusted reverse proxy or load balancer sanitizes it.
//
// Parameters:
// - `trustedIPPrefixes`: A list of CIDR prefixes that define trusted proxy IP ranges.
//
// Example trusted IP ranges:
// - "203.0.113.0/24" - Example corporate proxy
// - "198.51.100.0/24" - Example data center or hosting provider
// - "2400:cb00::/32" - Cloudflare IPv6 range
// - "2606:4700::/32" - Cloudflare IPv6 range
// - "192.0.2.0/24" - Example VPN gateway
//
// Note: Private IP ranges (e.g., "10.0.0.0/8", "192.168.0.0/16", "172.16.0.0/12")
// are automatically excluded by netip.Addr.IsPrivate() and do not need to be added here.
func ClientIPFromXFFHeader(trustedIPPrefixes ...string) func(http.Handler) http.Handler {
// Pre-parse trusted prefixes.
trustedPrefixes := make([]netip.Prefix, len(trustedIPPrefixes))
for i, ipRange := range trustedIPPrefixes {
trustedPrefixes[i] = netip.MustParsePrefix(ipRange)
}

return func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// Check if the client IP is already set in the context.
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok {
h.ServeHTTP(w, r)
return
}

// Parse and split the X-Forwarded-For header(s).
xff := strings.Split(strings.Join(r.Header.Values("X-Forwarded-For"), ","), ",")
nextValue:
for i := len(xff) - 1; i >= 0; i-- {
ip, err := netip.ParseAddr(strings.TrimSpace(xff[i]))
if err != nil {
continue
}

// Ignore loopback, private, or unspecified addresses.
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() {
continue
}

// Ignore trusted IPs within the given ranges.
for _, prefix := range trustedPrefixes {
if prefix.Contains(ip) {
continue nextValue
}
}

// Store the valid client IP in the context.
ctx = context.WithValue(ctx, clientIPCtxKey, ip)
h.ServeHTTP(w, r.WithContext(ctx))
return
}

h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}

// ClientIPFromRemoteAddr extracts the client IP address from the RemoteAddr
// field of the HTTP request and injects it into the request context if it is
// not already set. The parsed IP address can be retrieved using GetClientIP().
//
// The middleware ignores invalid or private IPs.
//
// ### Use Case:
// This middleware is useful when the client IP cannot be determined from headers
// such as X-Forwarded-For or X-Real-IP, and you need to fall back to RemoteAddr.
func ClientIPFromRemoteAddr(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// Check if the client IP is already set in the context.
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok {
h.ServeHTTP(w, r)
return
}

// Extract the IP from request RemoteAddr.
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
h.ServeHTTP(w, r)
return
}

ip, err := netip.ParseAddr(host)
if err != nil {
h.ServeHTTP(w, r)
return
}

// Store the valid client IP in the context.
ctx = context.WithValue(ctx, clientIPCtxKey, ip)
h.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}

// GetClientIP retrieves the client IP address from the given context.
// The IP address is set by one of the following middlewares:
// - ClientIPFromHeader
// - ClientIPFromXFFHeader
// - ClientIPFromRemoteAddr
//
// Returns an empty string if no valid IP is found.
func GetClientIP(ctx context.Context) string {
ip, ok := ctx.Value(clientIPCtxKey).(netip.Addr)
if !ok || !ip.IsValid() {
return ""
}
return ip.String()
}
141 changes: 141 additions & 0 deletions middleware/client_ip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chi/chi/v5"
)

func TestClientIPFromHeader(t *testing.T) {
tt := []struct {
name string
in string
out string
}{
// Empty header.
{name: "empty", in: "", out: ""},

// Valid X-Real-IP header values.
{name: "valid_ipv4", in: "100.100.100.100", out: "100.100.100.100"},
{name: "valid_ipv4", in: "178.25.203.2", out: "178.25.203.2"},
{name: "valid_ipv6_lower", in: "2345:0425:2ca1:0000:0000:0567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"},
{name: "valid_ipv6_upper", in: "2345:0425:2CA1:0000:0000:0567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"},
{name: "valid_ipv6_lower_short", in: "2345:425:2ca1::567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"},
{name: "valid_ipv6_upper_short", in: "2345:425:2CA1::567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"},

// Invalid X-Real-IP header values.
{name: "invalid_ip", in: "invalid", out: ""},
{name: "invalid_ip_with_port", in: "100.100.100.100:80", out: ""},
{name: "invalid_multiple_ips", in: "100.100.100.100;100.100.100.101;100.100.100.102", out: ""},
{name: "invalid_loopback", in: "127.0.0.1", out: ""},
{name: "invalid_zeroes", in: "0.0.0.0", out: ""},
{name: "invalid_loopback", in: "127.0.0.1", out: ""},
{name: "invalid_private_ipv4_1", in: "192.168.0.1", out: ""},
{name: "invalid_private_ipv4_2", in: "192.168.10.12", out: ""},
{name: "invalid_private_ipv4_3", in: "172.16.0.0", out: ""},
{name: "invalid_private_ipv4_4", in: "172.25.203.2", out: ""},
{name: "invalid_private_ipv4_5", in: "10.0.0.0", out: ""},
{name: "invalid_private_ipv4_6", in: "10.0.1.10", out: ""},
{name: "invalid_private_ipv6_1", in: "fc00::1", out: ""},
{name: "invalid_private_ipv6_2", in: "fc00:0425:2ca1:0000:0000:0567:5673:23b5", out: ""},
}

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Real-IP", tc.in)
w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(ClientIPFromHeader("X-Real-IP"))

var clientIP string
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
clientIP = GetClientIP(r.Context())
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Errorf("Response Code should be 200")
}

if clientIP != tc.out {
t.Errorf("expected %v, got %v", tc.out, clientIP)
}
})
}
}

func TestClientIPFromXFFHeader(t *testing.T) {
tt := []struct {
name string
xff []string
out string
}{
{name: "empty", xff: []string{""}, out: ""},

{name: "", xff: []string{"100.100.100.100"}, out: "100.100.100.100"},
{name: "", xff: []string{"100.100.100.100, 200.200.200.200"}, out: "200.200.200.200"},
{name: "", xff: []string{"100.100.100.100,200.200.200.200"}, out: "200.200.200.200"},
{name: "", xff: []string{"100.100.100.100", "200.200.200.200"}, out: "200.200.200.200"},
{name: "", xff: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"},
{name: "", xff: []string{"203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"},
{name: "", xff: []string{"5.5.5.5, 203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348", "7.7.7.7, 4.4.4.4"}, out: "4.4.4.4"},
}

r := chi.NewRouter()
r.Use(ClientIPFromXFFHeader())

for _, tc := range tt {
req, _ := http.NewRequest("GET", "/", nil)
for _, v := range tc.xff {
req.Header.Add("X-Forwarded-For", v)
}

w := httptest.NewRecorder()

clientIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
clientIP = GetClientIP(r.Context())
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Errorf("Response Code should be 200")
}

if clientIP != tc.out {
t.Errorf("expected %v, got %v", tc.out, clientIP)
}
}
}

func TestClientIPFromRemoteAddr(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "192.0.2.1:1234" // Simulate the remote address set by http.Server.

w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(ClientIPFromRemoteAddr)

var clientIP string
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
clientIP = GetClientIP(r.Context())
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Errorf("Response Code should be 200")
}

expected := "192.0.2.1"
if clientIP != expected {
t.Errorf("expected %v, got %v", expected, clientIP)
}
}

0 comments on commit c2354ea

Please sign in to comment.