-
-
Notifications
You must be signed in to change notification settings - Fork 998
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ClientIP middleware proposal, intended to replace RealIP
- Loading branch information
1 parent
0a20a0e
commit c2354ea
Showing
2 changed files
with
326 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
package middleware | ||
|
||
import ( | ||
"context" | ||
"net" | ||
"net/http" | ||
"net/netip" | ||
"strings" | ||
) | ||
|
||
var ( | ||
// clientIPCtxKey is the context key used to store the client IP address. | ||
clientIPCtxKey = &contextKey{"clientIP"} | ||
) | ||
|
||
// ClientIPFromHeader parses the client IP address from a specified HTTP header | ||
// (e.g., X-Real-IP, CF-Connecting-IP) and injects it into the request context | ||
// if it is not already set. The parsed IP address can be retrieved using GetClientIP(). | ||
// | ||
// The middleware validates the IP address to ignore loopback, private, and unspecified addresses. | ||
// | ||
// ### Important Notice: | ||
// - Use this middleware only when your infrastructure sets a trusted header containing the client IP. | ||
// - If the specified header is not securely set by your infrastructure, malicious clients could spoof it. | ||
// | ||
// Example trusted headers: | ||
// - "X-Real-IP" - Nginx (ngx_http_realip_module) | ||
// - "X-Client-IP" - Apache (mod_remoteip) | ||
// - "CF-Connecting-IP" - Cloudflare | ||
// - "True-Client-IP" - Akamai, Cloudflare Enterprise | ||
// - "X-Azure-ClientIP" - Azure Front Door | ||
// - "Fastly-Client-IP" - Fastly | ||
func ClientIPFromHeader(trustedHeader string) func(http.Handler) http.Handler { | ||
return func(h http.Handler) http.Handler { | ||
fn := func(w http.ResponseWriter, r *http.Request) { | ||
ctx := r.Context() | ||
|
||
// Check if the client IP is already set in the context. | ||
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok { | ||
h.ServeHTTP(w, r) | ||
return | ||
} | ||
|
||
// Parse the IP address from the trusted header. | ||
ip, err := netip.ParseAddr(r.Header.Get(trustedHeader)) | ||
if err != nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsPrivate() { | ||
// Ignore invalid or private IPs. | ||
h.ServeHTTP(w, r) | ||
return | ||
} | ||
|
||
// Store the valid client IP in the context. | ||
ctx = context.WithValue(ctx, clientIPCtxKey, ip) | ||
h.ServeHTTP(w, r.WithContext(ctx)) | ||
} | ||
return http.HandlerFunc(fn) | ||
} | ||
} | ||
|
||
// ClientIPFromXFFHeader parses the client IP address from the X-Forwarded-For | ||
// header and injects it into the request context if it is not already set. The | ||
// parsed IP address can be retrieved using GetClientIP(). | ||
// | ||
// The middleware traverses the X-Forwarded-For chain (rightmost untrusted IP) | ||
// and excludes loopback, private, unspecified, and trusted IP ranges. | ||
// | ||
// ### Important Notice: | ||
// - Use this middleware only when your infrastructure sets and validates the X-Forwarded-For header. | ||
// - Malicious clients can spoof the header unless a trusted reverse proxy or load balancer sanitizes it. | ||
// | ||
// Parameters: | ||
// - `trustedIPPrefixes`: A list of CIDR prefixes that define trusted proxy IP ranges. | ||
// | ||
// Example trusted IP ranges: | ||
// - "203.0.113.0/24" - Example corporate proxy | ||
// - "198.51.100.0/24" - Example data center or hosting provider | ||
// - "2400:cb00::/32" - Cloudflare IPv6 range | ||
// - "2606:4700::/32" - Cloudflare IPv6 range | ||
// - "192.0.2.0/24" - Example VPN gateway | ||
// | ||
// Note: Private IP ranges (e.g., "10.0.0.0/8", "192.168.0.0/16", "172.16.0.0/12") | ||
// are automatically excluded by netip.Addr.IsPrivate() and do not need to be added here. | ||
func ClientIPFromXFFHeader(trustedIPPrefixes ...string) func(http.Handler) http.Handler { | ||
// Pre-parse trusted prefixes. | ||
trustedPrefixes := make([]netip.Prefix, len(trustedIPPrefixes)) | ||
for i, ipRange := range trustedIPPrefixes { | ||
trustedPrefixes[i] = netip.MustParsePrefix(ipRange) | ||
} | ||
|
||
return func(h http.Handler) http.Handler { | ||
fn := func(w http.ResponseWriter, r *http.Request) { | ||
ctx := r.Context() | ||
|
||
// Check if the client IP is already set in the context. | ||
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok { | ||
h.ServeHTTP(w, r) | ||
return | ||
} | ||
|
||
// Parse and split the X-Forwarded-For header(s). | ||
xff := strings.Split(strings.Join(r.Header.Values("X-Forwarded-For"), ","), ",") | ||
nextValue: | ||
for i := len(xff) - 1; i >= 0; i-- { | ||
ip, err := netip.ParseAddr(strings.TrimSpace(xff[i])) | ||
if err != nil { | ||
continue | ||
} | ||
|
||
// Ignore loopback, private, or unspecified addresses. | ||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() { | ||
continue | ||
} | ||
|
||
// Ignore trusted IPs within the given ranges. | ||
for _, prefix := range trustedPrefixes { | ||
if prefix.Contains(ip) { | ||
continue nextValue | ||
} | ||
} | ||
|
||
// Store the valid client IP in the context. | ||
ctx = context.WithValue(ctx, clientIPCtxKey, ip) | ||
h.ServeHTTP(w, r.WithContext(ctx)) | ||
return | ||
} | ||
|
||
h.ServeHTTP(w, r) | ||
} | ||
return http.HandlerFunc(fn) | ||
} | ||
} | ||
|
||
// ClientIPFromRemoteAddr extracts the client IP address from the RemoteAddr | ||
// field of the HTTP request and injects it into the request context if it is | ||
// not already set. The parsed IP address can be retrieved using GetClientIP(). | ||
// | ||
// The middleware ignores invalid or private IPs. | ||
// | ||
// ### Use Case: | ||
// This middleware is useful when the client IP cannot be determined from headers | ||
// such as X-Forwarded-For or X-Real-IP, and you need to fall back to RemoteAddr. | ||
func ClientIPFromRemoteAddr(h http.Handler) http.Handler { | ||
fn := func(w http.ResponseWriter, r *http.Request) { | ||
ctx := r.Context() | ||
|
||
// Check if the client IP is already set in the context. | ||
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok { | ||
h.ServeHTTP(w, r) | ||
return | ||
} | ||
|
||
// Extract the IP from request RemoteAddr. | ||
host, _, err := net.SplitHostPort(r.RemoteAddr) | ||
if err != nil { | ||
h.ServeHTTP(w, r) | ||
return | ||
} | ||
|
||
ip, err := netip.ParseAddr(host) | ||
if err != nil { | ||
h.ServeHTTP(w, r) | ||
return | ||
} | ||
|
||
// Store the valid client IP in the context. | ||
ctx = context.WithValue(ctx, clientIPCtxKey, ip) | ||
h.ServeHTTP(w, r.WithContext(ctx)) | ||
} | ||
return http.HandlerFunc(fn) | ||
} | ||
|
||
// GetClientIP retrieves the client IP address from the given context. | ||
// The IP address is set by one of the following middlewares: | ||
// - ClientIPFromHeader | ||
// - ClientIPFromXFFHeader | ||
// - ClientIPFromRemoteAddr | ||
// | ||
// Returns an empty string if no valid IP is found. | ||
func GetClientIP(ctx context.Context) string { | ||
ip, ok := ctx.Value(clientIPCtxKey).(netip.Addr) | ||
if !ok || !ip.IsValid() { | ||
return "" | ||
} | ||
return ip.String() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
package middleware | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/go-chi/chi/v5" | ||
) | ||
|
||
func TestClientIPFromHeader(t *testing.T) { | ||
tt := []struct { | ||
name string | ||
in string | ||
out string | ||
}{ | ||
// Empty header. | ||
{name: "empty", in: "", out: ""}, | ||
|
||
// Valid X-Real-IP header values. | ||
{name: "valid_ipv4", in: "100.100.100.100", out: "100.100.100.100"}, | ||
{name: "valid_ipv4", in: "178.25.203.2", out: "178.25.203.2"}, | ||
{name: "valid_ipv6_lower", in: "2345:0425:2ca1:0000:0000:0567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"}, | ||
{name: "valid_ipv6_upper", in: "2345:0425:2CA1:0000:0000:0567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"}, | ||
{name: "valid_ipv6_lower_short", in: "2345:425:2ca1::567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"}, | ||
{name: "valid_ipv6_upper_short", in: "2345:425:2CA1::567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"}, | ||
|
||
// Invalid X-Real-IP header values. | ||
{name: "invalid_ip", in: "invalid", out: ""}, | ||
{name: "invalid_ip_with_port", in: "100.100.100.100:80", out: ""}, | ||
{name: "invalid_multiple_ips", in: "100.100.100.100;100.100.100.101;100.100.100.102", out: ""}, | ||
{name: "invalid_loopback", in: "127.0.0.1", out: ""}, | ||
{name: "invalid_zeroes", in: "0.0.0.0", out: ""}, | ||
{name: "invalid_loopback", in: "127.0.0.1", out: ""}, | ||
{name: "invalid_private_ipv4_1", in: "192.168.0.1", out: ""}, | ||
{name: "invalid_private_ipv4_2", in: "192.168.10.12", out: ""}, | ||
{name: "invalid_private_ipv4_3", in: "172.16.0.0", out: ""}, | ||
{name: "invalid_private_ipv4_4", in: "172.25.203.2", out: ""}, | ||
{name: "invalid_private_ipv4_5", in: "10.0.0.0", out: ""}, | ||
{name: "invalid_private_ipv4_6", in: "10.0.1.10", out: ""}, | ||
{name: "invalid_private_ipv6_1", in: "fc00::1", out: ""}, | ||
{name: "invalid_private_ipv6_2", in: "fc00:0425:2ca1:0000:0000:0567:5673:23b5", out: ""}, | ||
} | ||
|
||
for _, tc := range tt { | ||
t.Run(tc.name, func(t *testing.T) { | ||
req, _ := http.NewRequest("GET", "/", nil) | ||
req.Header.Add("X-Real-IP", tc.in) | ||
w := httptest.NewRecorder() | ||
|
||
r := chi.NewRouter() | ||
r.Use(ClientIPFromHeader("X-Real-IP")) | ||
|
||
var clientIP string | ||
r.Get("/", func(w http.ResponseWriter, r *http.Request) { | ||
clientIP = GetClientIP(r.Context()) | ||
w.Write([]byte("Hello World")) | ||
}) | ||
r.ServeHTTP(w, req) | ||
|
||
if w.Code != 200 { | ||
t.Errorf("Response Code should be 200") | ||
} | ||
|
||
if clientIP != tc.out { | ||
t.Errorf("expected %v, got %v", tc.out, clientIP) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestClientIPFromXFFHeader(t *testing.T) { | ||
tt := []struct { | ||
name string | ||
xff []string | ||
out string | ||
}{ | ||
{name: "empty", xff: []string{""}, out: ""}, | ||
|
||
{name: "", xff: []string{"100.100.100.100"}, out: "100.100.100.100"}, | ||
{name: "", xff: []string{"100.100.100.100, 200.200.200.200"}, out: "200.200.200.200"}, | ||
{name: "", xff: []string{"100.100.100.100,200.200.200.200"}, out: "200.200.200.200"}, | ||
{name: "", xff: []string{"100.100.100.100", "200.200.200.200"}, out: "200.200.200.200"}, | ||
{name: "", xff: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"}, | ||
{name: "", xff: []string{"203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"}, | ||
{name: "", xff: []string{"5.5.5.5, 203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348", "7.7.7.7, 4.4.4.4"}, out: "4.4.4.4"}, | ||
} | ||
|
||
r := chi.NewRouter() | ||
r.Use(ClientIPFromXFFHeader()) | ||
|
||
for _, tc := range tt { | ||
req, _ := http.NewRequest("GET", "/", nil) | ||
for _, v := range tc.xff { | ||
req.Header.Add("X-Forwarded-For", v) | ||
} | ||
|
||
w := httptest.NewRecorder() | ||
|
||
clientIP := "" | ||
r.Get("/", func(w http.ResponseWriter, r *http.Request) { | ||
clientIP = GetClientIP(r.Context()) | ||
w.Write([]byte("Hello World")) | ||
}) | ||
r.ServeHTTP(w, req) | ||
|
||
if w.Code != 200 { | ||
t.Errorf("Response Code should be 200") | ||
} | ||
|
||
if clientIP != tc.out { | ||
t.Errorf("expected %v, got %v", tc.out, clientIP) | ||
} | ||
} | ||
} | ||
|
||
func TestClientIPFromRemoteAddr(t *testing.T) { | ||
req, _ := http.NewRequest("GET", "/", nil) | ||
req.RemoteAddr = "192.0.2.1:1234" // Simulate the remote address set by http.Server. | ||
|
||
w := httptest.NewRecorder() | ||
|
||
r := chi.NewRouter() | ||
r.Use(ClientIPFromRemoteAddr) | ||
|
||
var clientIP string | ||
r.Get("/", func(w http.ResponseWriter, r *http.Request) { | ||
clientIP = GetClientIP(r.Context()) | ||
w.Write([]byte("Hello World")) | ||
}) | ||
r.ServeHTTP(w, req) | ||
|
||
if w.Code != 200 { | ||
t.Errorf("Response Code should be 200") | ||
} | ||
|
||
expected := "192.0.2.1" | ||
if clientIP != expected { | ||
t.Errorf("expected %v, got %v", expected, clientIP) | ||
} | ||
} |