Skip to content

Commit

Permalink
Add in-place ArrowBlockAccessor::random_shuffle (ray-project#50594)
Browse files Browse the repository at this point in the history
## Why are these changes needed?
Add in-place ArrowBlockAccessor::random_shuffle
Addresses ray-project#42146

---------

Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: srinathk10 <[email protected]>
  • Loading branch information
srinathk10 authored Feb 15, 2025
1 parent f0942dd commit e18a59a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
12 changes: 7 additions & 5 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,14 @@ def slice(self, start: int, end: int, copy: bool = False) -> "pyarrow.Table":
return view

def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table":
# TODO(swang): Creating this np.array index can add a lot of memory
# pressure when there are a large number of small rows. Investigate
# random shuffling in place to reduce memory pressure.
# See https://github.com/ray-project/ray/issues/42146.
num_rows = self.num_rows()
if num_rows == 0:
return pyarrow.table([])
random = np.random.RandomState(random_seed)
return self.take(random.permutation(self.num_rows()))
shuffled_indices = np.arange(num_rows)
# Shuffle all rows in-place
random.shuffle(shuffled_indices)
return self.take(pyarrow.array(shuffled_indices))

def schema(self) -> "pyarrow.lib.Schema":
return self._table.schema
Expand Down
25 changes: 25 additions & 0 deletions python/ray/data/tests/test_arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,31 @@ def test_append_column(ray_start_regular_shared):
assert actual_block.equals(expected_block)


def test_random_shuffle(ray_start_regular_shared):
TOTAL_ROWS = 10000
table = pa.table({"id": pa.array(range(TOTAL_ROWS))})
block_accessor = ArrowBlockAccessor(table)

# Perform the random shuffle
shuffled_table = block_accessor.random_shuffle(random_seed=None)
assert shuffled_table.num_rows == TOTAL_ROWS

# Access the shuffled data
block_accessor = ArrowBlockAccessor(shuffled_table)
shuffled_data = block_accessor.to_pandas()["id"].tolist()
original_data = list(range(TOTAL_ROWS))

# Ensure the shuffled data is not identical to the original
assert (
shuffled_data != original_data
), "Shuffling should result in a different order"

# Ensure the entire set of original values is still in the shuffled dataset
assert (
sorted(shuffled_data) == original_data
), "The shuffled data should contain all the original values"


def test_register_arrow_types(tmp_path):
# Test that our custom arrow extension types are registered on initialization.
ds = ray.data.from_items(np.zeros((8, 8, 8), dtype=np.int64))
Expand Down

0 comments on commit e18a59a

Please sign in to comment.