Skip to content

Commit

Permalink
fix: allow for ptr and non-ptr enum text marshalers (hamba#208)
Browse files Browse the repository at this point in the history
fix: allow for ptr and non-ptr enum text marahslers
  • Loading branch information
nrwiersma authored Nov 25, 2022
1 parent 45ed2d0 commit 23a0bda
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 6 deletions.
19 changes: 17 additions & 2 deletions codec_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder {
return &enumCodec{symbols: schema.(*EnumSchema).Symbols()}
case typ.Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()}
case reflect2.PtrTo(typ).Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true}
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
Expand All @@ -27,6 +29,8 @@ func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder {
return &enumCodec{symbols: schema.(*EnumSchema).Symbols()}
case typ.Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()}
case reflect2.PtrTo(typ).Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true}
}

return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
Expand Down Expand Up @@ -64,6 +68,7 @@ func (c *enumCodec) Encode(ptr unsafe.Pointer, w *Writer) {
type enumTextMarshalerCodec struct {
typ reflect2.Type
symbols []string
ptr bool
}

func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) {
Expand All @@ -74,7 +79,12 @@ func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) {
return
}

obj := c.typ.UnsafeIndirect(ptr)
var obj interface{}
if c.ptr {
obj = c.typ.PackEFace(ptr)
} else {
obj = c.typ.UnsafeIndirect(ptr)
}
if reflect2.IsNil(obj) {
ptrType := c.typ.(*reflect2.UnsafePtrType)
newPtr := ptrType.Elem().UnsafeNew()
Expand All @@ -88,7 +98,12 @@ func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) {
}

func (c *enumTextMarshalerCodec) Encode(ptr unsafe.Pointer, w *Writer) {
obj := c.typ.UnsafeIndirect(ptr)
var obj interface{}
if c.ptr {
obj = c.typ.PackEFace(ptr)
} else {
obj = c.typ.UnsafeIndirect(ptr)
}
if c.typ.IsNullable() && reflect2.IsNil(obj) {
w.Error = errors.New("encoding nil enum text marshaler")
return
Expand Down
40 changes: 40 additions & 0 deletions decoder_enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,42 @@ func TestDecoder_EnumTextUnmarshaler(t *testing.T) {
assert.Equal(t, testEnumTextUnmarshaler(1), *got)
}

func TestDecoder_EnumTextUnmarshalerNonPtr(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02}
schema := `{"type":"enum", "name": "test", "symbols": ["foo", "bar"]}`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got testEnumTextUnmarshaler
err := dec.Decode(&got)

require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, testEnumTextUnmarshaler(1), got)
}

func TestDecoder_EnumTextUnmarshalerObj(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02}
schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": {"type":"enum", "name": "test", "symbols": ["foo", "bar"]}}
]
}`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got testEnumUnmarshalerObj
err := dec.Decode(&got)

require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, testEnumUnmarshalerObj{A: testEnumTextUnmarshaler(1)}, got)
}

func TestDecoder_EnumTextUnmarshalerInvalidSymbol(t *testing.T) {
defer ConfigTeardown()

Expand Down Expand Up @@ -121,6 +157,10 @@ func TestDecoder_EnumTextUnmarshalerError(t *testing.T) {
assert.Error(t, err)
}

type testEnumUnmarshalerObj struct {
A testEnumTextUnmarshaler `avro:"a"`
}

type testEnumTextUnmarshaler int

func (m *testEnumTextUnmarshaler) UnmarshalText(data []byte) error {
Expand Down
32 changes: 28 additions & 4 deletions encoder_enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,28 @@ func TestEncoder_EnumTextMarshalerPtr(t *testing.T) {
require.NoError(t, err)

m := testEnumTextMarshaler(1)
ptr := &m
err = enc.Encode(ptr)
err = enc.Encode(&m)

require.NoError(t, err)
assert.Equal(t, []byte{0x02}, buf.Bytes())
}

func TestEncoder_EnumTextMarshalerObj(t *testing.T) {
defer ConfigTeardown()

schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": {"type":"enum", "name": "test", "symbols": ["foo", "bar"]}}
]
}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
require.NoError(t, err)

m := testEnumMarshlaerObj{A: testEnumTextMarshaler(1)}
err = enc.Encode(&m)

require.NoError(t, err)
assert.Equal(t, []byte{0x02}, buf.Bytes())
Expand Down Expand Up @@ -123,10 +143,14 @@ func TestEncoder_EnumTextMarshalerError(t *testing.T) {
assert.Error(t, err)
}

type testEnumMarshlaerObj struct {
A testEnumTextMarshaler `avro:"a"`
}

type testEnumTextMarshaler int

func (m testEnumTextMarshaler) MarshalText() ([]byte, error) {
switch m {
func (m *testEnumTextMarshaler) MarshalText() ([]byte, error) {
switch *m {
case 0:
return []byte("foo"), nil
case 1:
Expand Down

0 comments on commit 23a0bda

Please sign in to comment.