Use TextUnmarshaler on byte strings with ByteStringToStringAllowed.

Decode modes that allow unmarshaling CBOR byte strings into Go strings and recognize
TextUnmarshalers should also allow unmarshaling CBOR byte strings into TextUnmarshalers. When
interoperating with an encode mode that marshals Go strings to CBOR byte strings, this preserves
roundtrippability of TextMarshaler -> CBOR text string -> Go string -> CBOR byte string ->
TextUnmarshaler.

Signed-off-by: Ben Luddy <bluddy@redhat.com>
This commit is contained in:
Ben Luddy 2025-07-11 11:47:08 -04:00
parent a89c3ce6ea
commit 9bdebd2c0b
No known key found for this signature in database
GPG Key ID: A6551E73A5974C30
2 changed files with 33 additions and 7 deletions

View File

@ -1570,7 +1570,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return err
}
copied = copied || converted
return fillByteString(t, b, !copied, v, d.dm.byteStringToString, d.dm.binaryUnmarshaler)
return fillByteString(t, b, !copied, v, d.dm.byteStringToString, d.dm.binaryUnmarshaler, d.dm.textUnmarshaler)
case cborTypeTextString:
b, err := d.parseTextString()
@ -1629,7 +1629,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler, d.dm.textUnmarshaler)
}
if bi.IsUint64() {
return fillPositiveInt(t, bi.Uint64(), v)
@ -1652,7 +1652,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler, d.dm.textUnmarshaler)
}
if bi.IsInt64() {
return fillNegativeInt(t, bi.Int64(), v)
@ -3180,7 +3180,7 @@ func fillFloat(t cborType, val float64, v reflect.Value) error {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}
func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode, bum BinaryUnmarshalerMode) error {
func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode, bum BinaryUnmarshalerMode, tum TextUnmarshalerMode) error {
if bum == BinaryUnmarshalerByteString && reflect.PointerTo(v.Type()).Implements(typeBinaryUnmarshaler) {
if v.CanAddr() {
v = v.Addr()
@ -3193,9 +3193,26 @@ func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts B
}
return errors.New("cbor: cannot set new value for " + v.Type().String())
}
if bsts != ByteStringToStringForbidden && v.Kind() == reflect.String {
v.SetString(string(val))
return nil
if bsts != ByteStringToStringForbidden {
if tum == TextUnmarshalerTextString && reflect.PointerTo(v.Type()).Implements(typeTextUnmarshaler) {
if v.CanAddr() {
v = v.Addr()
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
// The contract of TextUnmarshaler forbids retaining the input
// bytes, so no copying is required even if val is shared.
if err := u.UnmarshalText(val); err != nil {
return fmt.Errorf("cbor: cannot unmarshal text for %s: %w", v.Type(), err)
}
return nil
}
}
return errors.New("cbor: cannot set new value for " + v.Type().String())
}
if v.Kind() == reflect.String {
v.SetString(string(val))
return nil
}
}
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
src := val

View File

@ -10779,6 +10779,15 @@ func TestTextUnmarshalerMode(t *testing.T) {
in: []byte("\x65hello"), // "hello"
want: testTextUnmarshaler("hello"),
},
{
name: "UnmarshalText is called for byte string with TextUnmarshalerTextString and ByteStringToStringAllowed",
opts: DecOptions{
TextUnmarshaler: TextUnmarshalerTextString,
ByteStringToString: ByteStringToStringAllowed,
},
in: []byte("\x45hello"), // 'hello'
want: testTextUnmarshaler("UnmarshalText"),
},
} {
t.Run(tc.name, func(t *testing.T) {
dm, err := tc.opts.DecMode()