From 0ebcdff2ba1904f60e60f26cdf551d787dafe8de Mon Sep 17 00:00:00 2001 From: vpapp <103230317+vpapp@users.noreply.github.com> Date: Thu, 9 Jan 2025 03:14:28 -0500 Subject: [PATCH] feat: allow setting zstd codec options (#485) Co-authored-by: vpapp --- ocf/codec.go | 22 +++++++++---- ocf/codec_test.go | 6 ++-- ocf/ocf.go | 64 ++++++++++++++++++++++++------------ ocf/ocf_test.go | 84 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 30 deletions(-) diff --git a/ocf/codec.go b/ocf/codec.go index 4811bf0..8d0cab2 100644 --- a/ocf/codec.go +++ b/ocf/codec.go @@ -24,19 +24,29 @@ const ( ZStandard CodecName = "zstandard" ) -func resolveCodec(name CodecName, lvl int) (Codec, error) { +type codecOptions struct { + DeflateCompressionLevel int + ZStandardOptions zstdOptions +} + +type zstdOptions struct { + EOptions []zstd.EOption + DOptions []zstd.DOption +} + +func resolveCodec(name CodecName, codecOpts codecOptions) (Codec, error) { switch name { case Null, "": return &NullCodec{}, nil case Deflate: - return &DeflateCodec{compLvl: lvl}, nil + return &DeflateCodec{compLvl: codecOpts.DeflateCompressionLevel}, nil case Snappy: return &SnappyCodec{}, nil case ZStandard: - return newZStandardCodec(), nil + return newZStandardCodec(codecOpts.ZStandardOptions), nil default: return nil, fmt.Errorf("unknown codec %s", name) @@ -132,9 +142,9 @@ type ZStandardCodec struct { encoder *zstd.Encoder } -func newZStandardCodec() *ZStandardCodec { - decoder, _ := zstd.NewReader(nil) - encoder, _ := zstd.NewWriter(nil) +func newZStandardCodec(opts zstdOptions) *ZStandardCodec { + decoder, _ := zstd.NewReader(nil, opts.DOptions...) + encoder, _ := zstd.NewWriter(nil, opts.EOptions...) return &ZStandardCodec{ decoder: decoder, encoder: encoder, diff --git a/ocf/codec_test.go b/ocf/codec_test.go index 2fb21f2..761437f 100644 --- a/ocf/codec_test.go +++ b/ocf/codec_test.go @@ -59,7 +59,7 @@ func BenchmarkZstdEncodeDecodeLowEntropyLong(b *testing.B) { input := makeTestData(8762, func() byte { return 'a' }) - codec, err := resolveCodec(ZStandard, 0) + codec, err := resolveCodec(ZStandard, codecOptions{}) require.NoError(b, err) b.ReportAllocs() @@ -74,7 +74,7 @@ func BenchmarkZstdEncodeDecodeLowEntropyLong(b *testing.B) { func BenchmarkZstdEncodeDecodeHighEntropyLong(b *testing.B) { input := makeTestData(8762, func() byte { return byte(rand.Uint32()) }) - codec, err := resolveCodec(ZStandard, 0) + codec, err := resolveCodec(ZStandard, codecOptions{}) require.NoError(b, err) b.ReportAllocs() @@ -87,7 +87,7 @@ func BenchmarkZstdEncodeDecodeHighEntropyLong(b *testing.B) { } func verifyZstdEncodeDecode(t *testing.T, input []byte) { - codec, err := resolveCodec(ZStandard, 0) + codec, err := resolveCodec(ZStandard, codecOptions{}) require.NoError(t, err) compressed := codec.Encode(input) diff --git a/ocf/ocf.go b/ocf/ocf.go index d980e4f..7c2a05c 100644 --- a/ocf/ocf.go +++ b/ocf/ocf.go @@ -5,6 +5,7 @@ package ocf import ( "bytes" + "compress/flate" "crypto/rand" "encoding/json" "errors" @@ -14,6 +15,7 @@ import ( "github.com/hamba/avro/v2" "github.com/hamba/avro/v2/internal/bytesx" + "github.com/klauspost/compress/zstd" ) const ( @@ -54,6 +56,7 @@ type Header struct { type decoderConfig struct { DecoderConfig avro.API SchemaCache *avro.SchemaCache + CodecOptions codecOptions } // DecoderFunc represents a configuration function for Decoder. @@ -74,6 +77,13 @@ func WithDecoderSchemaCache(cache *avro.SchemaCache) DecoderFunc { } } +// WithZStandardDecoderOptions sets the options for the ZStandard decoder. +func WithZStandardDecoderOptions(opts ...zstd.DOption) DecoderFunc { + return func(cfg *decoderConfig) { + cfg.CodecOptions.ZStandardOptions.DOptions = append(cfg.CodecOptions.ZStandardOptions.DOptions, opts...) + } +} + // Decoder reads and decodes Avro values from a container file. type Decoder struct { reader *avro.Reader @@ -93,6 +103,9 @@ func NewDecoder(r io.Reader, opts ...DecoderFunc) (*Decoder, error) { cfg := decoderConfig{ DecoderConfig: avro.DefaultConfig, SchemaCache: avro.DefaultSchemaCache, + CodecOptions: codecOptions{ + DeflateCompressionLevel: flate.DefaultCompression, + }, } for _, opt := range opts { opt(&cfg) @@ -100,7 +113,7 @@ func NewDecoder(r io.Reader, opts ...DecoderFunc) (*Decoder, error) { reader := avro.NewReader(r, 1024) - h, err := readHeader(reader, cfg.SchemaCache) + h, err := readHeader(reader, cfg.SchemaCache, cfg.CodecOptions) if err != nil { return nil, fmt.Errorf("decoder: %w", err) } @@ -197,14 +210,14 @@ func (d *Decoder) readBlock() int64 { } type encoderConfig struct { - BlockLength int - CodecName CodecName - CodecCompression int - Metadata map[string][]byte - Sync [16]byte - EncodingConfig avro.API - SchemaCache *avro.SchemaCache - SchemaMarshaler func(avro.Schema) ([]byte, error) + BlockLength int + CodecName CodecName + CodecOptions codecOptions + Metadata map[string][]byte + Sync [16]byte + EncodingConfig avro.API + SchemaCache *avro.SchemaCache + SchemaMarshaler func(avro.Schema) ([]byte, error) } // EncoderFunc represents a configuration function for Encoder. @@ -229,7 +242,14 @@ func WithCodec(codec CodecName) EncoderFunc { func WithCompressionLevel(compLvl int) EncoderFunc { return func(cfg *encoderConfig) { cfg.CodecName = Deflate - cfg.CodecCompression = compLvl + cfg.CodecOptions.DeflateCompressionLevel = compLvl + } +} + +// WithZStandardEncoderOptions sets the options for the ZStandard encoder. +func WithZStandardEncoderOptions(opts ...zstd.EOption) EncoderFunc { + return func(cfg *encoderConfig) { + cfg.CodecOptions.ZStandardOptions.EOptions = append(cfg.CodecOptions.ZStandardOptions.EOptions, opts...) } } @@ -316,7 +336,7 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e if info.Size() > 0 { reader := avro.NewReader(file, 1024) - h, err := readHeader(reader, cfg.SchemaCache) + h, err := readHeader(reader, cfg.SchemaCache, cfg.CodecOptions) if err != nil { return nil, err } @@ -354,7 +374,7 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e _, _ = rand.Read(header.Sync[:]) } - codec, err := resolveCodec(cfg.CodecName, cfg.CodecCompression) + codec, err := resolveCodec(cfg.CodecName, cfg.CodecOptions) if err != nil { return nil, err } @@ -379,13 +399,15 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e func computeEncoderConfig(opts []EncoderFunc) encoderConfig { cfg := encoderConfig{ - BlockLength: 100, - CodecName: Null, - CodecCompression: -1, - Metadata: map[string][]byte{}, - EncodingConfig: avro.DefaultConfig, - SchemaCache: avro.DefaultSchemaCache, - SchemaMarshaler: DefaultSchemaMarshaler, + BlockLength: 100, + CodecName: Null, + CodecOptions: codecOptions{ + DeflateCompressionLevel: flate.DefaultCompression, + }, + Metadata: map[string][]byte{}, + EncodingConfig: avro.DefaultConfig, + SchemaCache: avro.DefaultSchemaCache, + SchemaMarshaler: DefaultSchemaMarshaler, } for _, opt := range opts { opt(&cfg) @@ -469,7 +491,7 @@ type ocfHeader struct { Sync [16]byte } -func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache) (*ocfHeader, error) { +func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache, codecOpts codecOptions) (*ocfHeader, error) { var h Header reader.ReadVal(HeaderSchema, &h) if reader.Error != nil { @@ -484,7 +506,7 @@ func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache) (*ocfHeader, return nil, err } - codec, err := resolveCodec(CodecName(h.Meta[codecKey]), -1) + codec, err := resolveCodec(CodecName(h.Meta[codecKey]), codecOpts) if err != nil { return nil, err } diff --git a/ocf/ocf_test.go b/ocf/ocf_test.go index 200a779..3f5eb6a 100644 --- a/ocf/ocf_test.go +++ b/ocf/ocf_test.go @@ -13,6 +13,7 @@ import ( "github.com/hamba/avro/v2" "github.com/hamba/avro/v2/ocf" + "github.com/klauspost/compress/zstd" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -409,6 +410,52 @@ func TestDecoder_WithZStandardHandlesInvalidData(t *testing.T) { assert.Error(t, dec.Error()) } +func TestDecoder_WithZStandardOptions(t *testing.T) { + unionStr := "union value" + want := FullRecord{ + Strings: []string{"string1", "string2", "string3", "string4", "string5"}, + Longs: []int64{1, 2, 3, 4, 5}, + Enum: "C", + Map: map[string]int{ + "ke\xa9\xb1": 1, + "\x00\x00y2": 2, + "key3": 3, + "key4": 4, + "key5": 5, + }, + Nullable: &unionStr, + Fixed: [16]byte{0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04}, + Record: &TestRecord{ + Long: 1925639126735, + String: "I am a test record", + Int: 666, + Float: 7171.17, + Double: 916734926348163.01973408746523, + Bool: true, + }, + } + + f, err := os.Open("testdata/zstd-invalid-data.avro") + require.NoError(t, err) + t.Cleanup(func() { _ = f.Close() }) + + dec, err := ocf.NewDecoder(f, ocf.WithZStandardDecoderOptions(zstd.IgnoreChecksum(true))) + require.NoError(t, err) + + dec.HasNext() + + var got FullRecord + err = dec.Decode(&got) + + require.NoError(t, err, "should not cause an error because checksum is ignored") + require.NoError(t, dec.Error(), "should not cause an error because checksum is ignored") + assert.Equal(t, want, got, "should read corrupted data as valid because checksum is ignored") + + dec.HasNext() + + assert.ErrorContains(t, dec.Error(), "decoder: invalid block", "trailing byte in file should cause error before hitting zstd decoder") +} + func TestDecoder_DecodeAvroError(t *testing.T) { data := []byte{'O', 'b', 'j', 0x01, 0x01, 0x26, 0x16, 'a', 'v', 'r', 'o', '.', 's', 'c', 'h', 'e', 'm', 'a', 0x0c, '"', 'l', 'o', 'n', 'g', '"', 0x00, 0xfb, 0x2b, 0x0f, 0x1a, 0xdd, 0xfd, 0x90, 0x7d, 0x87, 0x12, @@ -878,6 +925,43 @@ func TestEncoder_EncodeCompressesZStandard(t *testing.T) { assert.Equal(t, 951, buf.Len()) } +func TestEncoder_EncodeCompressesZStandardWithLevel(t *testing.T) { + unionStr := "union value" + record := FullRecord{ + Strings: []string{"string1", "string2", "string3", "string4", "string5"}, + Longs: []int64{1, 2, 3, 4, 5}, + Enum: "C", + Map: map[string]int{ + "key1": 1, + "key2": 2, + "key3": 3, + "key4": 4, + "key5": 5, + }, + Nullable: &unionStr, + Fixed: [16]byte{0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04}, + Record: &TestRecord{ + Long: 1925639126735, + String: "I am a test record", + Int: 666, + Float: 7171.17, + Double: 916734926348163.01973408746523, + Bool: true, + }, + } + + buf := &bytes.Buffer{} + enc, _ := ocf.NewEncoder(schema, buf, ocf.WithCodec(ocf.ZStandard), ocf.WithZStandardEncoderOptions(zstd.WithEncoderLevel(zstd.SpeedBestCompression))) + + err := enc.Encode(record) + assert.NoError(t, err) + + err = enc.Close() + + require.NoError(t, err) + assert.Equal(t, 942, buf.Len()) +} + func TestEncoder_EncodeError(t *testing.T) { buf := &bytes.Buffer{} enc, err := ocf.NewEncoder(`"long"`, buf)