Skip to content

Commit 29cd5a3

Browse files
committed
fix(protoio): Better error when passing in unexpected type
1 parent 0967bbf commit 29cd5a3

File tree

4 files changed

+87
-3
lines changed

4 files changed

+87
-3
lines changed

msgio_test.go

+49-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@ import (
66
"fmt"
77
"io"
88
"math/rand"
9-
str "strings"
9+
"strings"
1010
"sync"
1111
"testing"
1212
"time"
13+
14+
"github.com/libp2p/go-msgio/pbio/pb"
15+
"github.com/libp2p/go-msgio/protoio"
16+
"github.com/multiformats/go-varint"
17+
"google.golang.org/protobuf/proto"
1318
)
1419

1520
func randBuf(r *rand.Rand, size int) []byte {
@@ -79,7 +84,7 @@ func TestMultiError(t *testing.T) {
7984
}
8085

8186
twoErrors := multiErr([]error{errors.New("one"), errors.New("two")})
82-
if eStr := twoErrors.Error(); !str.Contains(eStr, "one") && !str.Contains(eStr, "two") {
87+
if eStr := twoErrors.Error(); !strings.Contains(eStr, "one") && !strings.Contains(eStr, "two") {
8388
t.Fatal("Expected error messages not included")
8489
}
8590
}
@@ -328,3 +333,45 @@ func SubtestReadShortBuffer(t *testing.T, writer WriteCloser, reader ReadCloser)
328333
t.Fatal("Expected short buffer error")
329334
}
330335
}
336+
337+
func TestHandleProtoGeneratedByGoogleProtobufInProtoio(t *testing.T) {
338+
record := &pb.TestRecord{
339+
Uint32: 42,
340+
Uint64: 84,
341+
Bytes: []byte("test bytes"),
342+
String_: "test string",
343+
Int32: -42,
344+
Int64: -84,
345+
}
346+
347+
recordBytes, err := proto.Marshal(record)
348+
if err != nil {
349+
t.Fatal(err)
350+
}
351+
352+
for _, tc := range []string{"read", "write"} {
353+
t.Run(tc, func(t *testing.T) {
354+
var buf bytes.Buffer
355+
readRecord := &pb.TestRecord{}
356+
switch tc {
357+
case "read":
358+
buf.Write(varint.ToUvarint(uint64(len(recordBytes))))
359+
buf.Write(recordBytes)
360+
361+
reader := protoio.NewDelimitedReader(&buf, 1024)
362+
defer reader.Close()
363+
err = reader.ReadMsg(readRecord)
364+
case "write":
365+
writer := protoio.NewDelimitedWriter(&buf)
366+
err = writer.WriteMsg(record)
367+
}
368+
if err == nil {
369+
t.Fatal("expected error")
370+
}
371+
expectedError := "google Protobuf message passed into a GoGo Protobuf"
372+
if !strings.Contains(err.Error(), expectedError) {
373+
t.Fatalf("expected error to contain '%s'", expectedError)
374+
}
375+
})
376+
}
377+
}

protoio/isgoog.go

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package protoio
2+
3+
import (
4+
"github.com/gogo/protobuf/proto"
5+
"google.golang.org/protobuf/reflect/protoreflect"
6+
)
7+
8+
// isGoogleProtobufMsg checks if the given proto.Message was
9+
// generated by the official Google protobuf compiler
10+
func isGoogleProtobufMsg(msg proto.Message) bool {
11+
_, ok := msg.(interface {
12+
ProtoReflect() protoreflect.Message
13+
})
14+
return ok
15+
}

protoio/uvarint_reader.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ package protoio
3434

3535
import (
3636
"bufio"
37+
"errors"
3738
"fmt"
3839
"io"
3940
"os"
@@ -82,7 +83,23 @@ func (ur *uvarintReader) ReadMsg(msg proto.Message) (err error) {
8283
if _, err := io.ReadFull(ur.r, buf); err != nil {
8384
return err
8485
}
85-
return proto.Unmarshal(buf, msg)
86+
87+
// Hoist up gogo's proto.Unmarshal logic so we can also check if this is a google protobuf message
88+
msg.Reset()
89+
if u, ok := msg.(interface {
90+
XXX_Unmarshal([]byte) error
91+
}); ok {
92+
return u.XXX_Unmarshal(buf)
93+
} else if u, ok := msg.(interface {
94+
Unmarshal([]byte) error
95+
}); ok {
96+
return u.Unmarshal(buf)
97+
} else if isGoogleProtobufMsg(msg) {
98+
return errors.New("google Protobuf message passed into a GoGo Protobuf reader. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto")
99+
}
100+
101+
// Fallback to GoGo's proto.Unmarshal around this buffer
102+
return proto.NewBuffer(buf).Unmarshal(msg)
86103
}
87104

88105
func (ur *uvarintReader) Close() error {

protoio/uvarint_writer.go

+5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
package protoio
3434

3535
import (
36+
"errors"
3637
"fmt"
3738
"io"
3839
"os"
@@ -80,6 +81,10 @@ func (uw *uvarintWriter) WriteMsg(msg proto.Message) (err error) {
8081
}
8182
}
8283

84+
if isGoogleProtobufMsg(msg) {
85+
return errors.New("google Protobuf message passed into a GoGo Protobuf writer. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto")
86+
}
87+
8388
// fallback
8489
data, err = proto.Marshal(msg)
8590
if err != nil {

0 commit comments

Comments
 (0)