Allow decoding to struct field of interface type (#280)
Closes #260 Closes #275
This commit is contained in:
parent
3240b60c8b
commit
4a03f1c003
9
cache.go
9
cache.go
@ -31,6 +31,7 @@ const (
|
||||
specialTypeNone specialType = iota
|
||||
specialTypeUnmarshalerIface
|
||||
specialTypeEmptyIface
|
||||
specialTypeIface
|
||||
specialTypeTag
|
||||
specialTypeTime
|
||||
)
|
||||
@ -57,8 +58,12 @@ func newTypeInfo(t reflect.Type) *typeInfo {
|
||||
tInfo.nonPtrType = t
|
||||
tInfo.nonPtrKind = k
|
||||
|
||||
if k == reflect.Interface && t.NumMethod() == 0 {
|
||||
tInfo.spclType = specialTypeEmptyIface
|
||||
if k == reflect.Interface {
|
||||
if t.NumMethod() == 0 {
|
||||
tInfo.spclType = specialTypeEmptyIface
|
||||
} else {
|
||||
tInfo.spclType = specialTypeIface
|
||||
}
|
||||
} else if t == typeTag {
|
||||
tInfo.spclType = specialTypeTag
|
||||
} else if t == typeTime {
|
||||
|
||||
@ -542,6 +542,12 @@ const (
|
||||
// parseToValue decodes CBOR data to value. It assumes data is well-formed,
|
||||
// and does not perform bounds checking.
|
||||
func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolint:gocyclo
|
||||
|
||||
if tInfo.spclType == specialTypeIface && !v.IsNil() {
|
||||
v = v.Elem()
|
||||
tInfo = getTypeInfo(v.Type())
|
||||
}
|
||||
|
||||
// Create new value for the pointer v to point to if CBOR value is not nil/undefined.
|
||||
if !d.nextCBORNil() {
|
||||
for v.Kind() == reflect.Ptr {
|
||||
|
||||
224
decode_test.go
224
decode_test.go
@ -5178,3 +5178,227 @@ func TestUnmarshalInvalidTagBignum(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Foo interface {
|
||||
Foo() string
|
||||
}
|
||||
|
||||
type UintFoo uint
|
||||
|
||||
func (f *UintFoo) Foo() string {
|
||||
return fmt.Sprint(f)
|
||||
}
|
||||
|
||||
type IntFoo int
|
||||
|
||||
func (f *IntFoo) Foo() string {
|
||||
return fmt.Sprint(*f)
|
||||
}
|
||||
|
||||
type ByteFoo []byte
|
||||
|
||||
func (f *ByteFoo) Foo() string {
|
||||
return fmt.Sprint(*f)
|
||||
}
|
||||
|
||||
type StringFoo string
|
||||
|
||||
func (f *StringFoo) Foo() string {
|
||||
return string(*f)
|
||||
}
|
||||
|
||||
type ArrayFoo []int
|
||||
|
||||
func (f *ArrayFoo) Foo() string {
|
||||
return fmt.Sprint(*f)
|
||||
}
|
||||
|
||||
type MapFoo map[int]int
|
||||
|
||||
func (f *MapFoo) Foo() string {
|
||||
return fmt.Sprint(*f)
|
||||
}
|
||||
|
||||
type StructFoo struct {
|
||||
Value int `cbor:"1,keyasint"`
|
||||
}
|
||||
|
||||
func (f *StructFoo) Foo() string {
|
||||
return fmt.Sprint(*f)
|
||||
}
|
||||
|
||||
type TestExample struct {
|
||||
Message string `cbor:"1,keyasint"`
|
||||
Foo Foo `cbor:"2,keyasint"`
|
||||
}
|
||||
|
||||
func TestUnmarshalToInterface(t *testing.T) {
|
||||
|
||||
uintFoo, uintFoo123 := UintFoo(0), UintFoo(123)
|
||||
intFoo, intFooNeg1 := IntFoo(0), IntFoo(-1)
|
||||
byteFoo, byteFoo123 := ByteFoo(nil), ByteFoo([]byte{1, 2, 3})
|
||||
stringFoo, stringFoo123 := StringFoo(""), StringFoo("123")
|
||||
arrayFoo, arrayFoo123 := ArrayFoo(nil), ArrayFoo([]int{1, 2, 3})
|
||||
mapFoo, mapFoo123 := MapFoo(nil), MapFoo(map[int]int{1: 1, 2: 2, 3: 3})
|
||||
|
||||
em, _ := EncOptions{Sort: SortCanonical}.EncMode()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
v *TestExample
|
||||
unmarshalToObj *TestExample
|
||||
}{
|
||||
{
|
||||
name: "uint",
|
||||
data: hexDecode("a2016b736f6d65206d657373676502187b"), // {1: "some messge", 2: 123}
|
||||
v: &TestExample{
|
||||
Message: "some messge",
|
||||
Foo: &uintFoo123,
|
||||
},
|
||||
unmarshalToObj: &TestExample{Foo: &uintFoo},
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
data: hexDecode("a2016b736f6d65206d65737367650220"), // {1: "some messge", 2: -1}
|
||||
v: &TestExample{
|
||||
Message: "some messge",
|
||||
Foo: &intFooNeg1,
|
||||
},
|
||||
unmarshalToObj: &TestExample{Foo: &intFoo},
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
data: hexDecode("a2016b736f6d65206d65737367650243010203"), // {1: "some messge", 2: [1,2,3]}
|
||||
v: &TestExample{
|
||||
Message: "some messge",
|
||||
Foo: &byteFoo123,
|
||||
},
|
||||
unmarshalToObj: &TestExample{Foo: &byteFoo},
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
data: hexDecode("a2016b736f6d65206d65737367650263313233"), // {1: "some messge", 2: "123"}
|
||||
v: &TestExample{
|
||||
Message: "some messge",
|
||||
Foo: &stringFoo123,
|
||||
},
|
||||
unmarshalToObj: &TestExample{Foo: &stringFoo},
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
data: hexDecode("a2016b736f6d65206d65737367650283010203"), // {1: "some messge", 2: []int{1,2,3}}
|
||||
v: &TestExample{
|
||||
Message: "some messge",
|
||||
Foo: &arrayFoo123,
|
||||
},
|
||||
unmarshalToObj: &TestExample{Foo: &arrayFoo},
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
data: hexDecode("a2016b736f6d65206d657373676502a3010102020303"), // {1: "some messge", 2: map[int]int{1:1,2:2,3:3}}
|
||||
v: &TestExample{
|
||||
Message: "some messge",
|
||||
Foo: &mapFoo123,
|
||||
},
|
||||
unmarshalToObj: &TestExample{Foo: &mapFoo},
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
data: hexDecode("a2016b736f6d65206d657373676502a1011901c8"), // {1: "some messge", 2: {1: 456}}
|
||||
v: &TestExample{
|
||||
Message: "some messge",
|
||||
Foo: &StructFoo{Value: 456},
|
||||
},
|
||||
unmarshalToObj: &TestExample{Foo: &StructFoo{}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
data, err := em.Marshal(tc.v)
|
||||
if err != nil {
|
||||
t.Errorf("Marshal(%+v) returned error %v", tc.v, err)
|
||||
} else if !bytes.Equal(data, tc.data) {
|
||||
t.Errorf("Marshal(%+v) = 0x%x, want 0x%x", tc.v, data, tc.v)
|
||||
}
|
||||
|
||||
// Unmarshal to empty interface
|
||||
var einterface TestExample
|
||||
if err = Unmarshal(data, &einterface); err == nil {
|
||||
t.Errorf("Unmarshal(0x%x) didn't return an error, want error (*UnmarshalTypeError)", data)
|
||||
} else if _, ok := err.(*UnmarshalTypeError); !ok {
|
||||
t.Errorf("Unmarshal(0x%x) returned wrong type of error %T, want (*UnmarshalTypeError)", data, err)
|
||||
}
|
||||
|
||||
// Unmarshal to interface value
|
||||
err = Unmarshal(data, tc.unmarshalToObj)
|
||||
if err != nil {
|
||||
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
|
||||
} else if !reflect.DeepEqual(tc.unmarshalToObj, tc.v) {
|
||||
t.Errorf("Unmarshal(0x%x) = %v, want %v", data, tc.unmarshalToObj, tc.v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
I int
|
||||
}
|
||||
|
||||
func (b *Bar) Foo() string {
|
||||
return fmt.Sprint(*b)
|
||||
}
|
||||
|
||||
type FooStruct struct {
|
||||
Foos []Foo
|
||||
}
|
||||
|
||||
func TestUnmarshalTaggedDataToInterface(t *testing.T) {
|
||||
|
||||
var tags = NewTagSet()
|
||||
err := tags.Add(
|
||||
TagOptions{EncTag: EncTagRequired, DecTag: DecTagRequired},
|
||||
reflect.TypeOf(&Bar{}),
|
||||
4,
|
||||
)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
v := &FooStruct{
|
||||
Foos: []Foo{&Bar{1}},
|
||||
}
|
||||
|
||||
want := hexDecode("a164466f6f7381c4a1614901") // {"Foos": [4({"I": 1})]}
|
||||
|
||||
em, _ := EncOptions{}.EncModeWithTags(tags)
|
||||
data, err := em.Marshal(v)
|
||||
if err != nil {
|
||||
t.Errorf("Marshal(%+v) returned error %v", v, err)
|
||||
} else if !bytes.Equal(data, want) {
|
||||
t.Errorf("Marshal(%+v) = 0x%x, want 0x%x", v, data, want)
|
||||
}
|
||||
|
||||
dm, _ := DecOptions{}.DecModeWithTags(tags)
|
||||
|
||||
// Unmarshal to empty interface
|
||||
var v1 Bar
|
||||
if err = dm.Unmarshal(data, &v1); err == nil {
|
||||
t.Errorf("Unmarshal(0x%x) didn't return an error, want error (*UnmarshalTypeError)", data)
|
||||
} else if _, ok := err.(*UnmarshalTypeError); !ok {
|
||||
t.Errorf("Unmarshal(0x%x) returned wrong type of error %T, want (*UnmarshalTypeError)", data, err)
|
||||
}
|
||||
|
||||
// Unmarshal to interface value
|
||||
v2 := &FooStruct{
|
||||
Foos: []Foo{&Bar{}},
|
||||
}
|
||||
err = dm.Unmarshal(data, v2)
|
||||
if err != nil {
|
||||
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
|
||||
} else if !reflect.DeepEqual(v2, v) {
|
||||
t.Errorf("Unmarshal(0x%x) = %v, want %v", data, v2, v)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user