Skip to content

Commit

Permalink
unmarshal: Fix unmarshaling of TextUnmarshaler types
Browse files Browse the repository at this point in the history
  • Loading branch information
livingsilver94 committed Nov 1, 2021
1 parent a3b6497 commit c835258
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
18 changes: 11 additions & 7 deletions unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ func unmarshalRecursive(data map[string]string, prefix string, stru reflect.Valu
}

func stringToField(str string, field reflect.Value, omitempty bool) error {
addr := field.Addr() // Unmarshaling always requires a pointer receiver.
if addr.Type().Implements(textUnmarshalerType) {
return addr.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(str))
}

var (
val reflect.Value
err error
Expand Down Expand Up @@ -157,13 +162,8 @@ func stringToField(str string, field reflect.Value, omitempty bool) error {
val, err = reflect.ValueOf(v), e
case reflect.String:
val, err = reflect.ValueOf(str), nil
case reflect.Struct:
f := field.Addr() // Unmarshaling always requires a pointer receiver.
t := f.Type()
if t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) {
return f.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(str))
}
return fmt.Errorf("%s doesn't implement TextUnmarshaler", t)
default:
return fmt.Errorf("%s doesn't implement TextUnmarshaler", addr)
}
if err != nil {
return err
Expand All @@ -174,3 +174,7 @@ func stringToField(str string, field reflect.Value, omitempty bool) error {
field.Set(val)
return nil
}

var (
textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
)
2 changes: 1 addition & 1 deletion unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func TestUnmarshalScalars(t *testing.T) {
{In: map[string]string{"V": "(100.1+80.1i)"}, Out: struct{ V complex128 }{100.1 + 80.1i}},
{In: map[string]string{"V": "str"}, Out: struct{ V string }{"str"}},
{In: map[string]string{"V": "a test"}, Out: struct{ V StubTextUnmarshaler }{StubTextUnmarshaler{S: "a test"}}},
{In: map[string]string{"V": "100"}, Out: struct{ V StubIntStringer }{StubIntStringer(100)}},
{In: map[string]string{"V": "100"}, Out: struct{ V StubIntUnmarshaler }{StubIntUnmarshaler(100)}},
}
for _, test := range tests {
zero := reflect.New(reflect.TypeOf(test.Out))
Expand Down

0 comments on commit c835258

Please sign in to comment.