Skip to content

Commit 1fb5fed

Browse files
committed
Add function to expose allowed methods for use in custom 405-handlers
Fixes go-chi#870
1 parent 58ca6d6 commit 1fb5fed

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

context.go

+17-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,23 @@ func NewRouteContext() *Context {
3434
return &Context{}
3535
}
3636

37-
var (
38-
// RouteCtxKey is the context.Context key to store the request context.
39-
RouteCtxKey = &contextKey{"RouteContext"}
40-
)
37+
// WithRouteContext returns the list of methods allowed for the current
38+
// request, based on the current routing context.
39+
func AllowedMethods(ctx context.Context) []string {
40+
if rctx := RouteContext(ctx); rctx != nil {
41+
result := make([]string, 0, len(rctx.methodsAllowed))
42+
for _, method := range rctx.methodsAllowed {
43+
if method := methodTypString(method); method != "" {
44+
result = append(result, method)
45+
}
46+
}
47+
return result
48+
}
49+
return nil
50+
}
51+
52+
// RouteCtxKey is the context.Context key to store the request context.
53+
var RouteCtxKey = &contextKey{"RouteContext"}
4154

4255
// Context is the default routing context set on the root node of a
4356
// request context to track route patterns, URL parameters and

context_test.go

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package chi
22

3-
import "testing"
3+
import (
4+
"context"
5+
"strings"
6+
"testing"
7+
)
48

59
// TestRoutePattern tests correct in-the-middle wildcard removals.
610
// If user organizes a router like this:
@@ -85,3 +89,32 @@ func TestRoutePattern(t *testing.T) {
8589
t.Fatal("unexpected route pattern for root: " + p)
8690
}
8791
}
92+
93+
func TestAllowedMethods(t *testing.T) {
94+
t.Run("no chi context", func(t *testing.T) {
95+
got := AllowedMethods(context.Background())
96+
if got != nil {
97+
t.Errorf("Unexpected allowed methods: %v", got)
98+
}
99+
})
100+
t.Run("expected methods", func(t *testing.T) {
101+
want := "GET HEAD"
102+
ctx := context.WithValue(context.Background(), RouteCtxKey, &Context{
103+
methodsAllowed: []methodTyp{mGET, mHEAD},
104+
})
105+
got := strings.Join(AllowedMethods(ctx), " ")
106+
if want != got {
107+
t.Errorf("Unexpected allowed methods: %s, want: %s", got, want)
108+
}
109+
})
110+
t.Run("unexpected methods", func(t *testing.T) {
111+
want := "GET HEAD"
112+
ctx := context.WithValue(context.Background(), RouteCtxKey, &Context{
113+
methodsAllowed: []methodTyp{mGET, mHEAD, 9000},
114+
})
115+
got := strings.Join(AllowedMethods(ctx), " ")
116+
if want != got {
117+
t.Errorf("Unexpected allowed methods: %s, want: %s", got, want)
118+
}
119+
})
120+
}

0 commit comments

Comments
 (0)