From e350d2ff01977073ae3bbb9c04dd704831c9c1c6 Mon Sep 17 00:00:00 2001 From: Nicholas Wiersma Date: Thu, 2 Jul 2020 17:49:26 +0200 Subject: [PATCH] feature: allow "reversed" nullable schemas (#49) --- README.md | 4 ++-- codec_union.go | 14 ++++++++++---- decoder_union_test.go | 29 +++++++++++++++++++++++++++++ encoder_union_test.go | 30 ++++++++++++++++++++++++++++++ schema.go | 15 ++++++++++++++- schema_test.go | 35 +++++++++++++++++++++++++++++++++++ 6 files changed, 120 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 8d35e492..d4fcf34b 100644 --- a/README.md +++ b/README.md @@ -94,8 +94,8 @@ The following union types are accepted: `map[string]interface{}`, `*T` and `inte When a non-`nil` union value is encountered, a single key is en/decoded. The key is the avro type name, or scheam full name in the case of a named schema (enum, fixed or record). * ***T:** This is allowed in a "nullable" union. A nullable union is defined as a two schema union, -with the first being `null` (ie. `["null", "string"]`), in this case a `*T` is allowed, -with `T` matching the conversion table above. +with one of the types being `null` (ie. `["null", "string"]` or `["string", "null"]`), in this case +a `*T` is allowed, with `T` matching the conversion table above. * **interface{}:** An `interface` can be provided and the type or name resolved. Primitive types are pre-registered, but named types, maps and slices will need to be registered with the `Register` function. In the case of arrays and maps the enclosed schema type or name is postfix to the type diff --git a/codec_union.go b/codec_union.go index 4147537a..d4384ae8 100644 --- a/codec_union.go +++ b/codec_union.go @@ -145,9 +145,10 @@ func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) { func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { union := schema.(*UnionSchema) + _, typeIdx := union.Indices() ptrType := typ.(*reflect2.UnsafePtrType) elemType := ptrType.Elem() - decoder := decoderOfType(cfg, union.Types()[1], elemType) + decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType) return &unionPtrDecoder{ schema: union, @@ -187,27 +188,32 @@ func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) { func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { union := schema.(*UnionSchema) + nullIdx, typeIdx := union.Indices() ptrType := typ.(*reflect2.UnsafePtrType) - encoder := encoderOfType(cfg, union.Types()[1], ptrType.Elem()) + encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem()) return &unionPtrEncoder{ schema: union, encoder: encoder, + nullIdx: int64(nullIdx), + typeIdx: int64(typeIdx), } } type unionPtrEncoder struct { schema *UnionSchema encoder ValEncoder + nullIdx int64 + typeIdx int64 } func (e *unionPtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) { if *((*unsafe.Pointer)(ptr)) == nil { - w.WriteLong(0) + w.WriteLong(e.nullIdx) return } - w.WriteLong(1) + w.WriteLong(e.typeIdx) e.encoder.Encode(*((*unsafe.Pointer)(ptr)), w) } diff --git a/decoder_union_test.go b/decoder_union_test.go index cfb4e464..400b2c7e 100644 --- a/decoder_union_test.go +++ b/decoder_union_test.go @@ -105,6 +105,21 @@ func TestDecoder_UnionPtr(t *testing.T) { assert.Equal(t, &want, got) } +func TestDecoder_UnionPtrReversed(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x00, 0x06, 0x66, 0x6F, 0x6F} + schema := `["string", "null"]` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got *string + err := dec.Decode(&got) + + want := "foo" + assert.NoError(t, err) + assert.Equal(t, &want, got) +} + func TestDecoder_UnionPtrReuseInstance(t *testing.T) { defer ConfigTeardown() @@ -137,6 +152,20 @@ func TestDecoder_UnionPtrNull(t *testing.T) { assert.Nil(t, got) } +func TestDecoder_UnionPtrReversedNull(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x02} + schema := `["string", "null"]` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got *string + err := dec.Decode(&got) + + assert.NoError(t, err) + assert.Nil(t, got) +} + func TestDecoder_UnionPtrInvalidSchema(t *testing.T) { defer ConfigTeardown() diff --git a/encoder_union_test.go b/encoder_union_test.go index b5412298..69cc0836 100644 --- a/encoder_union_test.go +++ b/encoder_union_test.go @@ -105,6 +105,21 @@ func TestEncoder_UnionPtr(t *testing.T) { assert.Equal(t, []byte{0x02, 0x06, 0x66, 0x6F, 0x6F}, buf.Bytes()) } +func TestEncoder_UnionPtrReversed(t *testing.T) { + defer ConfigTeardown() + + schema := `["string", "null"]` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + + str := "foo" + err = enc.Encode(&str) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x00, 0x06, 0x66, 0x6F, 0x6F}, buf.Bytes()) +} + func TestEncoder_UnionPtrNull(t *testing.T) { defer ConfigTeardown() @@ -120,6 +135,21 @@ func TestEncoder_UnionPtrNull(t *testing.T) { assert.Equal(t, []byte{0x00}, buf.Bytes()) } +func TestEncoder_UnionPtrReversedNull(t *testing.T) { + defer ConfigTeardown() + + schema := `["string", "null"]` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + + var str *string + err = enc.Encode(str) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x02}, buf.Bytes()) +} + func TestEncoder_UnionPtrNotNullable(t *testing.T) { defer ConfigTeardown() diff --git a/schema.go b/schema.go index e2f55381..02012952 100644 --- a/schema.go +++ b/schema.go @@ -656,13 +656,26 @@ func (s *UnionSchema) Types() Schemas { // Nullable returns the Schema if the union is nullable, otherwise nil. func (s *UnionSchema) Nullable() bool { - if len(s.types) != 2 || s.types[0].Type() != Null { + if len(s.types) != 2 || s.types[0].Type() != Null && s.types[1].Type() != Null { return false } return true } +// Indices returns the index of the null and type schemas for a +// nullable schema. For non-nullable schemas 0 is returned for +// both. +func (s *UnionSchema) Indices() (null int, typ int) { + if !s.Nullable() { + return 0, 0 + } + if s.types[0].Type() == Null { + return 0, 1 + } + return 1, 0 +} + // String returns the canonical form of the schema. func (s *UnionSchema) String() string { types := "" diff --git a/schema_test.go b/schema_test.go index 752b469f..4f6bfada 100644 --- a/schema_test.go +++ b/schema_test.go @@ -694,6 +694,41 @@ func TestUnionSchema(t *testing.T) { } } +func TestUnionSchema_Indices(t *testing.T) { + tests := []struct { + name string + schema string + want [2]int + }{ + { + name: "Null First", + schema: `["null", "string"]`, + want: [2]int{0, 1}, + }, + { + name: "Null Second", + schema: `["string", "null"]`, + want: [2]int{1, 0}, + }, + { + name: "Not Nullable", + schema: `["null", "string", "int"]`, + want: [2]int{0, 0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, err := avro.Parse(tt.schema) + + assert.NoError(t, err) + null, typ := s.(*avro.UnionSchema).Indices() + assert.Equal(t, tt.want[0], null) + assert.Equal(t, tt.want[1], typ) + }) + } +} + func TestFixedSchema(t *testing.T) { tests := []struct { name string