Skip to content

Commit

Permalink
Reorganize Parquet code (facebookresearch#561)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Jul 19, 2024
1 parent 6d07467 commit 06e2690
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 346 deletions.
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
* @cbalioglu
src/fairseq2/data/parquet/ @artemru
tests/integration/parquet/ @artemru
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ implicit_reexport = true

# TODO: fix!
[[tool.mypy.overrides]]
module = "fairseq2.data.parquet_tools.*,tests.integration.parquet.*,recipes.parquet.*"
module = "fairseq2.data.parquet.tools,tests.integration.parquet.*"
ignore_errors = true

[tool.pytest.ini_options]
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/data/parquet/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pyarrow.parquet as pq

from fairseq2.data.data_pipeline import DataPipeline, DataPipelineBuilder
from fairseq2.data.parquet_tools import (
from fairseq2.data.parquet.tools import (
BatchOutputType,
_TableWrapper,
_to_real_object,
Expand Down
344 changes: 344 additions & 0 deletions src/fairseq2/data/parquet/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Union

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from numpy.typing import NDArray
from pyarrow.dataset import get_partition_keys # requires pyarrow >= 13

from fairseq2.data.data_pipeline import DataPipeline, DataPipelineBuilder, read_sequence


@contextmanager
def pyarrow_cpu(nb_cpu: int) -> Generator[None, None, None]:
nb_cpu_old = pa.cpu_count()
nb_io_cpu_old = pa.io_thread_count()
pa.set_cpu_count(nb_cpu)
pa.set_io_thread_count(nb_cpu)
try:
yield
finally:
pa.set_cpu_count(nb_cpu_old)
pa.set_io_thread_count(nb_io_cpu_old)


@contextmanager
def torch_random_seed(seed: Optional[int] = None) -> Generator[None, None, None]:
if seed is not None:
torch.manual_seed(seed)
yield


NestedDict = Dict[str, "NestedDictValue"]
NestedDictValue = Union[torch.Tensor, List[str], pd.Series, NestedDict]
BatchOutputType = Union[pa.Table, pd.DataFrame, NestedDict]


def from_pyarrow_to_torch_tensor(
arr: Union[pa.Array, pa.ChunkedArray], strict: bool = True
) -> NestedDictValue:
"""
struct_array = pa.Array.from_pandas([{"x": 4, "y": "RR"}] * 10)
nest_array = pa.Array.from_pandas([[{'a': 1}, {'a': 2}]])
"""
# for future ideas https://arrow.apache.org/docs/python/generated/pyarrow.Tensor.html
# for sparse matrix support https://github.com/apache/arrow/blob/main/python/pyarrow/tests/test_sparse_tensor.py

if arr.null_count != 0:
raise ValueError("to torch conversion does not support null values")

if isinstance(arr, pa.ChunkedArray):
arr = arr.chunks[0] if arr.num_chunks == 1 else arr.combine_chunks()

arr_type = arr.type
if pa.types.is_primitive(arr_type):
return torch.from_numpy(arr.to_numpy(zero_copy_only=True))

try:
return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
except pa.ArrowInvalid:
pass

if pa.types.is_dictionary(arr_type):
return from_pyarrow_to_torch_tensor(arr.dictionary_decode())

if pa.types.is_string(arr_type):
return list(map(str, arr.to_pandas()))

if (
pa.types.is_list(arr_type) or pa.types.is_large_list(arr_type)
) and pa.types.is_primitive(arr_type.value_type):
return torch.nested.as_nested_tensor(
list(map(torch.from_numpy, arr.to_pandas()))
)

if pa.types.is_fixed_size_list(arr_type) and pa.types.is_primitive(
arr_type.value_type
):
return torch.from_numpy(np.reshape(arr.values, (-1, arr_type.list_size)))

if pa.types.is_struct(arr_type):
return {
arr_type.field(i).name: from_pyarrow_to_torch_tensor(arr.field(i))
for i in range(arr_type.num_fields)
}

if pa.types.is_nested(arr_type):
# TODO: deal with arr = [[{'a': 1}, {'a': 2}]]
pass

if strict:
raise NotImplementedError(f"{arr_type} cannot be converted to torch.Tensor")
else:
return arr


def pyarrow_table_to_torch_dict(tt: pa.Table, strict: bool = True) -> NestedDict:
return {
col: from_pyarrow_to_torch_tensor(tt[col], strict) for col in tt.column_names
}


def init_parquet_dataset(
parquet_path: str,
filters: Optional[pa.dataset.Expression] = None,
filesystem: Optional[pa.fs.FileSystem] = None,
) -> pq.ParquetDataset:
return pq.ParquetDataset(parquet_path, filters=filters, filesystem=filesystem)


def get_dataset_fragments(
dataset: pq.ParquetDataset, filters: pa.dataset.Expression
) -> List[pa.dataset.Fragment]:
"""
This could be simplified once `split_row_groups=True` is implemented at `pq.ParquetDataset`.
We could also return a generator instead of list (when getting full infos from S3 may be slow)
"""
return list(dataset._dataset.get_fragments(filters))


def split_fragment_in_row_groups(
fragment: pa.dataset.Fragment,
) -> List[pa.dataset.Fragment]:
return list(fragment.split_by_row_group())


def add_partitioning_values(
table: pa.Table, fragment: pa.dataset.Fragment, columns: Optional[List[str]]
) -> pa.Table:
"""
When loading a single fragment, pyarrow does not add the partitioning columns,
so we need to do it manually.
"""
for key, val in get_partition_keys(fragment.partition_expression).items():
if columns is None or key in columns:
values = pa.DictionaryArray.from_arrays(
np.zeros(len(table), dtype=np.int32), [val]
)
table = table.append_column(key, values)
return table


def load_one_fragment(
fragment: pa.dataset.Fragment, columns: Optional[List[str]] = None
) -> pa.Table:
fragment_columns = columns
if fragment_columns is not None:
fragment_columns = [
col for col in fragment_columns if col in fragment.physical_schema.names
]
fragment_table = fragment.to_table(columns=fragment_columns, use_threads=False)
fragment_table = add_partitioning_values(fragment_table, fragment, columns)
return fragment_table


def apply_filter(
table: pa.Table,
filters: Optional[pa.dataset.Expression] = None,
drop_null: bool = True,
) -> pa.Table:
if drop_null:
table = table.drop_null()
if filters is not None:
table = table.filter(filters)
return table


def concat_table(tables: List[pa.Table], combine: bool = True) -> pa.Table:
result = pa.concat_tables(
tables,
promote_options="permissive", # needed to get deal with empty segments
)
if combine:
result = result.combine_chunks()
return result


def compute_length_splits(
length_col: NDArray[np.int32], max_tokens: int
) -> List[NDArray[np.int32]]:
"""split sequence of length_col in the chunks such that total length is ~ max_tokens
countint the padding to max length of elements in a chunk
Args:
length_col (np.ndarray):
max_tokens (int):
Returns:
List[np.ndarray]: splits that contain indices over the original length_col
"""
argsort_ind = np.argsort(length_col)
# TODO: remove 0 lengths
sorted_length_col = length_col[argsort_ind]

splits = []
ptr = 0
for i, length in enumerate(sorted_length_col):
if length * (i - ptr) > max_tokens:
splits.append(argsort_ind[ptr : (i - 1)])
ptr = i - 1
if (
length <= max_tokens
): # we drop the last iteration if it results in a batch greater than max_tokens
splits.append(argsort_ind[ptr:])
return splits


def compute_rows_length(pa_array: pa.Array) -> NDArray[np.int32]:
type_ = pa_array.type
if pa.types.is_list(type_) or pa.types.is_large_list(type_):
length_col = pa.compute.list_value_length(pa_array).to_numpy()
elif pa.types.is_string(type_):
length_col = pa.compute.utf8_length(pa_array).to_numpy()
else:
length_col = np.asarray(pa_array.to_pandas().apply(len))

length_col = length_col.copy()
length_col[np.isnan(length_col)] = 0
return np.asarray(length_col, dtype=np.int32)


class _TableWrapper:
"""
class to avoid fairseq2 casting pa.Table to iterable objects
which currently fails
"""

def __init__(self, table: pa.Table) -> None:
self.table: pa.Table = table


def _to_real_object(x: Union[_TableWrapper, NestedDict]) -> BatchOutputType:
if isinstance(x, _TableWrapper):
return x.table
elif isinstance(x, list):
return [_to_real_object(e) for e in x]
elif isinstance(x, tuple):
return tuple(_to_real_object(e) for e in x)
else:
return x


def table_func_wrap(func): # type: ignore
def inner(*args): # type: ignore
fixed_args = [_to_real_object(x) for x in args]
result = func(*fixed_args)
if isinstance(result, (pa.Table, pd.DataFrame)):
result = _TableWrapper(result)
return result

return inner


def list_parquet_fragments(
parquet_path: str,
filters: Optional[pa.dataset.Expression] = None,
columns: Optional[List[str]] = None,
split_to_row_groups: bool = True,
filesystem: Optional[pa.fs.FileSystem] = None,
shuffle_window: Optional[int] = None,
seed: int = 2,
) -> DataPipelineBuilder:
dataset = init_parquet_dataset(parquet_path, filters=filters, filesystem=filesystem)
columns = columns or dataset.schema.names
if not set(columns).issubset(set(dataset.schema.names)):
raise ValueError(
f"columns {sorted(set(columns) - set(dataset.schema.names))} are not found in the dataset schema"
)

pipeline_builder = read_sequence(get_dataset_fragments(dataset, filters))

if shuffle_window is not None:
# shuffle them in full memory since fragments are already known
pipeline_builder = pipeline_builder.shuffle(shuffle_window=0, seed=seed)

if split_to_row_groups:
pipeline_builder = pipeline_builder.yield_from(
lambda fragment: read_sequence(
split_fragment_in_row_groups(fragment)
).and_return()
)
if shuffle_window is not None:
pipeline_builder = pipeline_builder.shuffle(
shuffle_window=shuffle_window, seed=seed + 1
)

return pipeline_builder


def build_iterator_over_one_table(
table: pa.Table,
order_by_length: Optional[str] = None,
batch_size: Optional[int] = None,
max_tokens: Optional[int] = None,
shuffle: bool = True,
seed: Optional[int] = None,
num_parallel_calls: int = 8,
) -> DataPipeline:
random_state = np.random.RandomState(seed)
if order_by_length is not None:
length_col = compute_rows_length(table[order_by_length])
# add small perturbation to avoid same sample appear together during different epochs
if shuffle:
perturbation = random_state.randint(
0,
np.quantile(length_col, 0.001).astype(np.int32) + 2,
len(length_col),
)
length_col += np.asarray(perturbation, dtype=np.int32)
else:
if shuffle:
length_col = random_state.randint(0, 2**23, len(table))
else:
length_col = np.zeros(len(table), dtype=np.int32)

if batch_size is not None:
order_tt = pa.Table.from_arrays(
[pa.array(np.argsort(length_col, kind="stable"))], ["order"]
)
batches = [ind["order"] for ind in order_tt.to_batches(batch_size)]
elif max_tokens is not None:
batches = compute_length_splits(length_col, max_tokens)
else:
raise ValueError("unknown batching method")

if shuffle:
batches = [batches[i] for i in random_state.permutation(len(batches))]

return (
read_sequence(batches)
.map(
table_func_wrap(lambda ind: table.take(ind).combine_chunks()),
num_parallel_calls=num_parallel_calls,
)
.and_return(max_num_warnings=4)
)
Loading

0 comments on commit 06e2690

Please sign in to comment.