Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(protoio): Better error when passing in unexpected type #48

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix(protoio): Better error when passing in unexpected type
MarcoPolo committed Feb 3, 2025
commit 29cd5a3643c8043505337c754dcb9bb19f13c332
51 changes: 49 additions & 2 deletions msgio_test.go
Original file line number Diff line number Diff line change
@@ -6,10 +6,15 @@
"fmt"
"io"
"math/rand"
str "strings"
"strings"
"sync"
"testing"
"time"

"github.com/libp2p/go-msgio/pbio/pb"
"github.com/libp2p/go-msgio/protoio"

Check failure on line 15 in msgio_test.go

GitHub Actions / go-check / All

"github.com/libp2p/go-msgio/protoio" is deprecated: GoGo Protobuf is deprecated and unmaintained. (SA1019)
"github.com/multiformats/go-varint"
"google.golang.org/protobuf/proto"
)

func randBuf(r *rand.Rand, size int) []byte {
@@ -79,7 +84,7 @@
}

twoErrors := multiErr([]error{errors.New("one"), errors.New("two")})
if eStr := twoErrors.Error(); !str.Contains(eStr, "one") && !str.Contains(eStr, "two") {
if eStr := twoErrors.Error(); !strings.Contains(eStr, "one") && !strings.Contains(eStr, "two") {
t.Fatal("Expected error messages not included")
}
}
@@ -328,3 +333,45 @@
t.Fatal("Expected short buffer error")
}
}

func TestHandleProtoGeneratedByGoogleProtobufInProtoio(t *testing.T) {
record := &pb.TestRecord{
Uint32: 42,
Uint64: 84,
Bytes: []byte("test bytes"),
String_: "test string",
Int32: -42,
Int64: -84,
}

recordBytes, err := proto.Marshal(record)
if err != nil {
t.Fatal(err)
}

for _, tc := range []string{"read", "write"} {
t.Run(tc, func(t *testing.T) {
var buf bytes.Buffer
readRecord := &pb.TestRecord{}
switch tc {
case "read":
buf.Write(varint.ToUvarint(uint64(len(recordBytes))))
buf.Write(recordBytes)

reader := protoio.NewDelimitedReader(&buf, 1024)
defer reader.Close()
err = reader.ReadMsg(readRecord)
case "write":
writer := protoio.NewDelimitedWriter(&buf)
err = writer.WriteMsg(record)
}
if err == nil {
t.Fatal("expected error")
}
expectedError := "google Protobuf message passed into a GoGo Protobuf"
if !strings.Contains(err.Error(), expectedError) {
t.Fatalf("expected error to contain '%s'", expectedError)
}
})
}
}
15 changes: 15 additions & 0 deletions protoio/isgoog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package protoio

import (
"github.com/gogo/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

// isGoogleProtobufMsg checks if the given proto.Message was
// generated by the official Google protobuf compiler
func isGoogleProtobufMsg(msg proto.Message) bool {
_, ok := msg.(interface {
ProtoReflect() protoreflect.Message
})
return ok
}
19 changes: 18 additions & 1 deletion protoio/uvarint_reader.go
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ package protoio

import (
"bufio"
"errors"
"fmt"
"io"
"os"
@@ -82,7 +83,23 @@ func (ur *uvarintReader) ReadMsg(msg proto.Message) (err error) {
if _, err := io.ReadFull(ur.r, buf); err != nil {
return err
}
return proto.Unmarshal(buf, msg)

// Hoist up gogo's proto.Unmarshal logic so we can also check if this is a google protobuf message
msg.Reset()
if u, ok := msg.(interface {
XXX_Unmarshal([]byte) error
}); ok {
return u.XXX_Unmarshal(buf)
} else if u, ok := msg.(interface {
Unmarshal([]byte) error
}); ok {
return u.Unmarshal(buf)
} else if isGoogleProtobufMsg(msg) {
return errors.New("google Protobuf message passed into a GoGo Protobuf reader. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto")
}

// Fallback to GoGo's proto.Unmarshal around this buffer
return proto.NewBuffer(buf).Unmarshal(msg)
}

func (ur *uvarintReader) Close() error {
5 changes: 5 additions & 0 deletions protoio/uvarint_writer.go
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
package protoio

import (
"errors"
"fmt"
"io"
"os"
@@ -80,6 +81,10 @@ func (uw *uvarintWriter) WriteMsg(msg proto.Message) (err error) {
}
}

if isGoogleProtobufMsg(msg) {
return errors.New("google Protobuf message passed into a GoGo Protobuf writer. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto")
}

// fallback
data, err = proto.Marshal(msg)
if err != nil {