Refactor to remove more magic numbers

This commit is contained in:
Faye Amacker 2024-05-27 21:14:04 -05:00
parent c5c66ba45d
commit fdf5bd8378
7 changed files with 134 additions and 44 deletions

View File

@ -53,7 +53,7 @@ const (
additionalInformationWith4ByteArgument = 26
additionalInformationWith8ByteArgument = 27
// additional information with major type 7
// For major type 7.
additionalInformationAsFalse = 20
additionalInformationAsTrue = 21
additionalInformationAsNull = 22
@ -62,9 +62,15 @@ const (
additionalInformationAsFloat32 = 26
additionalInformationAsFloat64 = 27
// For major type 2, 3, 4, 5.
additionalInformationAsIndefiniteLengthFlag = 31
)
const (
maxSimpleValueInAdditionalInformation = 23
minSimpleValueIn1ByteArgument = 32
)
func (ai additionalInformation) isIndefiniteLength() bool {
return ai == additionalInformationAsIndefiniteLengthFlag
}
@ -110,7 +116,11 @@ const (
)
const (
cborBreakFlag = byte(0xff)
cborBreakFlag = byte(0xff)
cborByteStringWithIndefiniteLengthHead = byte(0x5f)
cborTextStringWithIndefiniteLengthHead = byte(0x7f)
cborArrayWithIndefiniteLengthHead = byte(0x9f)
cborMapWithIndefiniteLengthHead = byte(0xbf)
)
var (

View File

@ -1358,8 +1358,10 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
v.Set(reflect.ValueOf(iv))
}
return err
case specialTypeTag:
return d.parseToTag(v)
case specialTypeTime:
if d.nextCBORNil() {
// Decoding CBOR null and undefined to time.Time is no-op.
@ -1374,6 +1376,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
v.Set(reflect.ValueOf(tm))
}
return nil
case specialTypeUnmarshalerIface:
return d.parseToUnmarshaler(v)
}
@ -1535,6 +1538,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
}()
}
}
return d.parseToValue(v, tInfo)
case cborTypeArray:
@ -1628,6 +1632,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) {
return t, true, nil
}
return time.Time{}, false, &UnmarshalTypeError{CBORType: t.String(), GoType: typeTime.String()}
case cborTypeTextString:
s, err := d.parseTextString()
if err != nil {
@ -1638,6 +1643,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) {
return time.Time{}, false, errors.New("cbor: cannot set " + string(s) + " for time.Time: " + err.Error())
}
return t, true, nil
case cborTypePositiveInt:
_, _, val := d.getHead()
if val > math.MaxInt64 {
@ -1648,6 +1654,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) {
}
}
return time.Unix(int64(val), 0), true, nil
case cborTypeNegativeInt:
_, _, val := d.getHead()
if val > math.MaxInt64 {
@ -1667,6 +1674,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) {
}
}
return time.Unix(int64(-1)^int64(val), 0), true, nil
case cborTypePrimitives:
_, ai, val := d.getHead()
var f float64
@ -1690,6 +1698,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) {
}
seconds, fractional := math.Modf(f)
return time.Unix(int64(seconds), int64(fractional*1e9)), true, nil
default:
return time.Time{}, false, &UnmarshalTypeError{CBORType: t.String(), GoType: typeTime.String()}
}
@ -1822,8 +1831,10 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
clone := make([]byte, len(b))
copy(clone, b)
return clone, nil
case typeString:
return string(b), nil
default:
if copied || d.dm.defaultByteStringType.Kind() == reflect.String {
// Avoid an unnecessary copy since the conversion to string must
@ -1834,12 +1845,14 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
copy(clone, b)
return reflect.ValueOf(clone).Convert(d.dm.defaultByteStringType).Interface(), nil
}
case cborTypeTextString:
b, err := d.parseTextString()
if err != nil {
return nil, err
}
return string(b), nil
case cborTypeTag:
tagOff := d.off
_, _, tagNum := d.getHead()
@ -1852,9 +1865,11 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
if err != nil {
return nil, err
}
switch d.dm.timeTagToAny {
case TimeTagToTime:
return tm, nil
case TimeTagToRFC3339:
if tagNum == 1 {
tm = tm.UTC()
@ -1866,6 +1881,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return nil, err
}
return string(text), nil
case TimeTagToRFC3339Nano:
if tagNum == 1 {
tm = tm.UTC()
@ -1877,6 +1893,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return nil, err
}
return string(text), nil
default:
// not reachable
}
@ -1953,6 +1970,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
if ai < 20 || ai == 24 {
return SimpleValue(val), nil
}
switch ai {
case additionalInformationAsFalse,
additionalInformationAsTrue:
@ -1977,6 +1995,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
case cborTypeArray:
return d.parseArray()
case cborTypeMap:
if d.dm.defaultMapType != nil {
m := reflect.New(d.dm.defaultMapType)
@ -1988,6 +2007,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
}
return d.parseMap()
}
return nil, nil
}
@ -2035,19 +2055,23 @@ func (d *decoder) applyByteStringTextConversion(
encoded := make([]byte, base64.RawURLEncoding.EncodedLen(len(src)))
base64.RawURLEncoding.Encode(encoded, src)
return encoded, true, nil
case tagNumExpectedLaterEncodingBase64:
encoded := make([]byte, base64.StdEncoding.EncodedLen(len(src)))
base64.StdEncoding.Encode(encoded, src)
return encoded, true, nil
case tagNumExpectedLaterEncodingBase16:
encoded := make([]byte, hex.EncodedLen(len(src)))
hex.Encode(encoded, src)
return encoded, true, nil
default:
// If this happens, there is a bug: the decoder has pushed an invalid
// "expected later encoding" tag to the stack.
panic(fmt.Sprintf("unrecognized expected later encoding tag: %d", d.expectedLaterEncodingTags))
}
case reflect.Slice:
if dstType.Elem().Kind() != reflect.Uint8 || len(d.expectedLaterEncodingTags) > 0 {
// Either the destination is not a slice of bytes, or the encoder that
@ -2064,6 +2088,7 @@ func (d *decoder) applyByteStringTextConversion(
return nil, false, fmt.Errorf("cbor: failed to decode base64url string: %v", err)
}
return decoded[:n], true, nil
case ByteSliceExpectedEncodingBase64:
decoded := make([]byte, base64.StdEncoding.DecodedLen(len(src)))
n, err := base64.StdEncoding.Decode(decoded, src)
@ -2071,6 +2096,7 @@ func (d *decoder) applyByteStringTextConversion(
return nil, false, fmt.Errorf("cbor: failed to decode base64 string: %v", err)
}
return decoded[:n], true, nil
case ByteSliceExpectedEncodingBase16:
decoded := make([]byte, hex.DecodedLen(len(src)))
n, err := hex.Decode(decoded, src)
@ -2756,14 +2782,17 @@ func (d *decoder) skip() {
switch t {
case cborTypeByteString, cborTypeTextString:
d.off += int(val)
case cborTypeArray:
for i := 0; i < int(val); i++ {
d.skip()
}
case cborTypeMap:
for i := 0; i < int(val)*2; i++ {
d.skip()
}
case cborTypeTag:
d.skip()
}
@ -2893,6 +2922,7 @@ func fillPositiveInt(t cborType, val uint64, v reflect.Value) error {
}
v.SetInt(int64(val))
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if v.OverflowUint(val) {
return &UnmarshalTypeError{
@ -2903,11 +2933,13 @@ func fillPositiveInt(t cborType, val uint64, v reflect.Value) error {
}
v.SetUint(val)
return nil
case reflect.Float32, reflect.Float64:
f := float64(val)
v.SetFloat(f)
return nil
}
if v.Type() == typeBigInt {
i := new(big.Int).SetUint64(val)
v.Set(reflect.ValueOf(*i))
@ -2928,6 +2960,7 @@ func fillNegativeInt(t cborType, val int64, v reflect.Value) error {
}
v.SetInt(val)
return nil
case reflect.Float32, reflect.Float64:
f := float64(val)
v.SetFloat(f)
@ -3026,6 +3059,7 @@ func isImmutableKind(k reflect.Kind) bool {
reflect.Float32, reflect.Float64,
reflect.String:
return true
default:
return false
}
@ -3035,6 +3069,7 @@ func isHashableValue(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Slice, reflect.Map, reflect.Func:
return false
case reflect.Struct:
switch rv.Type() {
case typeTag:
@ -3057,6 +3092,7 @@ func convertByteSliceToByteString(v interface{}) (interface{}, bool) {
switch v := v.(type) {
case []byte:
return ByteString(v), true
case Tag:
content, converted := convertByteSliceToByteString(v.Content)
if converted {

View File

@ -242,16 +242,17 @@ func (di *diagnose) wellformed(allowExtraData bool) error {
func (di *diagnose) item() error { //nolint:gocyclo
initialByte := di.d.data[di.d.off]
switch initialByte {
case 0x5f, 0x7f: // indefinite-length byte/text string
case cborByteStringWithIndefiniteLengthHead,
cborTextStringWithIndefiniteLengthHead: // indefinite-length byte/text string
di.d.off++
if isBreakFlag(di.d.data[di.d.off]) {
di.d.off++
switch initialByte {
case 0x5f:
case cborByteStringWithIndefiniteLengthHead:
// indefinite-length bytes with no chunks.
di.w.WriteString(`''_`)
return nil
case 0x7f:
case cborTextStringWithIndefiniteLengthHead:
// indefinite-length text with no chunks.
di.w.WriteString(`""_`)
return nil
@ -276,7 +277,7 @@ func (di *diagnose) item() error { //nolint:gocyclo
di.w.WriteByte(')')
return nil
case 0x9f: // indefinite-length array
case cborArrayWithIndefiniteLengthHead: // indefinite-length array
di.d.off++
di.w.WriteString("[_ ")
@ -295,7 +296,7 @@ func (di *diagnose) item() error { //nolint:gocyclo
di.w.WriteByte(']')
return nil
case 0xbf: // indefinite-length map
case cborMapWithIndefiniteLengthHead: // indefinite-length map
di.d.off++
di.w.WriteString("{_ ")
@ -573,7 +574,7 @@ func (di *diagnose) encodeByteString(val []byte) error {
}
}
var utf16SurrSelf = rune(0x10000)
const utf16SurrSelf = rune(0x10000)
// quote should be either `'` or `"`
func (di *diagnose) encodeTextString(val string, quote byte) error {
@ -678,16 +679,17 @@ func (di *diagnose) encodeFloat(ai byte, val uint64) error {
}
// Use ES6 number to string conversion which should match most JSON generators.
// Inspired by https://github.com/golang/go/blob/4df10fba1687a6d4f51d7238a403f8f2298f6a16/src/encoding/json/encode.go#L585
const bitSize = 64
b := make([]byte, 0, 32)
if abs := math.Abs(f64); abs != 0 && (abs < 1e-6 || abs >= 1e21) {
b = strconv.AppendFloat(b, f64, 'e', -1, 64)
b = strconv.AppendFloat(b, f64, 'e', -1, bitSize)
// clean up e-09 to e-9
n := len(b)
if n >= 4 && string(b[n-4:n-1]) == "e-0" {
b = append(b[:n-2], b[n-1])
}
} else {
b = strconv.AppendFloat(b, f64, 'f', -1, 64)
b = strconv.AppendFloat(b, f64, 'f', -1, bitSize)
}
// add decimal point and trailing zero if needed

View File

@ -195,6 +195,7 @@ func (st StringMode) cborType() (cborType, error) {
switch st {
case StringToTextString:
return cborTypeTextString, nil
case StringToByteString:
return cborTypeByteString, nil
}
@ -417,10 +418,13 @@ func (bsm ByteSliceMode) encodingTag() (uint64, error) {
switch bsm {
case ByteSliceToByteString:
return 0, nil
case ByteSliceToByteStringWithExpectedConversionToBase64URL:
return tagNumExpectedLaterEncodingBase64URL, nil
case ByteSliceToByteStringWithExpectedConversionToBase64:
return tagNumExpectedLaterEncodingBase64, nil
case ByteSliceToByteStringWithExpectedConversionToBase16:
return tagNumExpectedLaterEncodingBase16, nil
}
@ -978,9 +982,9 @@ func encodeFloat(e *bytes.Buffer, em *encMode, v reflect.Value) error {
// Encode float64
// Don't use encodeFloat64() because it cannot be inlined.
var scratch [9]byte
scratch[0] = byte(cborTypePrimitives) | byte(27)
scratch[0] = byte(cborTypePrimitives) | byte(additionalInformationAsFloat64)
binary.BigEndian.PutUint64(scratch[1:], math.Float64bits(f64))
e.Write(scratch[:9])
e.Write(scratch[:])
return nil
}
@ -1002,7 +1006,7 @@ func encodeFloat(e *bytes.Buffer, em *encMode, v reflect.Value) error {
// Encode float16
// Don't use encodeFloat16() because it cannot be inlined.
var scratch [3]byte
scratch[0] = byte(cborTypePrimitives) | byte(25)
scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat16
binary.BigEndian.PutUint16(scratch[1:], uint16(f16))
e.Write(scratch[:3])
return nil
@ -1012,7 +1016,7 @@ func encodeFloat(e *bytes.Buffer, em *encMode, v reflect.Value) error {
// Encode float32
// Don't use encodeFloat32() because it cannot be inlined.
var scratch [5]byte
scratch[0] = byte(cborTypePrimitives) | byte(26)
scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat32
binary.BigEndian.PutUint32(scratch[1:], math.Float32bits(f32))
e.Write(scratch[:5])
return nil
@ -1023,6 +1027,7 @@ func encodeInf(e *bytes.Buffer, em *encMode, v reflect.Value) error {
switch em.infConvert {
case InfConvertReject:
return &UnsupportedValueError{msg: "floating-point infinity"}
case InfConvertFloat16:
if f64 > 0 {
e.Write(cborPositiveInfinity)
@ -1100,7 +1105,7 @@ func encodeNaN(e *bytes.Buffer, em *encMode, v reflect.Value) error {
func encodeFloat16(e *bytes.Buffer, f16 float16.Float16) error {
var scratch [3]byte
scratch[0] = byte(cborTypePrimitives) | byte(25)
scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat16
binary.BigEndian.PutUint16(scratch[1:], uint16(f16))
e.Write(scratch[:3])
return nil
@ -1108,7 +1113,7 @@ func encodeFloat16(e *bytes.Buffer, f16 float16.Float16) error {
func encodeFloat32(e *bytes.Buffer, f32 float32) error {
var scratch [5]byte
scratch[0] = byte(cborTypePrimitives) | byte(26)
scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat32
binary.BigEndian.PutUint32(scratch[1:], math.Float32bits(f32))
e.Write(scratch[:5])
return nil
@ -1116,7 +1121,7 @@ func encodeFloat32(e *bytes.Buffer, f32 float32) error {
func encodeFloat64(e *bytes.Buffer, f64 float64) error {
var scratch [9]byte
scratch[0] = byte(cborTypePrimitives) | byte(27)
scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat64
binary.BigEndian.PutUint64(scratch[1:], math.Float64bits(f64))
e.Write(scratch[:9])
return nil
@ -1478,10 +1483,12 @@ func encodeTime(e *bytes.Buffer, em *encMode, v reflect.Value) error {
case TimeUnix:
secs := t.Unix()
return encodeInt(e, em, reflect.ValueOf(secs))
case TimeUnixMicro:
t = t.UTC().Round(time.Microsecond)
f := float64(t.UnixNano()) / 1e9
return encodeFloat(e, em, reflect.ValueOf(f))
case TimeUnixDynamic:
t = t.UTC().Round(time.Microsecond)
secs, nsecs := t.Unix(), uint64(t.Nanosecond())
@ -1490,9 +1497,11 @@ func encodeTime(e *bytes.Buffer, em *encMode, v reflect.Value) error {
}
f := float64(secs) + float64(nsecs)/1e9
return encodeFloat(e, em, reflect.ValueOf(f))
case TimeRFC3339:
s := t.Format(time.RFC3339)
return encodeString(e, em, reflect.ValueOf(s))
default: // TimeRFC3339Nano
s := t.Format(time.RFC3339Nano)
return encodeString(e, em, reflect.ValueOf(s))
@ -1690,14 +1699,19 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc) {
switch t {
case typeSimpleValue:
return encodeMarshalerType, isEmptyUint
case typeTag:
return encodeTag, alwaysNotEmpty
case typeTime:
return encodeTime, alwaysNotEmpty
case typeBigInt:
return encodeBigInt, alwaysNotEmpty
case typeRawMessage:
return encodeMarshalerType, isEmptySlice
case typeByteString:
return encodeMarshalerType, isEmptyString
}
@ -1718,31 +1732,39 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc) {
switch k {
case reflect.Bool:
return encodeBool, isEmptyBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return encodeInt, isEmptyInt
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return encodeUint, isEmptyUint
case reflect.Float32, reflect.Float64:
return encodeFloat, isEmptyFloat
case reflect.String:
return encodeString, isEmptyString
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 {
return encodeByteString, isEmptySlice
}
fallthrough
case reflect.Array:
f, _ := getEncodeFunc(t.Elem())
if f == nil {
return nil, nil
}
return arrayEncodeFunc{f: f}.encode, isEmptySlice
case reflect.Map:
f := getEncodeMapFunc(t)
if f == nil {
return nil, nil
}
return f, isEmptyMap
case reflect.Struct:
// Get struct's special field "_" tag options
if f, ok := t.FieldByName("_"); ok {
@ -1754,6 +1776,7 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc) {
}
}
return encodeStruct, isEmptyStruct
case reflect.Interface:
return encodeIntf, isEmptyIntf
}

View File

@ -449,27 +449,46 @@ func TestMarshalLargeMap(t *testing.T) {
}
func encodeCborHeader(t cborType, n uint64) []byte {
b := make([]byte, 9)
if n <= 23 {
if n <= maxAdditionalInformationWithoutArgument {
const headSize = 1
var b [headSize]byte
b[0] = byte(t) | byte(n)
return b[:1]
} else if n <= math.MaxUint8 {
b[0] = byte(t) | byte(24)
b[1] = byte(n)
return b[:2]
} else if n <= math.MaxUint16 {
b[0] = byte(t) | byte(25)
binary.BigEndian.PutUint16(b[1:], uint16(n))
return b[:3]
} else if n <= math.MaxUint32 {
b[0] = byte(t) | byte(26)
binary.BigEndian.PutUint32(b[1:], uint32(n))
return b[:5]
} else {
b[0] = byte(t) | byte(27)
binary.BigEndian.PutUint64(b[1:], n)
return b[:9]
return b[:]
}
if n <= math.MaxUint8 {
const argumentSize = 1
const headSize = 1 + argumentSize
var b [headSize]byte
b[0] = byte(t) | additionalInformationWith1ByteArgument
b[1] = byte(n)
return b[:]
}
if n <= math.MaxUint16 {
const argumentSize = 2
const headSize = 1 + argumentSize
var b [headSize]byte
b[0] = byte(t) | additionalInformationWith2ByteArgument
binary.BigEndian.PutUint16(b[1:], uint16(n))
return b[:]
}
if n <= math.MaxUint32 {
const argumentSize = 4
const headSize = 1 + argumentSize
var b [headSize]byte
b[0] = byte(t) | additionalInformationWith4ByteArgument
binary.BigEndian.PutUint32(b[1:], uint32(n))
return b[:]
}
const argumentSize = 8
const headSize = 1 + argumentSize
var b [headSize]byte
b[0] = byte(t) | additionalInformationWith8ByteArgument
binary.BigEndian.PutUint64(b[1:], n)
return b[:]
}
func testMarshal(t *testing.T, testCases []marshalTest) {

View File

@ -33,11 +33,11 @@ func (sv SimpleValue) MarshalCBOR() ([]byte, error) {
// only has a single representation variant)."
switch {
case sv <= 23:
case sv <= maxSimpleValueInAdditionalInformation:
return []byte{byte(cborTypePrimitives) | byte(sv)}, nil
case sv >= 32:
return []byte{byte(cborTypePrimitives) | byte(24), byte(sv)}, nil
case sv >= minSimpleValueIn1ByteArgument:
return []byte{byte(cborTypePrimitives) | additionalInformationWith1ByteArgument, byte(sv)}, nil
default:
return nil, &UnsupportedValueError{msg: fmt.Sprintf("SimpleValue(%d)", sv)}
@ -57,7 +57,7 @@ func (sv *SimpleValue) UnmarshalCBOR(data []byte) error {
if typ != cborTypePrimitives {
return &UnmarshalTypeError{CBORType: typ.String(), GoType: "SimpleValue"}
}
if ai > 24 {
if ai > additionalInformationWith1ByteArgument {
return &UnmarshalTypeError{CBORType: typ.String(), GoType: "SimpleValue", errorMsg: "not simple values"}
}

View File

@ -239,10 +239,10 @@ func (enc *Encoder) EndIndefinite() error {
}
var cborIndefHeader = map[cborType][]byte{
cborTypeByteString: {0x5f},
cborTypeTextString: {0x7f},
cborTypeArray: {0x9f},
cborTypeMap: {0xbf},
cborTypeByteString: {cborByteStringWithIndefiniteLengthHead},
cborTypeTextString: {cborTextStringWithIndefiniteLengthHead},
cborTypeArray: {cborArrayWithIndefiniteLengthHead},
cborTypeMap: {cborMapWithIndefiniteLengthHead},
}
func (enc *Encoder) startIndefinite(typ cborType) error {