Currently, unreleased changes in PR #636 and #645 cause the input data to be checked twice when UnmarshalCBOR() is called internally by Unmarshal() for: - ByteString - RawTag - SimpleValue UnmarshalCBOR() checks input data because it can be called by user apps providing bad data. However, the codec already checks input data before internally calling UnmarshalCBOR() so the 2nd check is redundant. This commit avoids redundant check on the input data by having Unmarshal() call the private unmarshalCBOR() if implemented by ByteString, RawTag, SimpleValue, etc.: - Internally, the codec calls the private unmarshalCBOR() to avoid the redundant check on input data. - Externally, UnmarshalCBOR() is available as a wrapper that checks input data before calling the private unmarshalCBOR(). UnmarshalCBOR() for ByteString, RawTag, and SimpleValue are marked as deprecated and Unmarshal() should be used instead.
368 lines
9.1 KiB
Go
368 lines
9.1 KiB
Go
// Copyright (c) Faye Amacker. All rights reserved.
|
|
// Licensed under the MIT License. See LICENSE in the project root for license information.
|
|
|
|
package cbor
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
type encodeFuncs struct {
|
|
ef encodeFunc
|
|
ief isEmptyFunc
|
|
izf isZeroFunc
|
|
}
|
|
|
|
var (
|
|
decodingStructTypeCache sync.Map // map[reflect.Type]*decodingStructType
|
|
encodingStructTypeCache sync.Map // map[reflect.Type]*encodingStructType
|
|
encodeFuncCache sync.Map // map[reflect.Type]encodeFuncs
|
|
typeInfoCache sync.Map // map[reflect.Type]*typeInfo
|
|
)
|
|
|
|
type specialType int
|
|
|
|
const (
|
|
specialTypeNone specialType = iota
|
|
specialTypeUnmarshalerIface
|
|
specialTypeUnexportedUnmarshalerIface
|
|
specialTypeEmptyIface
|
|
specialTypeIface
|
|
specialTypeTag
|
|
specialTypeTime
|
|
)
|
|
|
|
type typeInfo struct {
|
|
elemTypeInfo *typeInfo
|
|
keyTypeInfo *typeInfo
|
|
typ reflect.Type
|
|
kind reflect.Kind
|
|
nonPtrType reflect.Type
|
|
nonPtrKind reflect.Kind
|
|
spclType specialType
|
|
}
|
|
|
|
func newTypeInfo(t reflect.Type) *typeInfo {
|
|
tInfo := typeInfo{typ: t, kind: t.Kind()}
|
|
|
|
for t.Kind() == reflect.Pointer {
|
|
t = t.Elem()
|
|
}
|
|
|
|
k := t.Kind()
|
|
|
|
tInfo.nonPtrType = t
|
|
tInfo.nonPtrKind = k
|
|
|
|
if k == reflect.Interface {
|
|
if t.NumMethod() == 0 {
|
|
tInfo.spclType = specialTypeEmptyIface
|
|
} else {
|
|
tInfo.spclType = specialTypeIface
|
|
}
|
|
} else if t == typeTag {
|
|
tInfo.spclType = specialTypeTag
|
|
} else if t == typeTime {
|
|
tInfo.spclType = specialTypeTime
|
|
} else if reflect.PointerTo(t).Implements(typeUnexportedUnmarshaler) {
|
|
tInfo.spclType = specialTypeUnexportedUnmarshalerIface
|
|
} else if reflect.PointerTo(t).Implements(typeUnmarshaler) {
|
|
tInfo.spclType = specialTypeUnmarshalerIface
|
|
}
|
|
|
|
switch k {
|
|
case reflect.Array, reflect.Slice:
|
|
tInfo.elemTypeInfo = getTypeInfo(t.Elem())
|
|
case reflect.Map:
|
|
tInfo.keyTypeInfo = getTypeInfo(t.Key())
|
|
tInfo.elemTypeInfo = getTypeInfo(t.Elem())
|
|
}
|
|
|
|
return &tInfo
|
|
}
|
|
|
|
type decodingStructType struct {
|
|
fields fields
|
|
fieldIndicesByName map[string]int
|
|
err error
|
|
toArray bool
|
|
}
|
|
|
|
// The stdlib errors.Join was introduced in Go 1.20, and we still support Go 1.17, so instead,
|
|
// here's a very basic implementation of an aggregated error.
|
|
type multierror []error
|
|
|
|
func (m multierror) Error() string {
|
|
var sb strings.Builder
|
|
for i, err := range m {
|
|
sb.WriteString(err.Error())
|
|
if i < len(m)-1 {
|
|
sb.WriteString(", ")
|
|
}
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
func getDecodingStructType(t reflect.Type) *decodingStructType {
|
|
if v, _ := decodingStructTypeCache.Load(t); v != nil {
|
|
return v.(*decodingStructType)
|
|
}
|
|
|
|
flds, structOptions := getFields(t)
|
|
|
|
toArray := hasToArrayOption(structOptions)
|
|
|
|
var errs []error
|
|
for i := 0; i < len(flds); i++ {
|
|
if flds[i].keyAsInt {
|
|
nameAsInt, numErr := strconv.Atoi(flds[i].name)
|
|
if numErr != nil {
|
|
errs = append(errs, errors.New("cbor: failed to parse field name \""+flds[i].name+"\" to int ("+numErr.Error()+")"))
|
|
break
|
|
}
|
|
flds[i].nameAsInt = int64(nameAsInt)
|
|
}
|
|
|
|
flds[i].typInfo = getTypeInfo(flds[i].typ)
|
|
}
|
|
|
|
fieldIndicesByName := make(map[string]int, len(flds))
|
|
for i, fld := range flds {
|
|
if _, ok := fieldIndicesByName[fld.name]; ok {
|
|
errs = append(errs, fmt.Errorf("cbor: two or more fields of %v have the same name %q", t, fld.name))
|
|
continue
|
|
}
|
|
fieldIndicesByName[fld.name] = i
|
|
}
|
|
|
|
var err error
|
|
{
|
|
var multi multierror
|
|
for _, each := range errs {
|
|
if each != nil {
|
|
multi = append(multi, each)
|
|
}
|
|
}
|
|
if len(multi) == 1 {
|
|
err = multi[0]
|
|
} else if len(multi) > 1 {
|
|
err = multi
|
|
}
|
|
}
|
|
|
|
structType := &decodingStructType{
|
|
fields: flds,
|
|
fieldIndicesByName: fieldIndicesByName,
|
|
err: err,
|
|
toArray: toArray,
|
|
}
|
|
decodingStructTypeCache.Store(t, structType)
|
|
return structType
|
|
}
|
|
|
|
type encodingStructType struct {
|
|
fields fields
|
|
bytewiseFields fields
|
|
lengthFirstFields fields
|
|
omitEmptyFieldsIdx []int
|
|
err error
|
|
toArray bool
|
|
}
|
|
|
|
func (st *encodingStructType) getFields(em *encMode) fields {
|
|
switch em.sort {
|
|
case SortNone, SortFastShuffle:
|
|
return st.fields
|
|
case SortLengthFirst:
|
|
return st.lengthFirstFields
|
|
default:
|
|
return st.bytewiseFields
|
|
}
|
|
}
|
|
|
|
type bytewiseFieldSorter struct {
|
|
fields fields
|
|
}
|
|
|
|
func (x *bytewiseFieldSorter) Len() int {
|
|
return len(x.fields)
|
|
}
|
|
|
|
func (x *bytewiseFieldSorter) Swap(i, j int) {
|
|
x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
|
|
}
|
|
|
|
func (x *bytewiseFieldSorter) Less(i, j int) bool {
|
|
return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
|
|
}
|
|
|
|
type lengthFirstFieldSorter struct {
|
|
fields fields
|
|
}
|
|
|
|
func (x *lengthFirstFieldSorter) Len() int {
|
|
return len(x.fields)
|
|
}
|
|
|
|
func (x *lengthFirstFieldSorter) Swap(i, j int) {
|
|
x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
|
|
}
|
|
|
|
func (x *lengthFirstFieldSorter) Less(i, j int) bool {
|
|
if len(x.fields[i].cborName) != len(x.fields[j].cborName) {
|
|
return len(x.fields[i].cborName) < len(x.fields[j].cborName)
|
|
}
|
|
return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
|
|
}
|
|
|
|
func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
|
|
if v, _ := encodingStructTypeCache.Load(t); v != nil {
|
|
structType := v.(*encodingStructType)
|
|
return structType, structType.err
|
|
}
|
|
|
|
flds, structOptions := getFields(t)
|
|
|
|
if hasToArrayOption(structOptions) {
|
|
return getEncodingStructToArrayType(t, flds)
|
|
}
|
|
|
|
var err error
|
|
var hasKeyAsInt bool
|
|
var hasKeyAsStr bool
|
|
var omitEmptyIdx []int
|
|
e := getEncodeBuffer()
|
|
for i := 0; i < len(flds); i++ {
|
|
// Get field's encodeFunc
|
|
flds[i].ef, flds[i].ief, flds[i].izf = getEncodeFunc(flds[i].typ)
|
|
if flds[i].ef == nil {
|
|
err = &UnsupportedTypeError{t}
|
|
break
|
|
}
|
|
|
|
// Encode field name
|
|
if flds[i].keyAsInt {
|
|
nameAsInt, numErr := strconv.Atoi(flds[i].name)
|
|
if numErr != nil {
|
|
err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
|
|
break
|
|
}
|
|
flds[i].nameAsInt = int64(nameAsInt)
|
|
if nameAsInt >= 0 {
|
|
encodeHead(e, byte(cborTypePositiveInt), uint64(nameAsInt))
|
|
} else {
|
|
n := nameAsInt*(-1) - 1
|
|
encodeHead(e, byte(cborTypeNegativeInt), uint64(n))
|
|
}
|
|
flds[i].cborName = make([]byte, e.Len())
|
|
copy(flds[i].cborName, e.Bytes())
|
|
e.Reset()
|
|
|
|
hasKeyAsInt = true
|
|
} else {
|
|
encodeHead(e, byte(cborTypeTextString), uint64(len(flds[i].name)))
|
|
flds[i].cborName = make([]byte, e.Len()+len(flds[i].name))
|
|
n := copy(flds[i].cborName, e.Bytes())
|
|
copy(flds[i].cborName[n:], flds[i].name)
|
|
e.Reset()
|
|
|
|
// If cborName contains a text string, then cborNameByteString contains a
|
|
// string that has the byte string major type but is otherwise identical to
|
|
// cborName.
|
|
flds[i].cborNameByteString = make([]byte, len(flds[i].cborName))
|
|
copy(flds[i].cborNameByteString, flds[i].cborName)
|
|
// Reset encoded CBOR type to byte string, preserving the "additional
|
|
// information" bits:
|
|
flds[i].cborNameByteString[0] = byte(cborTypeByteString) |
|
|
getAdditionalInformation(flds[i].cborNameByteString[0])
|
|
|
|
hasKeyAsStr = true
|
|
}
|
|
|
|
// Check if field can be omitted when empty
|
|
if flds[i].omitEmpty {
|
|
omitEmptyIdx = append(omitEmptyIdx, i)
|
|
}
|
|
}
|
|
putEncodeBuffer(e)
|
|
|
|
if err != nil {
|
|
structType := &encodingStructType{err: err}
|
|
encodingStructTypeCache.Store(t, structType)
|
|
return structType, structType.err
|
|
}
|
|
|
|
// Sort fields by canonical order
|
|
bytewiseFields := make(fields, len(flds))
|
|
copy(bytewiseFields, flds)
|
|
sort.Sort(&bytewiseFieldSorter{bytewiseFields})
|
|
|
|
lengthFirstFields := bytewiseFields
|
|
if hasKeyAsInt && hasKeyAsStr {
|
|
lengthFirstFields = make(fields, len(flds))
|
|
copy(lengthFirstFields, flds)
|
|
sort.Sort(&lengthFirstFieldSorter{lengthFirstFields})
|
|
}
|
|
|
|
structType := &encodingStructType{
|
|
fields: flds,
|
|
bytewiseFields: bytewiseFields,
|
|
lengthFirstFields: lengthFirstFields,
|
|
omitEmptyFieldsIdx: omitEmptyIdx,
|
|
}
|
|
|
|
encodingStructTypeCache.Store(t, structType)
|
|
return structType, structType.err
|
|
}
|
|
|
|
func getEncodingStructToArrayType(t reflect.Type, flds fields) (*encodingStructType, error) {
|
|
for i := 0; i < len(flds); i++ {
|
|
// Get field's encodeFunc
|
|
flds[i].ef, flds[i].ief, flds[i].izf = getEncodeFunc(flds[i].typ)
|
|
if flds[i].ef == nil {
|
|
structType := &encodingStructType{err: &UnsupportedTypeError{t}}
|
|
encodingStructTypeCache.Store(t, structType)
|
|
return structType, structType.err
|
|
}
|
|
}
|
|
|
|
structType := &encodingStructType{
|
|
fields: flds,
|
|
toArray: true,
|
|
}
|
|
encodingStructTypeCache.Store(t, structType)
|
|
return structType, structType.err
|
|
}
|
|
|
|
func getEncodeFunc(t reflect.Type) (encodeFunc, isEmptyFunc, isZeroFunc) {
|
|
if v, _ := encodeFuncCache.Load(t); v != nil {
|
|
fs := v.(encodeFuncs)
|
|
return fs.ef, fs.ief, fs.izf
|
|
}
|
|
ef, ief, izf := getEncodeFuncInternal(t)
|
|
encodeFuncCache.Store(t, encodeFuncs{ef, ief, izf})
|
|
return ef, ief, izf
|
|
}
|
|
|
|
func getTypeInfo(t reflect.Type) *typeInfo {
|
|
if v, _ := typeInfoCache.Load(t); v != nil {
|
|
return v.(*typeInfo)
|
|
}
|
|
tInfo := newTypeInfo(t)
|
|
typeInfoCache.Store(t, tInfo)
|
|
return tInfo
|
|
}
|
|
|
|
func hasToArrayOption(tag string) bool {
|
|
s := ",toarray"
|
|
idx := strings.Index(tag, s)
|
|
return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',')
|
|
}
|