Skip to content

Commit

Permalink
feat: allow setting zstd codec options (#485)
Browse files Browse the repository at this point in the history
Co-authored-by: vpapp <[email protected]>
  • Loading branch information
vpapp and vpapp authored Jan 9, 2025
1 parent 53e4ea9 commit 0ebcdff
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 30 deletions.
22 changes: 16 additions & 6 deletions ocf/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,29 @@ const (
ZStandard CodecName = "zstandard"
)

func resolveCodec(name CodecName, lvl int) (Codec, error) {
type codecOptions struct {
DeflateCompressionLevel int
ZStandardOptions zstdOptions
}

type zstdOptions struct {
EOptions []zstd.EOption
DOptions []zstd.DOption
}

func resolveCodec(name CodecName, codecOpts codecOptions) (Codec, error) {
switch name {
case Null, "":
return &NullCodec{}, nil

case Deflate:
return &DeflateCodec{compLvl: lvl}, nil
return &DeflateCodec{compLvl: codecOpts.DeflateCompressionLevel}, nil

case Snappy:
return &SnappyCodec{}, nil

case ZStandard:
return newZStandardCodec(), nil
return newZStandardCodec(codecOpts.ZStandardOptions), nil

default:
return nil, fmt.Errorf("unknown codec %s", name)
Expand Down Expand Up @@ -132,9 +142,9 @@ type ZStandardCodec struct {
encoder *zstd.Encoder
}

func newZStandardCodec() *ZStandardCodec {
decoder, _ := zstd.NewReader(nil)
encoder, _ := zstd.NewWriter(nil)
func newZStandardCodec(opts zstdOptions) *ZStandardCodec {
decoder, _ := zstd.NewReader(nil, opts.DOptions...)
encoder, _ := zstd.NewWriter(nil, opts.EOptions...)
return &ZStandardCodec{
decoder: decoder,
encoder: encoder,
Expand Down
6 changes: 3 additions & 3 deletions ocf/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func BenchmarkZstdEncodeDecodeLowEntropyLong(b *testing.B) {

input := makeTestData(8762, func() byte { return 'a' })

codec, err := resolveCodec(ZStandard, 0)
codec, err := resolveCodec(ZStandard, codecOptions{})
require.NoError(b, err)

b.ReportAllocs()
Expand All @@ -74,7 +74,7 @@ func BenchmarkZstdEncodeDecodeLowEntropyLong(b *testing.B) {
func BenchmarkZstdEncodeDecodeHighEntropyLong(b *testing.B) {
input := makeTestData(8762, func() byte { return byte(rand.Uint32()) })

codec, err := resolveCodec(ZStandard, 0)
codec, err := resolveCodec(ZStandard, codecOptions{})
require.NoError(b, err)

b.ReportAllocs()
Expand All @@ -87,7 +87,7 @@ func BenchmarkZstdEncodeDecodeHighEntropyLong(b *testing.B) {
}

func verifyZstdEncodeDecode(t *testing.T, input []byte) {
codec, err := resolveCodec(ZStandard, 0)
codec, err := resolveCodec(ZStandard, codecOptions{})
require.NoError(t, err)

compressed := codec.Encode(input)
Expand Down
64 changes: 43 additions & 21 deletions ocf/ocf.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package ocf

import (
"bytes"
"compress/flate"
"crypto/rand"
"encoding/json"
"errors"
Expand All @@ -14,6 +15,7 @@ import (

"github.com/hamba/avro/v2"
"github.com/hamba/avro/v2/internal/bytesx"
"github.com/klauspost/compress/zstd"
)

const (
Expand Down Expand Up @@ -54,6 +56,7 @@ type Header struct {
type decoderConfig struct {
DecoderConfig avro.API
SchemaCache *avro.SchemaCache
CodecOptions codecOptions
}

// DecoderFunc represents a configuration function for Decoder.
Expand All @@ -74,6 +77,13 @@ func WithDecoderSchemaCache(cache *avro.SchemaCache) DecoderFunc {
}
}

// WithZStandardDecoderOptions sets the options for the ZStandard decoder.
func WithZStandardDecoderOptions(opts ...zstd.DOption) DecoderFunc {
return func(cfg *decoderConfig) {
cfg.CodecOptions.ZStandardOptions.DOptions = append(cfg.CodecOptions.ZStandardOptions.DOptions, opts...)
}
}

// Decoder reads and decodes Avro values from a container file.
type Decoder struct {
reader *avro.Reader
Expand All @@ -93,14 +103,17 @@ func NewDecoder(r io.Reader, opts ...DecoderFunc) (*Decoder, error) {
cfg := decoderConfig{
DecoderConfig: avro.DefaultConfig,
SchemaCache: avro.DefaultSchemaCache,
CodecOptions: codecOptions{
DeflateCompressionLevel: flate.DefaultCompression,
},
}
for _, opt := range opts {
opt(&cfg)
}

reader := avro.NewReader(r, 1024)

h, err := readHeader(reader, cfg.SchemaCache)
h, err := readHeader(reader, cfg.SchemaCache, cfg.CodecOptions)
if err != nil {
return nil, fmt.Errorf("decoder: %w", err)
}
Expand Down Expand Up @@ -197,14 +210,14 @@ func (d *Decoder) readBlock() int64 {
}

type encoderConfig struct {
BlockLength int
CodecName CodecName
CodecCompression int
Metadata map[string][]byte
Sync [16]byte
EncodingConfig avro.API
SchemaCache *avro.SchemaCache
SchemaMarshaler func(avro.Schema) ([]byte, error)
BlockLength int
CodecName CodecName
CodecOptions codecOptions
Metadata map[string][]byte
Sync [16]byte
EncodingConfig avro.API
SchemaCache *avro.SchemaCache
SchemaMarshaler func(avro.Schema) ([]byte, error)
}

// EncoderFunc represents a configuration function for Encoder.
Expand All @@ -229,7 +242,14 @@ func WithCodec(codec CodecName) EncoderFunc {
func WithCompressionLevel(compLvl int) EncoderFunc {
return func(cfg *encoderConfig) {
cfg.CodecName = Deflate
cfg.CodecCompression = compLvl
cfg.CodecOptions.DeflateCompressionLevel = compLvl
}
}

// WithZStandardEncoderOptions sets the options for the ZStandard encoder.
func WithZStandardEncoderOptions(opts ...zstd.EOption) EncoderFunc {
return func(cfg *encoderConfig) {
cfg.CodecOptions.ZStandardOptions.EOptions = append(cfg.CodecOptions.ZStandardOptions.EOptions, opts...)
}
}

Expand Down Expand Up @@ -316,7 +336,7 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e

if info.Size() > 0 {
reader := avro.NewReader(file, 1024)
h, err := readHeader(reader, cfg.SchemaCache)
h, err := readHeader(reader, cfg.SchemaCache, cfg.CodecOptions)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -354,7 +374,7 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e
_, _ = rand.Read(header.Sync[:])
}

codec, err := resolveCodec(cfg.CodecName, cfg.CodecCompression)
codec, err := resolveCodec(cfg.CodecName, cfg.CodecOptions)
if err != nil {
return nil, err
}
Expand All @@ -379,13 +399,15 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e

func computeEncoderConfig(opts []EncoderFunc) encoderConfig {
cfg := encoderConfig{
BlockLength: 100,
CodecName: Null,
CodecCompression: -1,
Metadata: map[string][]byte{},
EncodingConfig: avro.DefaultConfig,
SchemaCache: avro.DefaultSchemaCache,
SchemaMarshaler: DefaultSchemaMarshaler,
BlockLength: 100,
CodecName: Null,
CodecOptions: codecOptions{
DeflateCompressionLevel: flate.DefaultCompression,
},
Metadata: map[string][]byte{},
EncodingConfig: avro.DefaultConfig,
SchemaCache: avro.DefaultSchemaCache,
SchemaMarshaler: DefaultSchemaMarshaler,
}
for _, opt := range opts {
opt(&cfg)
Expand Down Expand Up @@ -469,7 +491,7 @@ type ocfHeader struct {
Sync [16]byte
}

func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache) (*ocfHeader, error) {
func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache, codecOpts codecOptions) (*ocfHeader, error) {
var h Header
reader.ReadVal(HeaderSchema, &h)
if reader.Error != nil {
Expand All @@ -484,7 +506,7 @@ func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache) (*ocfHeader,
return nil, err
}

codec, err := resolveCodec(CodecName(h.Meta[codecKey]), -1)
codec, err := resolveCodec(CodecName(h.Meta[codecKey]), codecOpts)
if err != nil {
return nil, err
}
Expand Down
84 changes: 84 additions & 0 deletions ocf/ocf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/hamba/avro/v2"
"github.com/hamba/avro/v2/ocf"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -409,6 +410,52 @@ func TestDecoder_WithZStandardHandlesInvalidData(t *testing.T) {
assert.Error(t, dec.Error())
}

func TestDecoder_WithZStandardOptions(t *testing.T) {
unionStr := "union value"
want := FullRecord{
Strings: []string{"string1", "string2", "string3", "string4", "string5"},
Longs: []int64{1, 2, 3, 4, 5},
Enum: "C",
Map: map[string]int{
"ke\xa9\xb1": 1,
"\x00\x00y2": 2,
"key3": 3,
"key4": 4,
"key5": 5,
},
Nullable: &unionStr,
Fixed: [16]byte{0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04},
Record: &TestRecord{
Long: 1925639126735,
String: "I am a test record",
Int: 666,
Float: 7171.17,
Double: 916734926348163.01973408746523,
Bool: true,
},
}

f, err := os.Open("testdata/zstd-invalid-data.avro")
require.NoError(t, err)
t.Cleanup(func() { _ = f.Close() })

dec, err := ocf.NewDecoder(f, ocf.WithZStandardDecoderOptions(zstd.IgnoreChecksum(true)))
require.NoError(t, err)

dec.HasNext()

var got FullRecord
err = dec.Decode(&got)

require.NoError(t, err, "should not cause an error because checksum is ignored")
require.NoError(t, dec.Error(), "should not cause an error because checksum is ignored")
assert.Equal(t, want, got, "should read corrupted data as valid because checksum is ignored")

dec.HasNext()

assert.ErrorContains(t, dec.Error(), "decoder: invalid block", "trailing byte in file should cause error before hitting zstd decoder")
}

func TestDecoder_DecodeAvroError(t *testing.T) {
data := []byte{'O', 'b', 'j', 0x01, 0x01, 0x26, 0x16, 'a', 'v', 'r', 'o', '.', 's', 'c', 'h', 'e', 'm', 'a',
0x0c, '"', 'l', 'o', 'n', 'g', '"', 0x00, 0xfb, 0x2b, 0x0f, 0x1a, 0xdd, 0xfd, 0x90, 0x7d, 0x87, 0x12,
Expand Down Expand Up @@ -878,6 +925,43 @@ func TestEncoder_EncodeCompressesZStandard(t *testing.T) {
assert.Equal(t, 951, buf.Len())
}

func TestEncoder_EncodeCompressesZStandardWithLevel(t *testing.T) {
unionStr := "union value"
record := FullRecord{
Strings: []string{"string1", "string2", "string3", "string4", "string5"},
Longs: []int64{1, 2, 3, 4, 5},
Enum: "C",
Map: map[string]int{
"key1": 1,
"key2": 2,
"key3": 3,
"key4": 4,
"key5": 5,
},
Nullable: &unionStr,
Fixed: [16]byte{0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04},
Record: &TestRecord{
Long: 1925639126735,
String: "I am a test record",
Int: 666,
Float: 7171.17,
Double: 916734926348163.01973408746523,
Bool: true,
},
}

buf := &bytes.Buffer{}
enc, _ := ocf.NewEncoder(schema, buf, ocf.WithCodec(ocf.ZStandard), ocf.WithZStandardEncoderOptions(zstd.WithEncoderLevel(zstd.SpeedBestCompression)))

err := enc.Encode(record)
assert.NoError(t, err)

err = enc.Close()

require.NoError(t, err)
assert.Equal(t, 942, buf.Len())
}

func TestEncoder_EncodeError(t *testing.T) {
buf := &bytes.Buffer{}
enc, err := ocf.NewEncoder(`"long"`, buf)
Expand Down

0 comments on commit 0ebcdff

Please sign in to comment.