Skip to content

Commit d3e516b

Browse files
Move websocket headers to opt function 'WithWebsocketHeaders' (#365)
Follow-up up on the discussion in #360 (review). Move websocket headers to and opt function 'WithWebsocketHeaders'. Note this is a breaking change for users using the main branch (but not for users on tagged releases). I have: - [x] Written a clear PR title and description (above) - [x] Signed the [Khan Academy CLA](https://www.khanacademy.org/r/cla) - [x] Added tests covering my changes, if applicable - [x] Included a link to the issue fixed, if applicable - [x] Included documentation, for new features - [x] Added an entry to the changelog
1 parent 5913cd6 commit d3e516b

File tree

5 files changed

+43
-13
lines changed

5 files changed

+43
-13
lines changed

graphql/client.go

+13-8
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,10 @@ type WebSocketOption func(*webSocketClient)
133133
//
134134
// The client does not support queries nor mutations, and will return an error
135135
// if passed a request that attempts one.
136-
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header, opts ...WebSocketOption) WebSocketClient {
137-
if headers == nil {
138-
headers = http.Header{}
139-
}
140-
if headers.Get("Sec-WebSocket-Protocol") == "" {
141-
headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
142-
}
136+
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, opts ...WebSocketOption) WebSocketClient {
143137
client := &webSocketClient{
144138
Dialer: wsDialer,
145-
Header: headers,
139+
header: http.Header{},
146140
errChan: make(chan error),
147141
endpoint: endpoint,
148142
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
@@ -152,6 +146,10 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head
152146
opt(client)
153147
}
154148

149+
if client.header.Get("Sec-WebSocket-Protocol") == "" {
150+
client.header.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
151+
}
152+
155153
return client
156154
}
157155

@@ -163,6 +161,13 @@ func WithConnectionParams(connParams map[string]interface{}) WebSocketOption {
163161
}
164162
}
165163

164+
// WithWebsocketHeader sets a header to be sent to the server.
165+
func WithWebsocketHeader(header http.Header) WebSocketOption {
166+
return func(ws *webSocketClient) {
167+
ws.header = header
168+
}
169+
}
170+
166171
func newClient(endpoint string, httpClient Doer, method string) Client {
167172
if httpClient == nil || httpClient == (*http.Client)(nil) {
168173
httpClient = http.DefaultClient

graphql/websocket.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ const (
4545

4646
type webSocketClient struct {
4747
Dialer Dialer
48-
Header http.Header
48+
header http.Header
4949
endpoint string
5050
conn WSConn
5151
connParams map[string]interface{}
@@ -169,7 +169,7 @@ func checkConnectionAckReceived(message []byte) (bool, error) {
169169
}
170170

171171
func (w *webSocketClient) Start(ctx context.Context) (errChan chan error, err error) {
172-
w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.Header)
172+
w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.header)
173173
if err != nil {
174174
return nil, err
175175
}

internal/integration/integration_test.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,21 @@ func TestSubscriptionConnectionParams(t *testing.T) {
164164
opts []graphql.WebSocketOption
165165
}{
166166
{
167-
name: "authorized_user_gets_counter",
167+
name: "connection_params_authorized_user_gets_counter",
168168
opts: []graphql.WebSocketOption{
169169
graphql.WithConnectionParams(map[string]interface{}{
170170
authKey: "authorized-user-token",
171171
}),
172172
},
173173
},
174+
{
175+
name: "http_header_authorized_user_gets_counter",
176+
opts: []graphql.WebSocketOption{
177+
graphql.WithWebsocketHeader(http.Header{
178+
authKey: []string{"authorized-user-token"},
179+
}),
180+
},
181+
},
174182
{
175183
name: "unauthorized_user_gets_error",
176184
expectedError: "input: countAuthorized unauthorized\n",

internal/integration/roundtrip.go

-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ func newRoundtripWebSocketClient(t *testing.T, endpoint string, opts ...graphql.
167167
wsWrapped: graphql.NewClientUsingWebSocket(
168168
endpoint,
169169
&MyDialer{Dialer: dialer},
170-
nil,
171170
opts...,
172171
),
173172
t: t,

internal/integration/server/server.go

+19-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package server
33
import (
44
"context"
55
"fmt"
6+
"net/http"
67
"net/http/httptest"
78
"strconv"
89
"time"
@@ -198,6 +199,20 @@ func getAuthToken(ctx context.Context) string {
198199
return ""
199200
}
200201

202+
func authHeaderMiddleware(handler http.Handler) http.Handler {
203+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
204+
ctx := r.Context()
205+
206+
token := r.Header.Get(AuthKey)
207+
if token != "" {
208+
ctx = withAuthToken(ctx, token)
209+
}
210+
211+
r = r.WithContext(ctx)
212+
handler.ServeHTTP(w, r)
213+
})
214+
}
215+
201216
func RunServer() *httptest.Server {
202217
gqlgenServer := handler.New(NewExecutableSchema(Config{Resolvers: &resolver{}}))
203218
gqlgenServer.AddTransport(transport.POST{})
@@ -216,7 +231,10 @@ func RunServer() *httptest.Server {
216231
graphql.RegisterExtension(ctx, "foobar", "test")
217232
return next(ctx)
218233
})
219-
return httptest.NewServer(gqlgenServer)
234+
235+
server := authHeaderMiddleware(gqlgenServer)
236+
237+
return httptest.NewServer(server)
220238
}
221239

222240
type (

0 commit comments

Comments
 (0)