Skip to content

Commit

Permalink
rerun_py: Improvements to dataframe and simple unit-test framework (r…
Browse files Browse the repository at this point in the history
…erun-io#7600)

### What
- Allow `select()` to support `*args`
- Return a `RecordBatchReader` instead of `list[RecordBatch]`
- Add a unit-test that creates a recording, opens it, and does some
stuff.

We still aren't able to iterate over the result-set without collecting
it, as we need to fix:

https://github.com/rerun-io/rerun/blob/b5aa6e95b68e931d981af52abb706422f46bdfa7/crates/store/re_dataframe2/src/engine.rs#L31-L34

Future work:
- Lots of fun stuff to do here with the zoo
  • Loading branch information
jleibs authored Oct 7, 2024
1 parent 2ce1c82 commit 42cdfef
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 8 deletions.
2 changes: 1 addition & 1 deletion rerun_py/rerun_bindings/rerun_bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class RecordingView:
"""Filter the view to only include data between the given index time values."""
...

def select(self, columns: Sequence[AnyColumn]) -> list[pa.RecordBatch]: ...
def select(self, *args: AnyColumn, columns: Optional[Sequence[AnyColumn]] = None) -> pa.RecordBatchReader: ...

class Recording:
"""A single recording."""
Expand Down
9 changes: 9 additions & 0 deletions rerun_py/rerun_sdk/rerun/_baseclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,15 @@ class ComponentMixin(ComponentBatchLike):
The class using the mixin must define the `_BATCH_TYPE` field, which should be a subclass of `BaseBatch`.
"""

@classmethod
def arrow_type(cls) -> pa.DataType:
"""
The pyarrow type of this batch.
Part of the `ComponentBatchLike` logging interface.
"""
return cls._BATCH_TYPE._ARROW_TYPE.storage_type # type: ignore[attr-defined, no-any-return]

def component_name(self) -> str:
"""
The name of the component.
Expand Down
52 changes: 45 additions & 7 deletions rerun_py/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

use std::collections::{BTreeMap, BTreeSet};

use arrow::{array::RecordBatch, pyarrow::PyArrowType};
use arrow::{
array::{RecordBatchIterator, RecordBatchReader},
pyarrow::PyArrowType,
};
use pyo3::{
exceptions::{PyRuntimeError, PyTypeError, PyValueError},
prelude::*,
types::PyDict,
types::{PyDict, PyTuple},
};

use re_chunk_store::{
Expand Down Expand Up @@ -349,28 +352,63 @@ pub struct PyRecordingView {
/// increasing when data is sent from a single process.
#[pymethods]
impl PyRecordingView {
#[pyo3(signature = (
*args,
columns = None
))]
fn select(
&self,
py: Python<'_>,
args: &Bound<'_, PyTuple>,
columns: Option<Vec<AnyColumn>>,
) -> PyResult<PyArrowType<Vec<RecordBatch>>> {
) -> PyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
let borrowed = self.recording.borrow(py);
let engine = borrowed.engine();

let mut query_expression = self.query_expression.clone();

// Coerce the arguments into a list of `ColumnSelector`s
let args: Vec<AnyColumn> = args
.iter()
.map(|arg| arg.extract::<AnyColumn>())
.collect::<PyResult<_>>()?;

if columns.is_some() && !args.is_empty() {
return Err(PyValueError::new_err(
"Cannot specify both `columns` and `args` in `select`.",
));
}

let columns = columns.or_else(|| if !args.is_empty() { Some(args) } else { None });

query_expression.selection =
columns.map(|cols| cols.into_iter().map(|col| col.into_selector()).collect());

let query_handle = engine.query(query_expression);

let batches: Result<Vec<_>, _> = query_handle
let schema = query_handle.schema();
let fields: Vec<arrow::datatypes::Field> =
schema.fields.iter().map(|f| f.clone().into()).collect();
let metadata = schema.metadata.clone().into_iter().collect();
let schema = arrow::datatypes::Schema::new(fields).with_metadata(metadata);

// TODO(jleibs): Need to keep the engine alive
/*
let reader = RecordBatchIterator::new(
query_handle
.into_batch_iter()
.map(|batch| batch.try_to_arrow_record_batch()),
std::sync::Arc::new(schema),
);
*/
let batches = query_handle
.into_batch_iter()
.map(|batch| batch.try_to_arrow_record_batch())
.collect();
.collect::<Vec<_>>();

let batches = batches.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
let reader = RecordBatchIterator::new(batches.into_iter(), std::sync::Arc::new(schema));

Ok(PyArrowType(batches))
Ok(PyArrowType(Box::new(reader)))
}

fn filter_range_sequence(&self, start: i64, end: i64) -> PyResult<Self> {
Expand Down
91 changes: 91 additions & 0 deletions rerun_py/tests/unit/test_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

import tempfile

import pyarrow as pa
import rerun as rr


class TestDataframe:
def setup_method(self) -> None:
rr.init("rerun_example_test_recording")

rr.log("points", rr.Points3D([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
rr.log("points", rr.Points3D([[10, 11, 12]], colors=[[255, 0, 0]]))

with tempfile.TemporaryDirectory() as tmpdir:
rrd = tmpdir + "/tmp.rrd"

rr.save(rrd)

self.recording = rr.dataframe.load_recording(rrd)

def test_full_view(self) -> None:
view = self.recording.view(index="log_time", contents="points")

batches = view.select()
table = pa.Table.from_batches(batches, batches.schema)

# row, log_time, log_tick, indicator, points, colors
assert table.num_columns == 6
assert table.num_rows == 2

def test_select_column(self) -> None:
view = self.recording.view(index="log_time", contents="points")
pos = rr.dataframe.ComponentColumnSelector("points", rr.components.Position3D)
batches = view.select(pos)

table = pa.Table.from_batches(batches, batches.schema)
# points
assert table.num_columns == 1
assert table.num_rows == 2

expected0 = pa.array(
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
type=rr.components.Position3D.arrow_type(),
)

expected1 = pa.array(
[
[10, 11, 12],
],
type=rr.components.Position3D.arrow_type(),
)

assert table.column(0)[0].values.equals(expected0)
assert table.column(0)[1].values.equals(expected1)

def test_view_syntax(self) -> None:
good_content_expressions = [
{"points": rr.components.Position3D},
{"points": [rr.components.Position3D]},
{"points": "rerun.components.Position3D"},
{"points/**": "rerun.components.Position3D"},
]

for expr in good_content_expressions:
view = self.recording.view(index="log_time", contents=expr)
batches = view.select()
table = pa.Table.from_batches(batches, batches.schema)

# row, log_time, log_tick, points
assert table.num_columns == 4
assert table.num_rows == 2

bad_content_expressions = [
{"points": rr.components.Position2D},
{"point": [rr.components.Position3D]},
]

for expr in bad_content_expressions:
view = self.recording.view(index="log_time", contents=expr)
batches = view.select()

# row, log_time, log_tick
table = pa.Table.from_batches(batches, batches.schema)
assert table.num_columns == 3
assert table.num_rows == 0

0 comments on commit 42cdfef

Please sign in to comment.