Skip to content

Commit 4342b77

Browse files
authored
Make request body readable and changeable in interceptFunc and beforeFunc (#81)
Hi! This PR gives users the ability to access and change `request.Body` in functions registered with `RegisterInterceptFunc` and `RegisterBeforeFunc`. I've already described it in issue #80. Also, I've added couple of simple tests to check ability to change request data in these functions. **Changes:** 1. In **rpc/v2** close `request.Body` after the execution of `beforeFunc`'s and `interceptFunc`'s. Update codec request info after calls to functions above. 2. Read request body bytes, decode it to codec format and provide `bytes.Buffer` and `request.Body` for underlying functions in **v2/json**, **v2/json2** and **v2/protorpc** codecs. Of course, exists a better way to do that, but it will require changes in the signature of `RegisterInterceptFunc` and `RegisterBeforeFunc` and it would be breaking changes. If both of these methods will have an original `*http.Request` as input parameter, they could be executed before the creation of codec. In this case, users can access and alter request data, and only after that it would be read by the codec and marshaled to service request params. But since these changes are breaking, it's not an option at the moment, probably it can fit the next version or release. Would be nice to know your opinion, thanks!
1 parent 39123e3 commit 4342b77

File tree

5 files changed

+190
-11
lines changed

5 files changed

+190
-11
lines changed

v2/json/server.go

+17-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
package json
77

88
import (
9+
"bytes"
910
"encoding/json"
1011
"errors"
1112
"fmt"
13+
"io"
1214
"log"
1315
"net/http"
1416

@@ -79,9 +81,22 @@ func (c *Codec) NewRequest(r *http.Request) rpc.CodecRequest {
7981

8082
// newCodecRequest returns a new CodecRequest.
8183
func newCodecRequest(r *http.Request) rpc.CodecRequest {
82-
// Decode the request body and check if RPC method is valid.
8384
req := new(serverRequest)
84-
err := json.NewDecoder(r.Body).Decode(req)
85+
86+
// Copy request body for decoding and access of underlying methods
87+
b, err := io.ReadAll(r.Body)
88+
if err != nil {
89+
return &CodecRequest{request: req, err: err}
90+
}
91+
// Close original body
92+
r.Body.Close()
93+
94+
// Decode the request body and check if RPC method is valid.
95+
err = json.Unmarshal(b, req)
96+
97+
// Add close method to buffer and pass as request body
98+
r.Body = io.NopCloser(bytes.NewBuffer(b))
99+
85100
return &CodecRequest{request: req, err: err}
86101
}
87102

v2/json2/server.go

+21-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
package json2
77

88
import (
9+
"bytes"
910
"encoding/json"
11+
"io"
1012
"net/http"
1113

1214
"github.com/gorilla/rpc/v2"
@@ -99,10 +101,24 @@ func (c *Codec) NewRequest(r *http.Request) rpc.CodecRequest {
99101

100102
// newCodecRequest returns a new CodecRequest.
101103
func newCodecRequest(r *http.Request, encoder rpc.Encoder, errorMapper func(error) error) rpc.CodecRequest {
102-
// Decode the request body and check if RPC method is valid.
103104
req := new(serverRequest)
104-
err := json.NewDecoder(r.Body).Decode(req)
105105

106+
// Copy request body for decoding and access of underlying methods
107+
b, err := io.ReadAll(r.Body)
108+
if err != nil {
109+
err = &Error{
110+
Code: E_PARSE,
111+
Message: err.Error(),
112+
Data: req,
113+
}
114+
115+
return &CodecRequest{request: req, err: err, encoder: encoder, errorMapper: errorMapper}
116+
}
117+
// Close original body
118+
r.Body.Close()
119+
120+
// Decode the request body and check if RPC method is valid.
121+
err = json.Unmarshal(b, req)
106122
if err != nil {
107123
err = &Error{
108124
Code: E_PARSE,
@@ -117,6 +133,9 @@ func newCodecRequest(r *http.Request, encoder rpc.Encoder, errorMapper func(erro
117133
}
118134
}
119135

136+
// Add close method to buffer and pass as request body
137+
r.Body = io.NopCloser(bytes.NewBuffer(b))
138+
120139
return &CodecRequest{request: req, err: err, encoder: encoder, errorMapper: errorMapper}
121140
}
122141

v2/protorpc/server.go

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package protorpc
77

88
import (
9+
"bytes"
910
"encoding/json"
1011
"errors"
1112
"fmt"
@@ -78,11 +79,24 @@ func newCodecRequest(r *http.Request) rpc.CodecRequest {
7879
return &CodecRequest{request: req, err: fmt.Errorf("rpc: no method: %s", path)}
7980
}
8081
req.Method = path[index+1:]
81-
err := json.NewDecoder(r.Body).Decode(&req.Params)
82+
83+
// Copy request body for decoding and access of underlying methods
84+
b, err := io.ReadAll(r.Body)
85+
if err != nil {
86+
return &CodecRequest{request: req, err: err}
87+
}
88+
// Close original body
89+
r.Body.Close()
90+
91+
err = json.Unmarshal(b, &req.Params)
8292
var codecErr error
8393
if err != io.EOF {
8494
codecErr = err
8595
}
96+
97+
// Add close method to buffer and pass as request body
98+
r.Body = io.NopCloser(bytes.NewBuffer(b))
99+
86100
return &CodecRequest{request: req, err: codecErr}
87101
}
88102

v2/server.go

+18-6
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
178178
codecReq.WriteError(w, http.StatusBadRequest, errGet)
179179
return
180180
}
181-
// Decode the args.
182-
args := reflect.New(methodSpec.argsType)
183-
if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil {
184-
codecReq.WriteError(w, http.StatusBadRequest, errRead)
185-
return
186-
}
187181

188182
// Call the registered Intercept Function
189183
if s.interceptFunc != nil {
@@ -206,6 +200,24 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
206200
s.beforeFunc(requestInfo)
207201
}
208202

203+
// Close request body after Intercept and Before Function if it exists
204+
// if it's already closed, error still would be nil
205+
if r.Body != nil {
206+
r.Body.Close()
207+
}
208+
209+
// Update codec request with request values after Intercept and Before functions if they exist
210+
if s.interceptFunc != nil || s.beforeFunc != nil {
211+
codecReq = codec.NewRequest(r)
212+
}
213+
214+
// Decode the args.
215+
args := reflect.New(methodSpec.argsType)
216+
if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil {
217+
codecReq.WriteError(w, http.StatusBadRequest, errRead)
218+
return
219+
}
220+
209221
// Prepare the reply, we need it even if validation fails
210222
reply := reflect.New(methodSpec.replyType)
211223
errValue := []reflect.Value{nilErrorValue}

v2/server_test.go

+119
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
package rpc
77

88
import (
9+
"bytes"
10+
"encoding/json"
911
"errors"
12+
"io"
1013
"log"
1114
"net/http"
1215
"strconv"
@@ -94,6 +97,30 @@ func (r MockCodecRequest) WriteError(w http.ResponseWriter, status int, err erro
9497
}
9598
}
9699

100+
type MockCodecJson struct {
101+
}
102+
103+
func (c MockCodecJson) NewRequest(r *http.Request) CodecRequest {
104+
if r.Body == nil {
105+
return MockCodecRequest{}
106+
}
107+
108+
inp := new(Service1Request)
109+
b, err := io.ReadAll(r.Body)
110+
if err != nil {
111+
return MockCodecRequest{}
112+
}
113+
r.Body.Close()
114+
115+
if err := json.Unmarshal(b, inp); err != nil {
116+
return MockCodecRequest{}
117+
}
118+
119+
r.Body = io.NopCloser(bytes.NewBuffer(b))
120+
121+
return MockCodecRequest{inp.A, inp.B}
122+
}
123+
97124
type MockResponseWriter struct {
98125
header http.Header
99126
Status int
@@ -211,6 +238,98 @@ func TestInterception(t *testing.T) {
211238
t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expected))
212239
}
213240
}
241+
242+
func TestInterceptionWithChange(t *testing.T) {
243+
const (
244+
A = 2
245+
B = 3
246+
C = 5
247+
)
248+
expectedBeforeChange := A * B
249+
expectedAfterChange := A * C
250+
251+
r2, err := http.NewRequest("POST", "mocked/request", bytes.NewBuffer([]byte(`{"A": 2, "B":5}`)))
252+
if err != nil {
253+
t.Fatal(err)
254+
}
255+
256+
s := NewServer()
257+
s.RegisterService(new(Service1), "")
258+
s.RegisterCodec(MockCodecJson{}, "mock")
259+
s.RegisterInterceptFunc(func(i *RequestInfo) *http.Request {
260+
return r2
261+
})
262+
263+
r, err := http.NewRequest("POST", "", bytes.NewBuffer([]byte(`{A: 2, B:3}`)))
264+
if err != nil {
265+
t.Fatal(err)
266+
}
267+
r.Header.Set("Content-Type", "mock; dummy")
268+
w := NewMockResponseWriter()
269+
s.ServeHTTP(w, r)
270+
if w.Status != 200 {
271+
t.Errorf("Status was %d, should be 200.", w.Status)
272+
}
273+
274+
if w.Body != strconv.Itoa(expectedBeforeChange) && w.Body == strconv.Itoa(expectedAfterChange) {
275+
return
276+
}
277+
278+
t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expectedAfterChange))
279+
}
280+
281+
func TestBeforeFunc(t *testing.T) {
282+
const (
283+
A = 2
284+
B = 3
285+
C = 5
286+
)
287+
expectedBeforeChange := A * B
288+
expectedAfterChange := A * C
289+
290+
s := NewServer()
291+
s.RegisterService(new(Service1), "")
292+
s.RegisterCodec(MockCodecJson{}, "mock")
293+
s.RegisterBeforeFunc(func(i *RequestInfo) {
294+
r := i.Request
295+
296+
inp := new(Service1Request)
297+
err := json.NewDecoder(r.Body).Decode(inp)
298+
if err != nil {
299+
t.Error(err)
300+
t.Fail()
301+
}
302+
303+
inp.B = C
304+
305+
b, err := json.Marshal(inp)
306+
if err != nil {
307+
t.Error(err)
308+
t.Fail()
309+
}
310+
311+
r.Body = io.NopCloser(bytes.NewBuffer(b))
312+
i.Request = r
313+
})
314+
315+
r, err := http.NewRequest("POST", "", bytes.NewBuffer([]byte(`{"A":2, "B":10}`)))
316+
if err != nil {
317+
t.Fatal(err)
318+
}
319+
r.Header.Set("Content-Type", "mock; dummy")
320+
w := NewMockResponseWriter()
321+
s.ServeHTTP(w, r)
322+
if w.Status != 200 {
323+
t.Errorf("Status was %d, should be 200.", w.Status)
324+
}
325+
326+
if w.Body != strconv.Itoa(expectedBeforeChange) && w.Body == strconv.Itoa(expectedAfterChange) {
327+
return
328+
}
329+
330+
t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expectedAfterChange))
331+
}
332+
214333
func TestValidationSuccessful(t *testing.T) {
215334
const (
216335
A = 2

0 commit comments

Comments
 (0)