From c8352586d1c2bc8605fa3b1cbb50621d52351612 Mon Sep 17 00:00:00 2001 From: Fabio Forni Date: Mon, 1 Nov 2021 16:20:56 +0100 Subject: [PATCH] unmarshal: Fix unmarshaling of TextUnmarshaler types --- unmarshal.go | 18 +++++++++++------- unmarshal_test.go | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/unmarshal.go b/unmarshal.go index 5fe9f63..80cb83d 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -105,6 +105,11 @@ func unmarshalRecursive(data map[string]string, prefix string, stru reflect.Valu } func stringToField(str string, field reflect.Value, omitempty bool) error { + addr := field.Addr() // Unmarshaling always requires a pointer receiver. + if addr.Type().Implements(textUnmarshalerType) { + return addr.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(str)) + } + var ( val reflect.Value err error @@ -157,13 +162,8 @@ func stringToField(str string, field reflect.Value, omitempty bool) error { val, err = reflect.ValueOf(v), e case reflect.String: val, err = reflect.ValueOf(str), nil - case reflect.Struct: - f := field.Addr() // Unmarshaling always requires a pointer receiver. - t := f.Type() - if t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) { - return f.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(str)) - } - return fmt.Errorf("%s doesn't implement TextUnmarshaler", t) + default: + return fmt.Errorf("%s doesn't implement TextUnmarshaler", addr) } if err != nil { return err @@ -174,3 +174,7 @@ func stringToField(str string, field reflect.Value, omitempty bool) error { field.Set(val) return nil } + +var ( + textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() +) diff --git a/unmarshal_test.go b/unmarshal_test.go index bb6eab3..e594172 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -117,7 +117,7 @@ func TestUnmarshalScalars(t *testing.T) { {In: map[string]string{"V": "(100.1+80.1i)"}, Out: struct{ V complex128 }{100.1 + 80.1i}}, {In: map[string]string{"V": "str"}, Out: struct{ V string }{"str"}}, {In: map[string]string{"V": "a test"}, Out: struct{ V StubTextUnmarshaler }{StubTextUnmarshaler{S: "a test"}}}, - {In: map[string]string{"V": "100"}, Out: struct{ V StubIntStringer }{StubIntStringer(100)}}, + {In: map[string]string{"V": "100"}, Out: struct{ V StubIntUnmarshaler }{StubIntUnmarshaler(100)}}, } for _, test := range tests { zero := reflect.New(reflect.TypeOf(test.Out))