Skip to content

Commit

Permalink
fix: decimal decoding into *big.Rat (hamba#425)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Jul 31, 2024
1 parent 3fc81b6 commit 4aff30f
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 31 deletions.
6 changes: 3 additions & 3 deletions codec_default_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,15 +741,15 @@ func TestDecoder_DefaultFixed(t *testing.T) {
schema.(*RecordSchema).Fields()[1].action = FieldSetDefault

type TestRecord struct {
A string `avro:"a"`
B big.Rat `avro:"b"`
A string `avro:"a"`
B *big.Rat `avro:"b"`
}

var got TestRecord
err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got)

require.NoError(t, err)
assert.Equal(t, big.NewRat(1734, 5), &got.B)
assert.Equal(t, big.NewRat(1734, 5), got.B)
assert.Equal(t, "foo", got.A)
})
}
25 changes: 15 additions & 10 deletions codec_fixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,34 @@ func createDecoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValDecoder {
break
}
return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)}

case reflect.Uint64:
if fixed.Size() != 8 {
break
}

return &fixedUint64Codec{}
case reflect.Ptr:
ptrType := typ.(*reflect2.UnsafePtrType)
elemType := ptrType.Elem()

ls := fixed.Logical()
tpy1 := elemType.Type1()
if elemType.Kind() != reflect.Struct || !tpy1.ConvertibleTo(ratType) || ls == nil ||
ls.Type() != Decimal {
break
}
dec := ls.(*DecimalLogicalSchema)
return &fixedDecimalCodec{prec: dec.Precision(), scale: dec.Scale(), size: fixed.Size()}
case reflect.Struct:
ls := fixed.Logical()
if ls == nil {
break
}
typ1 := typ.Type1()
switch {
case typ1.ConvertibleTo(durType) && ls.Type() == Duration:
return &fixedDurationCodec{}
case typ1.ConvertibleTo(ratType) && ls.Type() == Decimal:
dec := ls.(*DecimalLogicalSchema)
return &fixedDecimalCodec{prec: dec.Precision(), scale: dec.Scale(), size: fixed.Size()}
if !typ1.ConvertibleTo(durType) || ls.Type() != Duration {
break
}
return &fixedDurationCodec{}
}

return &errorDecoder{
Expand All @@ -54,14 +61,12 @@ func createEncoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValEncoder {
break
}
return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)}

case reflect.Uint64:
if fixed.Size() != 8 {
break
}

return &fixedUint64Codec{}

case reflect.Ptr:
ptrType := typ.(*reflect2.UnsafePtrType)
elemType := ptrType.Elem()
Expand Down Expand Up @@ -131,7 +136,7 @@ type fixedDecimalCodec struct {
func (c *fixedDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) {
b := make([]byte, c.size)
r.Read(b)
*((*big.Rat)(ptr)) = *ratFromBytes(b, c.scale)
*((**big.Rat)(ptr)) = ratFromBytes(b, c.scale)
}

func (c *fixedDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) {
Expand Down
9 changes: 1 addition & 8 deletions codec_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ func genericDecode(typ reflect2.Type, dec ValDecoder, r *Reader) any {
if reflect2.IsNil(obj) {
return nil
}

// Generic reader returns a different result from the
// codec in the case of a big.Rat. Handle this.
if typ.Type1() == ratType {
dec := obj.(big.Rat)
return &dec
}
return obj
}

Expand Down Expand Up @@ -125,7 +118,7 @@ func genericReceiver(schema Schema) (reflect2.Type, error) {
var v LogicalDuration
return reflect2.TypeOf(v), nil
case Decimal:
var v big.Rat
var v *big.Rat
return reflect2.TypeOf(v), nil
}
}
Expand Down
2 changes: 1 addition & 1 deletion codec_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) {
if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 {
i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8))
}
*((*big.Rat)(ptr)) = *ratFromBytes(b, c.scale)
*((**big.Rat)(ptr)) = ratFromBytes(b, c.scale)
}

func ratFromBytes(b []byte, scale int) *big.Rat {
Expand Down
8 changes: 4 additions & 4 deletions decoder_fixed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestDecoder_FixedRat_Positive(t *testing.T) {
require.NoError(t, err)

got := &big.Rat{}
err = dec.Decode(got)
err = dec.Decode(&got)

require.NoError(t, err)
assert.Equal(t, big.NewRat(1734, 5), got)
Expand All @@ -64,7 +64,7 @@ func TestDecoder_FixedRat_Negative(t *testing.T) {
require.NoError(t, err)

got := &big.Rat{}
err = dec.Decode(got)
err = dec.Decode(&got)

require.NoError(t, err)
assert.Equal(t, big.NewRat(-1734, 5), got)
Expand All @@ -79,7 +79,7 @@ func TestDecoder_FixedRat_Zero(t *testing.T) {
require.NoError(t, err)

got := &big.Rat{}
err = dec.Decode(got)
err = dec.Decode(&got)

require.NoError(t, err)
assert.Equal(t, big.NewRat(0, 1), got)
Expand All @@ -94,7 +94,7 @@ func TestDecoder_FixedRatInvalidLogicalSchema(t *testing.T) {
require.NoError(t, err)

got := &big.Rat{}
err = dec.Decode(got)
err = dec.Decode(&got)

assert.Error(t, err)
}
Expand Down
10 changes: 5 additions & 5 deletions decoder_native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ func TestDecoder_BytesRat_Positive(t *testing.T) {
require.NoError(t, err)

got := &big.Rat{}
err = dec.Decode(got)
err = dec.Decode(&got)

require.NoError(t, err)
assert.Equal(t, big.NewRat(1734, 5), got)
Expand All @@ -783,7 +783,7 @@ func TestDecoder_BytesRat_Negative(t *testing.T) {
require.NoError(t, err)

got := &big.Rat{}
err = dec.Decode(got)
err = dec.Decode(&got)

require.NoError(t, err)
assert.Equal(t, big.NewRat(-1734, 5), got)
Expand All @@ -798,7 +798,7 @@ func TestDecoder_BytesRat_Zero(t *testing.T) {
require.NoError(t, err)

got := &big.Rat{}
err = dec.Decode(got)
err = dec.Decode(&got)

require.NoError(t, err)
assert.Equal(t, big.NewRat(0, 1), got)
Expand All @@ -813,7 +813,7 @@ func TestDecoder_BytesRatInvalidSchema(t *testing.T) {
require.NoError(t, err)

got := &big.Rat{}
err = dec.Decode(got)
err = dec.Decode(&got)

assert.Error(t, err)
}
Expand All @@ -827,7 +827,7 @@ func TestDecoder_BytesRatInvalidLogicalSchema(t *testing.T) {
require.NoError(t, err)

got := &big.Rat{}
err = dec.Decode(got)
err = dec.Decode(&got)

assert.Error(t, err)
}

0 comments on commit 4aff30f

Please sign in to comment.