From 95f6ec7f32ee32a5e9016dcbcb0845e192e6aa41 Mon Sep 17 00:00:00 2001 From: Achille Date: Tue, 6 Dec 2022 20:29:48 -0800 Subject: [PATCH] Optimize parquet.MergeRowGroups (#431) * add parquet.MultiRowWriter * return io.ErrShortWrite if a writer does not write all the rows * add parquet.RowBuffer[T] * benchmark + optimize sort of parquet.RowBuffer * add tests for ordering of parquet.RowBuffer * rename: firstIndexOfRepeatedColumn => direct * fix memory leak occuring when reading rows in chunks * add parquet.SortingWriter * use WriteRowGroups instead of CopyRows * refactor sorting configuration * add dedupe reader and writer * implement drop of duplicated rows in sorting writer * fix gofmt * fix Go 1.17 compilation * add merge benchmark based on parquet.RowBuffer to reduce noise * rewrite merge row groups implementation to use row comparator function * remove unused parquet.SortFunc APIs * fix column path * fix git merge conflict resolution --- column.go | 6 +- compare.go | 112 +++++++------- config.go | 5 +- file.go | 15 ++ merge.go | 356 ++++++++++++++++++++------------------------- merge_test.go | 28 ++-- row_buffer_test.go | 57 ++++++++ row_group.go | 19 +-- row_group_test.go | 26 +++- sort.go | 181 ----------------------- sort_test.go | 82 ----------- 11 files changed, 327 insertions(+), 560 deletions(-) delete mode 100644 sort.go delete mode 100644 sort_test.go diff --git a/column.go b/column.go index c030df4..1e25258 100644 --- a/column.go +++ b/column.go @@ -71,7 +71,7 @@ func (c *Column) Encoding() encoding.Encoding { return c.encoding } func (c *Column) Compression() compress.Codec { return c.compression } // Path of the column in the parquet schema. -func (c *Column) Path() []string { return c.path } +func (c *Column) Path() []string { return c.path[1:] } // Name returns the column name. func (c *Column) Name() string { return c.schema.Name } @@ -273,7 +273,7 @@ func (cl *columnLoader) open(file *File, path []string) (*Column, error) { file: file, schema: &file.metadata.Schema[cl.schemaIndex], } - c.path = c.path.append(c.schema.Name) + c.path = columnPath(path).append(c.schema.Name) cl.schemaIndex++ numChildren := int(c.schema.NumChildren) @@ -356,7 +356,7 @@ func (cl *columnLoader) open(file *File, path []string) (*Column, error) { } var err error - c.columns[i], err = cl.open(file, path) + c.columns[i], err = cl.open(file, c.path) if err != nil { return nil, fmt.Errorf("%s: %w", c.schema.Name, err) } diff --git a/compare.go b/compare.go index 8f6f33c..8101cb6 100644 --- a/compare.go +++ b/compare.go @@ -178,75 +178,81 @@ func lessBE128(v1, v2 *[16]byte) bool { } func compareRowsFuncOf(schema *Schema, sortingColumns []SortingColumn) func(Row, Row) int { - compareFuncs := make([]func(Row, Row) int, 0, len(sortingColumns)) + compareFuncs := make([]func(Row, Row) int, len(sortingColumns)) direct := true - for _, column := range schema.Columns() { - leaf, _ := schema.Lookup(column...) - if leaf.MaxRepetitionLevel > 0 { + forEachLeafColumnOf(schema, func(leaf leafColumn) { + if leaf.maxRepetitionLevel > 0 { direct = false } - for _, sortingColumn := range sortingColumns { - path1 := columnPath(column) - path2 := columnPath(sortingColumn.Path()) - - if path1.equal(path2) { - descending := sortingColumn.Descending() - optional := leaf.MaxDefinitionLevel > 0 - sortFunc := (func(Row, Row) int)(nil) - - if direct && !optional { - // This is an optimization for the common case where rows - // are sorted by non-optional, non-repeated columns. - // - // The sort function can make the assumption that it will - // find the column value at the current column index, and - // does not need to scan the rows looking for values with - // a matching column index. - // - // A second optimization consists in passing the column type - // directly to the sort function instead of an intermediary - // closure, which removes an indirection layer and improves - // throughput by ~20% in BenchmarkSortRowBuffer. - typ := leaf.Node.Type() - if descending { - sortFunc = compareRowsFuncOfIndexDescending(leaf.ColumnIndex, typ) - } else { - sortFunc = compareRowsFuncOfIndexAscending(leaf.ColumnIndex, typ) - } + if sortingIndex := searchSortingColumn(sortingColumns, leaf.path); sortingIndex < len(sortingColumns) { + sortingColumn := sortingColumns[sortingIndex] + descending := sortingColumn.Descending() + optional := leaf.maxDefinitionLevel > 0 + sortFunc := (func(Row, Row) int)(nil) + + if direct && !optional { + // This is an optimization for the common case where rows + // are sorted by non-optional, non-repeated columns. + // + // The sort function can make the assumption that it will + // find the column value at the current column index, and + // does not need to scan the rows looking for values with + // a matching column index. + // + // A second optimization consists in passing the column type + // directly to the sort function instead of an intermediary + // closure, which removes an indirection layer and improves + // throughput by ~20% in BenchmarkSortRowBuffer. + typ := leaf.node.Type() + if descending { + sortFunc = compareRowsFuncOfIndexDescending(leaf.columnIndex, typ) } else { - compare := leaf.Node.Type().Compare + sortFunc = compareRowsFuncOfIndexAscending(leaf.columnIndex, typ) + } + } else { + compare := leaf.node.Type().Compare - if descending { - compare = CompareDescending(compare) - } + if descending { + compare = CompareDescending(compare) + } - if optional { - if sortingColumn.NullsFirst() { - compare = CompareNullsFirst(compare) - } else { - compare = CompareNullsLast(compare) - } + if optional { + if sortingColumn.NullsFirst() { + compare = CompareNullsFirst(compare) + } else { + compare = CompareNullsLast(compare) } - - sortFunc = compareRowsFuncOfScan(leaf.ColumnIndex, compare) } - compareFuncs = append(compareFuncs, sortFunc) + sortFunc = compareRowsFuncOfScan(leaf.columnIndex, compare) } + + compareFuncs[sortingIndex] = sortFunc + } + }) + + // When some sorting columns were not found on the schema it is possible for + // the list of compare functions to still contain nil values; we compact it + // here to keep only the columns that we found comparators for. + n := 0 + for _, f := range compareFuncs { + if f != nil { + compareFuncs[n] = f + n++ } } // For the common case where rows are sorted by a single column, we can skip // looping over the list of sort functions. - switch len(compareFuncs) { + switch n { case 0: return compareRowsUnordered case 1: return compareFuncs[0] default: - return compareRowsFuncOfColumns(compareFuncs) + return compareRowsFuncOfColumns(compareFuncs[:n]) } } @@ -265,28 +271,28 @@ func compareRowsFuncOfColumns(compareFuncs []func(Row, Row) int) func(Row, Row) } //go:noinline -func compareRowsFuncOfIndexAscending(columnIndex int, typ Type) func(Row, Row) int { +func compareRowsFuncOfIndexAscending(columnIndex int16, typ Type) func(Row, Row) int { return func(row1, row2 Row) int { return typ.Compare(row1[columnIndex], row2[columnIndex]) } } //go:noinline -func compareRowsFuncOfIndexDescending(columnIndex int, typ Type) func(Row, Row) int { +func compareRowsFuncOfIndexDescending(columnIndex int16, typ Type) func(Row, Row) int { return func(row1, row2 Row) int { return -typ.Compare(row1[columnIndex], row2[columnIndex]) } } //go:noinline -func compareRowsFuncOfScan(columnIndex int, compare func(Value, Value) int) func(Row, Row) int { - columnIndexOfValues := ^int16(columnIndex) +func compareRowsFuncOfScan(columnIndex int16, compare func(Value, Value) int) func(Row, Row) int { + columnIndex = ^columnIndex return func(row1, row2 Row) int { i1 := 0 i2 := 0 for { - for i1 < len(row1) && row1[i1].columnIndex != columnIndexOfValues { + for i1 < len(row1) && row1[i1].columnIndex != columnIndex { i1++ } - for i2 < len(row2) && row2[i2].columnIndex != columnIndexOfValues { + for i2 < len(row2) && row2[i2].columnIndex != columnIndex { i2++ } diff --git a/config.go b/config.go index a6d547e..d84aac0 100644 --- a/config.go +++ b/config.go @@ -385,10 +385,7 @@ func (c *SortingConfig) Apply(options ...SortingOption) { } func (c *SortingConfig) ConfigureSorting(config *SortingConfig) { - *config = SortingConfig{ - SortingColumns: coalesceSortingColumns(c.SortingColumns, config.SortingColumns), - DropDuplicatedRows: c.DropDuplicatedRows, - } + *config = coalesceSortingConfig(*c, *config) } // FileOption is an interface implemented by types that carry configuration diff --git a/file.go b/file.go index 8b6c078..2381d9f 100644 --- a/file.go +++ b/file.go @@ -7,6 +7,7 @@ import ( "hash/crc32" "io" "sort" + "strings" "sync" "github.com/segmentio/encoding/thrift" @@ -415,6 +416,20 @@ type fileSortingColumn struct { func (s *fileSortingColumn) Path() []string { return s.column.Path() } func (s *fileSortingColumn) Descending() bool { return s.descending } func (s *fileSortingColumn) NullsFirst() bool { return s.nullsFirst } +func (s *fileSortingColumn) String() string { + b := new(strings.Builder) + if s.nullsFirst { + b.WriteString("nulls_first+") + } + if s.descending { + b.WriteString("descending(") + } else { + b.WriteString("ascending(") + } + b.WriteString(columnPath(s.Path()).String()) + b.WriteString(")") + return b.String() +} type fileColumnChunk struct { file *File diff --git a/merge.go b/merge.go index 6e6b8dd..114fc1c 100644 --- a/merge.go +++ b/merge.go @@ -8,8 +8,8 @@ import ( type mergedRowGroup struct { multiRowGroup - sorting []SortingColumn - sortFuncs []columnSortFunc + sorting []SortingColumn + compare func(Row, Row) int } func (m *mergedRowGroup) SortingColumns() []SortingColumn { @@ -19,273 +19,225 @@ func (m *mergedRowGroup) SortingColumns() []SortingColumn { func (m *mergedRowGroup) Rows() Rows { // The row group needs to respect a sorting order; the merged row reader // uses a heap to merge rows from the row groups. - return &mergedRowGroupRows{rowGroup: m, schema: m.schema} + rows := make([]Rows, len(m.rowGroups)) + for i := range rows { + rows[i] = m.rowGroups[i].Rows() + } + return &mergedRowGroupRows{ + merge: mergedRowReader{ + compare: m.compare, + readers: makeBufferedRowReaders(len(rows), func(i int) RowReader { return rows[i] }), + }, + rows: rows, + schema: m.schema, + } } type mergedRowGroupRows struct { - rowGroup *mergedRowGroup - schema *Schema - sorting []columnSortFunc - cursors []rowGroupCursor - values1 []Value - values2 []Value - seek int64 - index int64 - err error + merge mergedRowReader + rowIndex int64 + seekToRow int64 + rows []Rows + schema *Schema } -func (r *mergedRowGroupRows) init(m *mergedRowGroup) { - if r.schema != nil { - numColumns := numLeafColumnsOf(r.schema) - cursors := make([]bufferedRowGroupCursor, len(m.rowGroups)) - buffers := make([][]Value, int(numColumns)*len(m.rowGroups)) +func (r *mergedRowGroupRows) Close() (lastErr error) { + r.merge.close() + r.rowIndex = 0 + r.seekToRow = 0 - for i, rowGroup := range m.rowGroups { - cursors[i].reader = rowGroup.Rows() - cursors[i].columns, buffers = buffers[:numColumns:numColumns], buffers[numColumns:] + for _, rows := range r.rows { + if err := rows.Close(); err != nil { + lastErr = err } + } - r.cursors = make([]rowGroupCursor, 0, len(cursors)) - r.sorting = m.sortFuncs - - for i := range cursors { - c := rowGroupCursor(&cursors[i]) - // TODO: this is a bit of a weak model, it only works with types - // declared in this package; we may want to define an API to allow - // applications to participate in it. - if rd, ok := cursors[i].reader.(*rowGroupRows); ok { - rd.init() - // TODO: this optimization is disabled for now, there are - // remaining blockers: - // - // * The optimized merge of pages for non-overlapping ranges is - // not yet implemented in the mergedRowGroupRows type. - // - // * Using pages min/max to determine overlapping ranges does - // not work for repeated columns; sorting by repeated columns - // seems to be an edge case so probably not worth optimizing, - // we still need to find a way to disable the optimization in - // those cases. - // - //c = optimizedRowGroupCursor{rd} - } - - if err := c.readNext(); err != nil { - if err == io.EOF { - continue - } - r.err = err - return - } + return lastErr +} - r.cursors = append(r.cursors, c) +func (r *mergedRowGroupRows) ReadRows(rows []Row) (int, error) { + for r.rowIndex < r.seekToRow { + n := int(r.seekToRow - r.rowIndex) + if n > len(rows) { + n = len(rows) } - - heap.Init(r) + n, err := r.merge.ReadRows(rows[:n]) + if err != nil { + return 0, err + } + r.rowIndex += int64(n) } + + return r.merge.ReadRows(rows) } func (r *mergedRowGroupRows) SeekToRow(rowIndex int64) error { - if rowIndex >= r.index { - r.seek = rowIndex + if rowIndex >= r.rowIndex { + r.seekToRow = rowIndex return nil } - return fmt.Errorf("SeekToRow: merged row reader cannot seek backward from row %d to %d", r.index, rowIndex) + return fmt.Errorf("SeekToRow: merged row reader cannot seek backward from row %d to %d", r.rowIndex, rowIndex) +} + +func (r *mergedRowGroupRows) Schema() *Schema { + return r.schema } -func (r *mergedRowGroupRows) ReadRows(rows []Row) (n int, err error) { - if r.rowGroup != nil { - r.init(r.rowGroup) - r.rowGroup = nil +func MergeRowReaders(readers []RowReader, compare func(Row, Row) int) RowReader { + return &mergedRowReader{ + compare: compare, + readers: makeBufferedRowReaders(len(readers), func(i int) RowReader { return readers[i] }), } - if r.err != nil { - return 0, r.err +} + +func makeBufferedRowReaders(numReaders int, readerAt func(int) RowReader) []*bufferedRowReader { + buffers := make([]bufferedRowReader, numReaders) + readers := make([]*bufferedRowReader, numReaders) + + for i := range readers { + buffers[i].rows = readerAt(i) + readers[i] = &buffers[i] } - for n < len(rows) && len(r.cursors) > 0 { - min := r.cursors[0] - r.values1, err = min.readRow(r.values1[:0]) - if err != nil { - return n, err - } + return readers +} - if r.index >= r.seek { - rows[n] = append(rows[n][:0], r.values1...) - n++ +type mergedRowReader struct { + compare func(Row, Row) int + readers []*bufferedRowReader + initialized bool +} + +func (m *mergedRowReader) initialize() error { + for i, r := range m.readers { + switch err := r.read(); err { + case nil: + case io.EOF: + m.readers[i] = nil + default: + m.readers = nil + return err } - r.index++ + } - if err := min.readNext(); err != nil { - if err != io.EOF { - r.err = err - return n, err - } - c := r.cursors[0] - heap.Pop(r) - if err := c.close(); err != nil { - r.err = err - return n, err - } - } else { - heap.Fix(r, 0) + n := 0 + for _, r := range m.readers { + if r != nil { + m.readers[n] = r + n++ } } - if n < len(rows) { - err = io.EOF + clear := m.readers[n:] + for i := range clear { + clear[i] = nil } - return n, err + m.readers = m.readers[:n] + heap.Init(m) + return nil } -func (r *mergedRowGroupRows) Close() error { - var lastErr error - for i := range r.cursors { - if err := r.cursors[i].close(); err != nil { - lastErr = err - } +func (m *mergedRowReader) close() { + for _, r := range m.readers { + r.close() } - return lastErr + m.readers = nil } -// func (r *mergedRowGroupRows) WriteRowsTo(w RowWriter) (int64, error) { -// if r.rowGroup != nil { -// defer func() { r.rowGroup = nil }() -// switch dst := w.(type) { -// case RowGroupWriter: -// return dst.WriteRowGroup(r.rowGroup) -// case pageAndValueWriter: -// r.init(r.rowGroup) -// return r.writeRowsTo(dst) -// } -// } -// return CopyRows(w, struct{ RowReaderWithSchema }{r}) -// } - -func (r *mergedRowGroupRows) writeRowsTo(w pageAndValueWriter) (numRows int64, err error) { - // TODO: the intent of this method is to optimize the merge of rows by - // copying entire pages instead of individual rows when we detect ranges - // that don't overlap. - return -} +func (m *mergedRowReader) ReadRows(rows []Row) (n int, err error) { + if !m.initialized { + m.initialized = true -func (r *mergedRowGroupRows) Schema() *Schema { - return r.schema -} + if err := m.initialize(); err != nil { + return 0, err + } + } -func (r *mergedRowGroupRows) Len() int { - return len(r.cursors) -} + for n < len(rows) && len(m.readers) != 0 { + r := m.readers[0] -func (r *mergedRowGroupRows) Less(i, j int) bool { - cursor1 := r.cursors[i] - cursor2 := r.cursors[j] - - for _, sorting := range r.sorting { - r.values1 = cursor1.nextRowValuesOf(r.values1[:0], sorting.columnIndex) - r.values2 = cursor2.nextRowValuesOf(r.values2[:0], sorting.columnIndex) - comp := sorting.compare(r.values1, r.values2) - switch { - case comp < 0: - return true - case comp > 0: - return false + rows[n] = append(rows[n][:0], r.head()...) + n++ + + if err := r.next(); err != nil { + if err != io.EOF { + return n, err + } + heap.Pop(m) + } else { + heap.Fix(m, 0) } } - return false -} + if len(m.readers) == 0 { + err = io.EOF + } -func (r *mergedRowGroupRows) Swap(i, j int) { - r.cursors[i], r.cursors[j] = r.cursors[j], r.cursors[i] + return n, err } -func (r *mergedRowGroupRows) Push(interface{}) { - panic("BUG: unreachable") +func (m *mergedRowReader) Less(i, j int) bool { + return m.compare(m.readers[i].head(), m.readers[j].head()) < 0 } -func (r *mergedRowGroupRows) Pop() interface{} { - n := len(r.cursors) - 1 - c := r.cursors[n] - r.cursors = r.cursors[:n] - return c +func (m *mergedRowReader) Len() int { + return len(m.readers) } -type rowGroupCursor interface { - close() error - readRow(Row) (Row, error) - readNext() error - nextRowValuesOf([]Value, int16) []Value +func (m *mergedRowReader) Swap(i, j int) { + m.readers[i], m.readers[j] = m.readers[j], m.readers[i] } -type columnSortFunc struct { - columnIndex int16 - compare SortFunc +func (m *mergedRowReader) Push(x interface{}) { + panic("NOT IMPLEMENTED") } -type bufferedRowGroupCursor struct { - reader Rows - rowbuf [1]Row - columns [][]Value +func (m *mergedRowReader) Pop() interface{} { + i := len(m.readers) - 1 + r := m.readers[i] + m.readers = m.readers[:i] + return r } -func (cur *bufferedRowGroupCursor) close() error { - return cur.reader.Close() +type bufferedRowReader struct { + rows RowReader + off int32 + end int32 + buf [10]Row } -func (cur *bufferedRowGroupCursor) readRow(row Row) (Row, error) { - return append(row, cur.rowbuf[0]...), nil +func (r *bufferedRowReader) head() Row { + return r.buf[r.off] } -func (cur *bufferedRowGroupCursor) readNext() error { - _, err := cur.reader.ReadRows(cur.rowbuf[:]) - if err != nil { - return err - } - for i, c := range cur.columns { - cur.columns[i] = c[:0] - } - for _, v := range cur.rowbuf[0] { - columnIndex := v.Column() - cur.columns[columnIndex] = append(cur.columns[columnIndex], v) +func (r *bufferedRowReader) next() error { + if r.off++; r.off == r.end { + r.off = 0 + r.end = 0 + return r.read() } return nil } -func (cur *bufferedRowGroupCursor) nextRowValuesOf(values []Value, columnIndex int16) []Value { - return append(values, cur.columns[columnIndex]...) -} - -/* -type optimizedRowGroupCursor struct{ *rowGroupRows } - -func (cur optimizedRowGroupCursor) readRow(row Row) (Row, error) { return cur.ReadRow(row) } - -func (cur optimizedRowGroupCursor) readNext() error { - for i := range cur.columns { - c := &cur.columns[i] - if c.buffered() == 0 { - if err := c.readPage(); err != nil { - return err - } - } +func (r *bufferedRowReader) read() error { + if r.rows == nil { + return io.EOF + } + n, err := r.rows.ReadRows(r.buf[r.end:]) + if err != nil && n == 0 { + return err } + r.end += int32(n) return nil } -func (cur optimizedRowGroupCursor) nextRowValuesOf(values []Value, columnIndex int16) []Value { - col := &cur.columns[columnIndex] - err := col.readValues() - if err != nil { - values = append(values, Value{}) - } else { - values = append(values, col.buffer[col.offset]) - } - return values +func (r *bufferedRowReader) close() { + r.rows = nil + r.off = 0 + r.end = 0 } -*/ var ( _ RowReaderWithSchema = (*mergedRowGroupRows)(nil) - //_ RowWriterTo = (*mergedRowGroupRows)(nil) ) diff --git a/merge_test.go b/merge_test.go index 4a832d7..ab8cbe5 100644 --- a/merge_test.go +++ b/merge_test.go @@ -37,7 +37,6 @@ func (r *wrappedRows) Close() error { } func TestMergeRowGroupsCursorsAreClosed(t *testing.T) { - type model struct { A int } @@ -110,7 +109,7 @@ func BenchmarkMergeRowGroups(b *testing.B) { for n := 1; n <= numRowGroups; n++ { b.Run(fmt.Sprintf("groups=%d,rows=%d", n, n*rowsPerGroup), func(b *testing.B) { - mergedRowGroup, err := parquet.MergeRowGroups(rowGroups[:n]) + mergedRowGroup, err := parquet.MergeRowGroups(rowGroups[:n], options...) if err != nil { b.Fatal(err) } @@ -143,14 +142,20 @@ func BenchmarkMergeFiles(b *testing.B) { b.Run(test.scenario, func(b *testing.B) { schema := parquet.SchemaOf(test.model) - buffer := parquet.NewBuffer( + sortingOptions := []parquet.SortingOption{ + parquet.SortingColumns( + parquet.Ascending(schema.Columns()[0]...), + ), + } + + options := []parquet.RowGroupOption{ schema, parquet.SortingRowGroupConfig( - parquet.SortingColumns( - parquet.Ascending(schema.Columns()[0]...), - ), + sortingOptions..., ), - ) + } + + buffer := parquet.NewBuffer(options...) prng := rand.New(rand.NewSource(0)) files := make([]*parquet.File, numRowGroups) @@ -162,7 +167,12 @@ func BenchmarkMergeFiles(b *testing.B) { } sort.Sort(buffer) rowGroupBuffers[i].Reset() - writer := parquet.NewWriter(&rowGroupBuffers[i]) + writer := parquet.NewWriter(&rowGroupBuffers[i], + schema, + parquet.SortingWriterConfig( + sortingOptions..., + ), + ) _, err := copyRowsAndClose(writer, buffer.Rows()) if err != nil { b.Fatal(err) @@ -180,7 +190,7 @@ func BenchmarkMergeFiles(b *testing.B) { for n := 1; n <= numRowGroups; n++ { b.Run(fmt.Sprintf("groups=%d,rows=%d", n, n*rowsPerGroup), func(b *testing.B) { - mergedRowGroup, err := parquet.MergeRowGroups(rowGroups[:n]) + mergedRowGroup, err := parquet.MergeRowGroups(rowGroups[:n], options...) if err != nil { b.Fatal(err) } diff --git a/row_buffer_test.go b/row_buffer_test.go index b8184e6..e093085 100644 --- a/row_buffer_test.go +++ b/row_buffer_test.go @@ -300,3 +300,60 @@ func BenchmarkSortRowBuffer(b *testing.B) { sort.Sort(buf) } } + +func BenchmarkMergeRowBuffers(b *testing.B) { + type Row struct { + ID int64 `parquet:"id"` + } + + const ( + numBuffers = 100 + numRowsPerBuffer = 10e3 + ) + + rows := [numBuffers][numRowsPerBuffer]Row{} + nextID := int64(0) + for i := 0; i < numRowsPerBuffer; i++ { + for j := 0; j < numBuffers; j++ { + rows[j][i].ID = nextID + nextID++ + } + } + + options := []parquet.RowGroupOption{ + parquet.SortingRowGroupConfig( + parquet.SortingColumns( + parquet.Ascending("id"), + ), + ), + } + + rowGroups := make([]parquet.RowGroup, numBuffers) + for i := range rowGroups { + buffer := parquet.NewRowBuffer[Row](options...) + buffer.Write(rows[i][:]) + rowGroups[i] = buffer + } + + merge, err := parquet.MergeRowGroups(rowGroups, options...) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rows := merge.Rows() + _, err := parquet.CopyRows(discardRows{}, rows) + rows.Close() + if err != nil { + b.Fatal(err) + } + } +} + +type discardRows struct{} + +func (discardRows) WriteRows(rows []parquet.Row) (int, error) { + return len(rows), nil +} diff --git a/row_group.go b/row_group.go index b2cb141..d86e12b 100644 --- a/row_group.go +++ b/row_group.go @@ -213,24 +213,7 @@ func MergeRowGroups(rowGroups []RowGroup, options ...RowGroupOption) (RowGroup, } } - m.sortFuncs = make([]columnSortFunc, len(m.sorting)) - forEachLeafColumnOf(schema, func(leaf leafColumn) { - if sortingIndex := searchSortingColumn(m.sorting, leaf.path); sortingIndex < len(m.sorting) { - m.sortFuncs[sortingIndex] = columnSortFunc{ - columnIndex: leaf.columnIndex, - compare: sortFuncOf( - leaf.node.Type(), - &SortConfig{ - MaxRepetitionLevel: int(leaf.maxRepetitionLevel), - MaxDefinitionLevel: int(leaf.maxDefinitionLevel), - Descending: m.sorting[sortingIndex].Descending(), - NullsFirst: m.sorting[sortingIndex].NullsFirst(), - }, - ), - } - } - }) - + m.compare = compareRowsFuncOf(schema, m.sorting) return m, nil } diff --git a/row_group_test.go b/row_group_test.go index 1c3756c..cf3940e 100644 --- a/row_group_test.go +++ b/row_group_test.go @@ -420,21 +420,31 @@ func TestMergeRowGroups(t *testing.T) { for { _, err1 := expectedRows.ReadRows(row1) - _, err2 := mergedRows.ReadRows(row2) + n, err2 := mergedRows.ReadRows(row2) if err1 != err2 { - t.Fatalf("errors mismatched while comparing row %d/%d: want=%v got=%v", numRows, totalRows, err1, err2) + // ReadRows may or may not return io.EOF + // when it reads the last row, so we test + // that the reference RowReader has also + // reached the end. + if err1 == nil && err2 == io.EOF { + _, err1 = expectedRows.ReadRows(row1[:0]) + } + if err1 != io.EOF { + t.Fatalf("errors mismatched while comparing row %d/%d: want=%v got=%v", numRows, totalRows, err1, err2) + } } - if err1 != nil { - break + if n != 0 { + if !row1[0].Equal(row2[0]) { + t.Errorf("row at index %d/%d mismatch: want=%+v got=%+v", numRows, totalRows, row1[0], row2[0]) + } + numRows++ } - if !row1[0].Equal(row2[0]) { - t.Errorf("row at index %d/%d mismatch: want=%+v got=%+v", numRows, totalRows, row1[0], row2[0]) + if err1 != nil { + break } - - numRows++ } if numRows != totalRows { diff --git a/sort.go b/sort.go deleted file mode 100644 index 465d1ca..0000000 --- a/sort.go +++ /dev/null @@ -1,181 +0,0 @@ -package parquet - -// The SortConfig type carries configuration options used to generate sorting -// functions. -// -// SortConfig implements the SortOption interface so it can be used directly as -// argument to the SortFuncOf function, for example: -// -// sortFunc := parquet.SortFuncOf(columnType, &parquet.SortConfig{ -// Descending: true, -// NullsFirst: true, -// }) -type SortConfig struct { - MaxRepetitionLevel int - MaxDefinitionLevel int - Descending bool - NullsFirst bool -} - -// Apply applies options to c. -func (c *SortConfig) Apply(options ...SortOption) { - for _, opt := range options { - opt.ConfigureSort(c) - } -} - -// ConfigureSort satisfies the SortOption interface. -func (c *SortConfig) ConfigureSort(config *SortConfig) { - *c = *config -} - -// SortMaxRepetitionLevel constructs a configuration option which sets the -// maximum repetition level known to a sorting function. -// -// Defaults to zero, which represents a non-repeated column. -func SortMaxRepetitionLevel(level int) SortOption { - return sortOption(func(c *SortConfig) { c.MaxRepetitionLevel = level }) -} - -// SortMaxDefinitionLevel constructs a configuration option which sets the -// maximum definition level known to a sorting function. -// -// Defaults to zero, which represents a non-nullable column. -func SortMaxDefinitionLevel(level int) SortOption { - return sortOption(func(c *SortConfig) { c.MaxDefinitionLevel = level }) -} - -// SortDescending constructs a configuration option which inverts the order of a -// sorting function. -// -// Defaults to false, which means values are sorted in ascending order. -func SortDescending(descending bool) SortOption { - return sortOption(func(c *SortConfig) { c.Descending = descending }) -} - -// SortNullsFirst constructs a configuration option which places the null values -// first or last. -// -// Defaults to false, which means null values are placed last. -func SortNullsFirst(nullsFirst bool) SortOption { - return sortOption(func(c *SortConfig) { c.NullsFirst = nullsFirst }) -} - -// SortOption is an interface implemented by types that carry configuration -// options for sorting functions. -type SortOption interface { - ConfigureSort(*SortConfig) -} - -type sortOption func(*SortConfig) - -func (f sortOption) ConfigureSort(c *SortConfig) { f(c) } - -// SortFunc is a function type which compares two sets of column values. -// -// Slices with exactly one value must be passed to the function when comparing -// values of non-repeated columns. For repeated columns, there may be zero or -// more values in each slice, and the parameters may have different lengths. -// -// SortFunc is a low-level API which is usually useful to construct customize -// implementations of the RowGroup interface. -type SortFunc func(a, b []Value) int - -// SortFuncOf constructs a sorting function for values of the given type. -// -// The list of options contains the configuration used to construct the sorting -// function. -func SortFuncOf(t Type, options ...SortOption) SortFunc { - config := new(SortConfig) - config.Apply(options...) - return sortFuncOf(t, config) -} - -func sortFuncOf(t Type, config *SortConfig) (sort SortFunc) { - sort = sortFuncOfRequired(t) - - if config.Descending { - sort = sortFuncOfDescending(sort) - } - - switch { - case makeRepetitionLevel(config.MaxRepetitionLevel) > 0: - sort = sortFuncOfRepeated(sort, config) - case makeDefinitionLevel(config.MaxDefinitionLevel) > 0: - sort = sortFuncOfOptional(sort, config) - } - - return sort -} - -//go:noinline -func sortFuncOfDescending(sort SortFunc) SortFunc { - return func(a, b []Value) int { return -sort(a, b) } -} - -func sortFuncOfOptional(sort SortFunc, config *SortConfig) SortFunc { - if config.NullsFirst { - return sortFuncOfOptionalNullsFirst(sort) - } else { - return sortFuncOfOptionalNullsLast(sort) - } -} - -//go:noinline -func sortFuncOfOptionalNullsFirst(sort SortFunc) SortFunc { - return func(a, b []Value) int { - switch { - case a[0].IsNull(): - if b[0].IsNull() { - return 0 - } - return -1 - case b[0].IsNull(): - return +1 - default: - return sort(a, b) - } - } -} - -//go:noinline -func sortFuncOfOptionalNullsLast(sort SortFunc) SortFunc { - return func(a, b []Value) int { - switch { - case a[0].IsNull(): - if b[0].IsNull() { - return 0 - } - return +1 - case b[0].IsNull(): - return -1 - default: - return sort(a, b) - } - } -} - -//go:noinline -func sortFuncOfRepeated(sort SortFunc, config *SortConfig) SortFunc { - sort = sortFuncOfOptional(sort, config) - return func(a, b []Value) int { - n := len(a) - if n > len(b) { - n = len(b) - } - - for i := 0; i < n; i++ { - k := sort(a[i:i+1], b[i:i+1]) - if k != 0 { - return k - } - } - - return len(a) - len(b) - } -} - -//go:noinline -func sortFuncOfRequired(t Type) SortFunc { - return func(a, b []Value) int { return t.Compare(a[0], b[0]) } -} diff --git a/sort_test.go b/sort_test.go deleted file mode 100644 index 36f8be0..0000000 --- a/sort_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package parquet_test - -import ( - "sort" - "testing" - - "github.com/google/uuid" - "github.com/segmentio/parquet-go" -) - -func TestSortFunc(t *testing.T) { - sortFunc := parquet.SortFuncOf(parquet.String().Type(), - parquet.SortMaxDefinitionLevel(1), - parquet.SortDescending(true), - parquet.SortNullsFirst(true), - ) - - values := [][]parquet.Value{ - {parquet.ValueOf("A")}, - {parquet.ValueOf(nil)}, - {parquet.ValueOf(nil)}, - {parquet.ValueOf("C")}, - {parquet.ValueOf("B")}, - {parquet.ValueOf(nil)}, - } - - expect := [][]parquet.Value{ - {parquet.ValueOf(nil)}, - {parquet.ValueOf(nil)}, - {parquet.ValueOf(nil)}, - {parquet.ValueOf("C")}, - {parquet.ValueOf("B")}, - {parquet.ValueOf("A")}, - } - - sort.Slice(values, func(i, j int) bool { - return sortFunc(values[i], values[j]) < 0 - }) - - for i := range values { - if !parquet.Equal(values[i][0], expect[i][0]) { - t.Errorf("value at index %d mismatch: got=%+v want=%+v\n%+v\n%+v", i, expect[i], values[i], expect, values) - break - } - } -} - -func TestRepeatedUUIDSortFunc(t *testing.T) { - type testStruct struct { - List []uuid.UUID `parquet:"list"` - } - - s := parquet.SchemaOf(&testStruct{}) - - a := s.Deconstruct(nil, &testStruct{ - List: []uuid.UUID{ - uuid.MustParse("00000000-0000-0000-0000-000000000001"), - uuid.MustParse("00000000-0000-0000-0000-000000000002"), - }, - }) - - b := s.Deconstruct(nil, &testStruct{ - List: []uuid.UUID{ - uuid.MustParse("00000000-0000-0000-0000-000000000001"), - uuid.MustParse("00000000-0000-0000-0000-000000000002"), - uuid.MustParse("00000000-0000-0000-0000-000000000003"), - }, - }) - - // a and b are equal up until the third element, then a ends so a < b. - f := parquet.SortFuncOf( - s.Fields()[0].Type(), - parquet.SortDescending(false), - parquet.SortNullsFirst(true), - parquet.SortMaxDefinitionLevel(1), - parquet.SortMaxRepetitionLevel(1), - ) - cmp := f(a, b) - if cmp >= 0 { - t.Fatal("expected a < b, got compare value", cmp) - } -}