Skip to content

Commit

Permalink
feature: allow "reversed" nullable schemas (hamba#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Jul 2, 2020
1 parent 3f5e4ef commit e350d2f
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ The following union types are accepted: `map[string]interface{}`, `*T` and `inte
When a non-`nil` union value is encountered, a single key is en/decoded. The key is the avro
type name, or scheam full name in the case of a named schema (enum, fixed or record).
* ***T:** This is allowed in a "nullable" union. A nullable union is defined as a two schema union,
with the first being `null` (ie. `["null", "string"]`), in this case a `*T` is allowed,
with `T` matching the conversion table above.
with one of the types being `null` (ie. `["null", "string"]` or `["string", "null"]`), in this case
a `*T` is allowed, with `T` matching the conversion table above.
* **interface{}:** An `interface` can be provided and the type or name resolved. Primitive types
are pre-registered, but named types, maps and slices will need to be registered with the `Register` function. In the
case of arrays and maps the enclosed schema type or name is postfix to the type
Expand Down
14 changes: 10 additions & 4 deletions codec_union.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) {

func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
union := schema.(*UnionSchema)
_, typeIdx := union.Indices()
ptrType := typ.(*reflect2.UnsafePtrType)
elemType := ptrType.Elem()
decoder := decoderOfType(cfg, union.Types()[1], elemType)
decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType)

return &unionPtrDecoder{
schema: union,
Expand Down Expand Up @@ -187,27 +188,32 @@ func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) {

func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
union := schema.(*UnionSchema)
nullIdx, typeIdx := union.Indices()
ptrType := typ.(*reflect2.UnsafePtrType)
encoder := encoderOfType(cfg, union.Types()[1], ptrType.Elem())
encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem())

return &unionPtrEncoder{
schema: union,
encoder: encoder,
nullIdx: int64(nullIdx),
typeIdx: int64(typeIdx),
}
}

type unionPtrEncoder struct {
schema *UnionSchema
encoder ValEncoder
nullIdx int64
typeIdx int64
}

func (e *unionPtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
if *((*unsafe.Pointer)(ptr)) == nil {
w.WriteLong(0)
w.WriteLong(e.nullIdx)
return
}

w.WriteLong(1)
w.WriteLong(e.typeIdx)
e.encoder.Encode(*((*unsafe.Pointer)(ptr)), w)
}

Expand Down
29 changes: 29 additions & 0 deletions decoder_union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ func TestDecoder_UnionPtr(t *testing.T) {
assert.Equal(t, &want, got)
}

func TestDecoder_UnionPtrReversed(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x00, 0x06, 0x66, 0x6F, 0x6F}
schema := `["string", "null"]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got *string
err := dec.Decode(&got)

want := "foo"
assert.NoError(t, err)
assert.Equal(t, &want, got)
}

func TestDecoder_UnionPtrReuseInstance(t *testing.T) {
defer ConfigTeardown()

Expand Down Expand Up @@ -137,6 +152,20 @@ func TestDecoder_UnionPtrNull(t *testing.T) {
assert.Nil(t, got)
}

func TestDecoder_UnionPtrReversedNull(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02}
schema := `["string", "null"]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got *string
err := dec.Decode(&got)

assert.NoError(t, err)
assert.Nil(t, got)
}

func TestDecoder_UnionPtrInvalidSchema(t *testing.T) {
defer ConfigTeardown()

Expand Down
30 changes: 30 additions & 0 deletions encoder_union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ func TestEncoder_UnionPtr(t *testing.T) {
assert.Equal(t, []byte{0x02, 0x06, 0x66, 0x6F, 0x6F}, buf.Bytes())
}

func TestEncoder_UnionPtrReversed(t *testing.T) {
defer ConfigTeardown()

schema := `["string", "null"]`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

str := "foo"
err = enc.Encode(&str)

assert.NoError(t, err)
assert.Equal(t, []byte{0x00, 0x06, 0x66, 0x6F, 0x6F}, buf.Bytes())
}

func TestEncoder_UnionPtrNull(t *testing.T) {
defer ConfigTeardown()

Expand All @@ -120,6 +135,21 @@ func TestEncoder_UnionPtrNull(t *testing.T) {
assert.Equal(t, []byte{0x00}, buf.Bytes())
}

func TestEncoder_UnionPtrReversedNull(t *testing.T) {
defer ConfigTeardown()

schema := `["string", "null"]`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

var str *string
err = enc.Encode(str)

assert.NoError(t, err)
assert.Equal(t, []byte{0x02}, buf.Bytes())
}

func TestEncoder_UnionPtrNotNullable(t *testing.T) {
defer ConfigTeardown()

Expand Down
15 changes: 14 additions & 1 deletion schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,13 +656,26 @@ func (s *UnionSchema) Types() Schemas {

// Nullable returns the Schema if the union is nullable, otherwise nil.
func (s *UnionSchema) Nullable() bool {
if len(s.types) != 2 || s.types[0].Type() != Null {
if len(s.types) != 2 || s.types[0].Type() != Null && s.types[1].Type() != Null {
return false
}

return true
}

// Indices returns the index of the null and type schemas for a
// nullable schema. For non-nullable schemas 0 is returned for
// both.
func (s *UnionSchema) Indices() (null int, typ int) {
if !s.Nullable() {
return 0, 0
}
if s.types[0].Type() == Null {
return 0, 1
}
return 1, 0
}

// String returns the canonical form of the schema.
func (s *UnionSchema) String() string {
types := ""
Expand Down
35 changes: 35 additions & 0 deletions schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,41 @@ func TestUnionSchema(t *testing.T) {
}
}

func TestUnionSchema_Indices(t *testing.T) {
tests := []struct {
name string
schema string
want [2]int
}{
{
name: "Null First",
schema: `["null", "string"]`,
want: [2]int{0, 1},
},
{
name: "Null Second",
schema: `["string", "null"]`,
want: [2]int{1, 0},
},
{
name: "Not Nullable",
schema: `["null", "string", "int"]`,
want: [2]int{0, 0},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := avro.Parse(tt.schema)

assert.NoError(t, err)
null, typ := s.(*avro.UnionSchema).Indices()
assert.Equal(t, tt.want[0], null)
assert.Equal(t, tt.want[1], typ)
})
}
}

func TestFixedSchema(t *testing.T) {
tests := []struct {
name string
Expand Down

0 comments on commit e350d2f

Please sign in to comment.