-
-
Notifications
You must be signed in to change notification settings - Fork 289
/
Copy pathhandler_test.go
207 lines (195 loc) · 6.9 KB
/
handler_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
package templ_test
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/a-h/templ"
"github.com/google/go-cmp/cmp"
)
func TestHandler(t *testing.T) {
hello := templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
if _, err := io.WriteString(w, "Hello"); err != nil {
t.Fatalf("failed to write string: %v", err)
}
return nil
})
errorComponent := templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
if _, err := io.WriteString(w, "Hello"); err != nil {
t.Fatalf("failed to write string: %v", err)
}
return errors.New("handler error")
})
tests := []struct {
name string
input *templ.ComponentHandler
expectedStatus int
expectedMIMEType string
expectedBody string
}{
{
name: "handlers return OK by default",
input: templ.Handler(hello),
expectedStatus: http.StatusOK,
expectedMIMEType: "text/html; charset=utf-8",
expectedBody: "Hello",
},
{
name: "handlers return OK by default",
input: templ.Handler(templ.Raw(`♠ ‘ ♠ ‘`)),
expectedStatus: http.StatusOK,
expectedMIMEType: "text/html; charset=utf-8",
expectedBody: "♠ ‘ ♠ ‘",
},
{
name: "handlers can be configured to return an alternative status code",
input: templ.Handler(hello, templ.WithStatus(http.StatusNotFound)),
expectedStatus: http.StatusNotFound,
expectedMIMEType: "text/html; charset=utf-8",
expectedBody: "Hello",
},
{
name: "handlers can be configured to return an alternative status code and content type",
input: templ.Handler(hello, templ.WithStatus(http.StatusOK), templ.WithContentType("text/csv")),
expectedStatus: http.StatusOK,
expectedMIMEType: "text/csv",
expectedBody: "Hello",
},
{
name: "handlers that fail return a 500 error",
input: templ.Handler(errorComponent),
expectedStatus: http.StatusInternalServerError,
expectedMIMEType: "text/plain; charset=utf-8",
expectedBody: "templ: failed to render template\n",
},
{
name: "error handling can be customised",
input: templ.Handler(errorComponent, templ.WithErrorHandler(func(r *http.Request, err error) http.Handler {
// Because the error is received, it's possible to log the detail of the request.
// log.Printf("template render error for %v %v: %v", r.Method, r.URL.String(), err)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
if _, err := io.WriteString(w, "custom body"); err != nil {
t.Fatalf("failed to write string: %v", err)
}
})
})),
expectedStatus: http.StatusBadRequest,
expectedMIMEType: "text/html; charset=utf-8",
expectedBody: "custom body",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/test", nil)
tt.input.ServeHTTP(w, r)
if got := w.Result().StatusCode; tt.expectedStatus != got {
t.Errorf("expected status %d, got %d", tt.expectedStatus, got)
}
if mimeType := w.Result().Header.Get("Content-Type"); tt.expectedMIMEType != mimeType {
t.Errorf("expected content-type %s, got %s", tt.expectedMIMEType, mimeType)
}
body, err := io.ReadAll(w.Result().Body)
if err != nil {
t.Errorf("failed to read body: %v", err)
}
if diff := cmp.Diff(tt.expectedBody, string(body)); diff != "" {
t.Error(diff)
}
})
}
t.Run("streaming mode allows responses to be flushed", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/test", nil)
component := templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
// Write part 1.
if _, err := io.WriteString(w, "Part 1"); err != nil {
return err
}
// Flush.
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Check partial response.
wr := w.(*httptest.ResponseRecorder)
actualBody := wr.Body.String()
if diff := cmp.Diff("Part 1", actualBody); diff != "" {
t.Error(diff)
}
// Write part 2.
if _, err := io.WriteString(w, "\nPart 2"); err != nil {
return err
}
return nil
})
templ.Handler(component, templ.WithStatus(http.StatusCreated), templ.WithStreaming()).ServeHTTP(w, r)
if got := w.Result().StatusCode; http.StatusCreated != got {
t.Errorf("expected status %d, got %d", http.StatusCreated, got)
}
if mimeType := w.Result().Header.Get("Content-Type"); "text/html; charset=utf-8" != mimeType {
t.Errorf("expected content-type %s, got %s", "text/html; charset=utf-8", mimeType)
}
body, err := io.ReadAll(w.Result().Body)
if err != nil {
t.Errorf("failed to read body: %v", err)
}
if diff := cmp.Diff("Part 1\nPart 2", string(body)); diff != "" {
t.Error(diff)
}
})
t.Run("streaming mode handles errors", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/test", nil)
expectedErr := errors.New("streaming error")
component := templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
if _, err := io.WriteString(w, "Body"); err != nil {
return err
}
return expectedErr
})
var errorHandlerCalled bool
errorHandler := func(r *http.Request, err error) http.Handler {
if expectedErr != err {
t.Errorf("expected error %v, got %v", expectedErr, err)
}
errorHandlerCalled = true
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// This will be ignored, because the header has already been written.
w.WriteHeader(http.StatusBadRequest)
// This will be written, but will be appended to the written body.
if _, err := io.WriteString(w, "Error message"); err != nil {
t.Errorf("failed to write error message: %v", err)
}
})
}
h := templ.Handler(component,
templ.WithStatus(http.StatusCreated),
templ.WithStreaming(),
templ.WithErrorHandler(errorHandler),
)
h.ServeHTTP(w, r)
if !errorHandlerCalled {
t.Error("expected error handler to be called")
}
// Expect the status code to be 201, not 400, because in streaming mode,
// we have to write the header before we can call the error handler.
if actualRepsonseCode := w.Result().StatusCode; http.StatusCreated != actualRepsonseCode {
t.Errorf("expected status %d, got %d", http.StatusCreated, actualRepsonseCode)
}
// Expect the body to be "BodyError message", not just "Error message" because
// in streaming mode, we've already written part of the body to the response, unlike in
// standard mode where the body is written to a buffer before the response is written,
// ensuring that partial responses are not sent.
actualBody, err := io.ReadAll(w.Result().Body)
if err != nil {
t.Errorf("failed to read body: %v", err)
}
if diff := cmp.Diff("BodyError message", string(actualBody)); diff != "" {
t.Error(diff)
}
})
}