Skip to content

Commit

Permalink
Merge branch '1.1.0'. Fixes #2.
Browse files Browse the repository at this point in the history
  • Loading branch information
livingsilver94 committed Jan 25, 2022
2 parents 80a58ae + 34bea50 commit 3695d3c
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 79 deletions.
5 changes: 3 additions & 2 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ var (
ErrNilValue = errors.New("nil")
// ErrNotPointer is retuned when an argument passed is not a pointer but it should be.
ErrNotPointer = errors.New("not a pointer")
// ErrNotStruct is returned when an argument passed is not a struct but it should be.
ErrNotStruct = errors.New("not a struct type")
// ErrNoCodec is returned when a type cannot be marshaled or unmarshaled,
// e.g. it is neither a struct nor implements StringMap(Un)marshaler.
ErrNoCodec = errors.New("not an encodable or decodable type")
)

func errIs(something interface{}, err error) error {
Expand Down
55 changes: 36 additions & 19 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@ import (

const inlineSep = "."

// Marshal returns the map[string]string representation of v, which must be a struct,
// by reading every exported field and translating it into a (key, value) pair to be
// added to the resulting map. Interfaces or pointers to struct are also accepted.
// StringMapMarshaler is the interface implemented by types that can marshal themselves into a map of strings.
type StringMapMarshaler interface {
MarshalStringMap() (map[string]string, error)
}

// Marshal returns the map[string]string representation of v, which must be a struct
// or implementing StringMapMarshaler. When implementing the interface, the map is returned verbatim
// from its method along with the error value. When not, Marshal reads every exported field and translates it
// into a (key, value) pair to be added to the resulting map. Interfaces or pointers to struct are also accepted.
//
// Marshal converts all non-reference built-in types except arrays, plus
// Marshal converts all fields of non-reference built-in types except arrays, plus
// structs implementing encoding.TextMarshaler or fmt.Stringer, checked in this exact order.
//
// The encoding of each struct field can be customized by the format string stored under the "redmap"
Expand Down Expand Up @@ -42,34 +48,30 @@ const inlineSep = "."
// // Field appears in the map as key "-".
// Field int `redmap:"-,"`
//
// // Field must be a struct. Field is flattened and its fields
// // are added to the map as (key, value) pairs, where the keys
// // are constructed in the "customName.subFieldName" format.
// // Field must be a struct or implementing StringMapMarshaler.
// // The resulting map is added to the final map with keys flattened,
// // constructed in the "customName.subKeyName" format.
// Field int `redmap:"customName,inline"`
func Marshal(v interface{}) (map[string]string, error) {
val, err := structValue(v)
val, err := validValue(v)
if err != nil {
return nil, err
}
ret := make(map[string]string)
return ret, marshalRecursive(ret, "", val)
}

func structValue(v interface{}) (reflect.Value, error) {
func validValue(v interface{}) (reflect.Value, error) {
val := reflect.ValueOf(v)
kin := val.Kind()
for kin == reflect.Interface || kin == reflect.Ptr {
val = val.Elem()
kin = val.Kind()
}
switch kin {
case reflect.Struct:
return val, nil
case reflect.Invalid:
if kin == reflect.Invalid {
return reflect.Value{}, ErrNilValue
default:
return reflect.Value{}, errIs(val.Type(), ErrNotStruct)
}
return val, nil
}

// marshalRecursive marshal a struct represented by val into a map[string]string.
Expand All @@ -78,6 +80,12 @@ func structValue(v interface{}) (reflect.Value, error) {
// name in case of an inlined inner struct.
func marshalRecursive(mp map[string]string, prefix string, stru reflect.Value) error {
typ := stru.Type()
if typ.Implements(mapMarshalerType) {
return structToMap(mp, prefix, stru)
}
if stru.Kind() != reflect.Struct {
return errIs(stru.Type(), ErrNoCodec)
}
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
if field.PkgPath != "" {
Expand All @@ -99,9 +107,6 @@ func marshalRecursive(mp map[string]string, prefix string, stru reflect.Value) e
}

if tags.inline {
if kind := value.Kind(); kind != reflect.Struct {
return fmt.Errorf("cannot inline: %w", errIs(value.Type(), ErrNotStruct))
}
err := marshalRecursive(mp, prefix+tags.name+inlineSep, value)
if err != nil {
return err
Expand All @@ -117,6 +122,17 @@ func marshalRecursive(mp map[string]string, prefix string, stru reflect.Value) e
return nil
}

func structToMap(mp map[string]string, prefix string, stru reflect.Value) error {
conv, err := stru.Interface().(StringMapMarshaler).MarshalStringMap()
if err != nil {
return err
}
for k, v := range conv {
mp[prefix+k] = v
}
return nil
}

func fieldToString(val reflect.Value) (string, error) {
typ := val.Type()
if typ.Implements(textMarshalerType) {
Expand Down Expand Up @@ -149,6 +165,7 @@ func fieldToString(val reflect.Value) (string, error) {
}

var (
textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
mapMarshalerType = reflect.TypeOf(new(StringMapMarshaler)).Elem()
stringerType = reflect.TypeOf(new(fmt.Stringer)).Elem()
textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
)
104 changes: 83 additions & 21 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,49 @@ import (
)

const (
stringerOut = "stub"
textMarshalerOut = "stubtext"
stringerOut = "stub" // stringerOut is the output of fmt.Stringer implementations.
textMarshalerOut = "stubtext" // textMarshalerOut is the output of encoding.TextMarshaler implementations.
)

type StubStringer struct{}
var (
// mapMarshalerOut is the output of redmap.StringMapMarshaler implementations.
mapMarshalerOut = map[string]string{"field1": "value1", "field2": "value2"}
)

// stubStringer implements the fmt.Stringer interface.
type stubStringer struct{}

func (s stubStringer) String() string { return stringerOut }

// stubIntStringer is an int that implements fmt.Stringer,
// so that we can test if a non-struct type is correctly handled as an interface.
type stubIntStringer int

func (s stubIntStringer) String() string { return stringerOut }

func (s StubStringer) String() string { return stringerOut }
// stubTextMarshaler implements the encoding.TextMarshaler interface.
type stubTextMarshaler struct{}

// StubIntStringer implements fmt.Stringer but doesn't rely on an
// underlying struct. Useful to test whether we can detect
// interfaces independently from their underlying type.
type StubIntStringer int
func (s stubTextMarshaler) MarshalText() ([]byte, error) { return []byte(textMarshalerOut), nil }

func (s StubIntStringer) String() string { return stringerOut }
// stubMapMarshaler implements the redmap.StringMapMarshaler interface.
type stubMapMarshaler struct{}

type StubTextMarshaler struct{}
func (s stubMapMarshaler) MarshalStringMap() (map[string]string, error) {
return mapMarshalerOut, nil
}

// stubIntMapMarshaler is an int that implements redmap.StringMapMarshaler,
// so that we can test if a non-struct type is correctly handled as an interface.
type stubIntMapMarshaler int

func (s StubTextMarshaler) MarshalText() ([]byte, error) { return []byte(textMarshalerOut), nil }
func (s stubIntMapMarshaler) MarshalStringMap() (map[string]string, error) {
return mapMarshalerOut, nil
}

func TestMarshalValidType(t *testing.T) {
var (
stub StubStringer = StubStringer{}
stub stubStringer = stubStringer{}
ifac fmt.Stringer = stub
)
types := []interface{}{
Expand All @@ -50,7 +71,7 @@ func TestMarshalValidType(t *testing.T) {

func TestMarshalNil(t *testing.T) {
var (
stub *StubStringer = nil
stub *stubStringer = nil
ifac fmt.Stringer = stub
)
nils := []interface{}{nil, stub, ifac}
Expand All @@ -66,8 +87,8 @@ func TestMarshalInvalidType(t *testing.T) {
tests := []interface{}{noStruct, &noStruct}
for _, test := range tests {
_, err := redmap.Marshal(test)
if !errors.Is(err, redmap.ErrNotStruct) {
t.Fatalf("Unmarshal returned error %q but %q was expected", err, redmap.ErrNotStruct)
if !errors.Is(err, redmap.ErrNoCodec) {
t.Fatalf("Unmarshal returned error %q but %q was expected", err, redmap.ErrNoCodec)
}
}
}
Expand All @@ -87,14 +108,14 @@ func TestMarshalScalars(t *testing.T) {
{In: struct{ V string }{"str"}, Out: map[string]string{"V": "str"}},

// // Marshal interfaces by passing the real value.
{In: struct{ V StubStringer }{StubStringer{}}, Out: map[string]string{"V": stringerOut}},
{In: struct{ V StubIntStringer }{StubIntStringer(100)}, Out: map[string]string{"V": stringerOut}},
{In: struct{ V StubTextMarshaler }{StubTextMarshaler{}}, Out: map[string]string{"V": textMarshalerOut}},
{In: struct{ V stubStringer }{stubStringer{}}, Out: map[string]string{"V": stringerOut}},
{In: struct{ V stubIntStringer }{stubIntStringer(100)}, Out: map[string]string{"V": stringerOut}},
{In: struct{ V stubTextMarshaler }{stubTextMarshaler{}}, Out: map[string]string{"V": textMarshalerOut}},

// Marshal interfaces by interfaces.
{In: struct{ V fmt.Stringer }{StubStringer{}}, Out: map[string]string{"V": stringerOut}},
{In: struct{ V fmt.Stringer }{StubIntStringer(100)}, Out: map[string]string{"V": stringerOut}},
{In: struct{ V encoding.TextMarshaler }{StubTextMarshaler{}}, Out: map[string]string{"V": textMarshalerOut}},
{In: struct{ V fmt.Stringer }{stubStringer{}}, Out: map[string]string{"V": stringerOut}},
{In: struct{ V fmt.Stringer }{stubIntStringer(100)}, Out: map[string]string{"V": stringerOut}},
{In: struct{ V encoding.TextMarshaler }{stubTextMarshaler{}}, Out: map[string]string{"V": textMarshalerOut}},
}
for _, test := range tests {
out, err := redmap.Marshal(test.In)
Expand Down Expand Up @@ -189,3 +210,44 @@ func TestMarshalWithTags(t *testing.T) {
t.Fatalf("Marshal's output doesn't respect struct tags\n\tExpected: %v\n\tOut: %v", expected, out)
}
}

func TestMapMarshaler(t *testing.T) {
tests := []struct {
In redmap.StringMapMarshaler
Out map[string]string
}{
{In: stubMapMarshaler{}, Out: mapMarshalerOut},
{In: stubIntMapMarshaler(666), Out: mapMarshalerOut},
}
for _, test := range tests {
out, err := redmap.Marshal(test.In)
if err != nil {
t.Fatalf("Marshal returned unexpected error %q", err)
}
if !reflect.DeepEqual(out, test.Out) {
t.Fatalf("Marshal's output doesn't match the expected value\n\tIn: %v\n\tExpected: %v\n\tOut: %v", test.In, test.Out, out)
}
}
}

func TestInnerMapMarshaler(t *testing.T) {
stru := struct {
RegularField string
Struct stubMapMarshaler `redmap:",inline"`
}{
RegularField: "regular",
Struct: stubMapMarshaler{},
}
expected := map[string]string{"RegularField": "regular"}
for k, v := range mapMarshalerOut {
expected["Struct."+k] = v
}

out, err := redmap.Marshal(stru)
if err != nil {
t.Fatalf("Marshal returned unexpected error %q", err)
}
if !reflect.DeepEqual(out, expected) {
t.Fatalf("Marshal's output doesn't respect struct tags\n\tExpected: %v\n\tOut: %v", expected, out)
}
}
56 changes: 37 additions & 19 deletions unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,55 +5,60 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
)

// StringMapUnmarshaler is the interface implemented by types that can unmarshal themselves
// from a map of strings. Implementations must copy the given map if they wish to modify it.
type StringMapUnmarshaler interface {
UnmarshalStringMap(map[string]string) error
}

// Unmarshal sets v's fields according to its map representation contained by data.
// v must be a pointer to struct or an interface. Neither data nor v can be nil.
//
// Unmarshal uses the inverse of the encodings that Marshal uses, so all the types supported
// by it are also supported in Unmarshal, except the interfaces: only encoding.TextUnmarshaler
// can be unmarshaled.
// by it are also supported in Unmarshal, except fmt.Stringer which doesn't have an inverse.
//
// The decoding of each struct field can be customized by the format string documented in Marshal.
func Unmarshal(data map[string]string, v interface{}) error {
if data == nil {
return errIs("map passed", ErrNilValue)
}
val, err := ptrStructValue(v)
val, err := ptrValidValue(v)
if err != nil {
return err
}
return unmarshalRecursive(data, "", val)
}

func ptrStructValue(v interface{}) (reflect.Value, error) {
func ptrValidValue(v interface{}) (reflect.Value, error) {
val := reflect.ValueOf(v)
kin := val.Kind()

switch kin {
case reflect.Ptr:
case reflect.Invalid:
return reflect.Value{}, errIs("argument provided", ErrNilValue)
default:
return reflect.Value{}, errIs(val.Type(), ErrNotPointer)
}

for kin == reflect.Ptr {
val = val.Elem()
kin = val.Kind()
}

switch kin {
case reflect.Struct:
return val, nil
case reflect.Invalid:
if kin == reflect.Invalid {
return reflect.Value{}, errIs(reflect.TypeOf(v), ErrNilValue)
default:
return reflect.Value{}, errIs(val.Type(), ErrNotStruct)
}
return val, nil
}

func unmarshalRecursive(data map[string]string, prefix string, stru reflect.Value) error {
func unmarshalRecursive(mp map[string]string, prefix string, stru reflect.Value) error {
if ptr := stru.Addr(); ptr.Type().Implements(mapUnmarshalerType) {
return mapToStruct(mp, prefix, ptr)
}
if stru.Kind() != reflect.Struct {
return errIs(stru.Type(), ErrNoCodec)
}
typ := stru.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
Expand Down Expand Up @@ -83,15 +88,12 @@ func unmarshalRecursive(data map[string]string, prefix string, stru reflect.Valu
}

if tags.inline {
if kind := value.Kind(); kind != reflect.Struct {
return fmt.Errorf("cannot inline: %w", errIs(value.Type(), ErrNotStruct))
}
err := unmarshalRecursive(data, tags.name+inlineSep, value)
err := unmarshalRecursive(mp, tags.name+inlineSep, value)
if err != nil {
return err
}
} else {
str, ok := data[tags.name]
str, ok := mp[tags.name]
if !ok {
continue
}
Expand All @@ -104,6 +106,21 @@ func unmarshalRecursive(data map[string]string, prefix string, stru reflect.Valu
return nil
}

func mapToStruct(mp map[string]string, prefix string, stru reflect.Value) error {
if prefix != "" {
// FIXME: Creating a submap is O(n). Can we think of a better algorithm?
subMP := make(map[string]string, len(mp))
for k, v := range mp {
if !strings.HasPrefix(k, prefix) {
continue
}
subMP[k[len(prefix):]] = v
}
mp = subMP
}
return stru.Interface().(StringMapUnmarshaler).UnmarshalStringMap(mp)
}

func stringToField(str string, field reflect.Value, omitempty bool) error {
addr := field.Addr() // Unmarshaling always requires a pointer receiver.
if addr.Type().Implements(textUnmarshalerType) {
Expand Down Expand Up @@ -176,5 +193,6 @@ func stringToField(str string, field reflect.Value, omitempty bool) error {
}

var (
mapUnmarshalerType = reflect.TypeOf(new(StringMapUnmarshaler)).Elem()
textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
)
Loading

0 comments on commit 3695d3c

Please sign in to comment.