Skip to content

Commit

Permalink
Delegate IColumn.fill_null/drop_null to Arrow (pytorch#96)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#96

As title

NOTE: the changes of diff only delegate IColumn.fill_null to Arrow since TorchArrow is using Apache Arrow 2.0.0, which doesn't support `drop_null` until Apache Arrow 6: https://arrow.apache.org/docs/python/api/compute.html#selections

Reviewed By: wenleix, OswinC

Differential Revision: D32770009

fbshipit-source-id: 14ad956d406f4125b176eb5cbb65eb16492c1cd1
  • Loading branch information
Bo Huang authored and wenleix committed Dec 13, 2021
1 parent 98c07fe commit 3d06b72
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 40 deletions.
17 changes: 6 additions & 11 deletions torcharrow/icolumn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .dispatcher import Device
from .expression import expression
from .interop import from_arrow
from .scope import Scope
from .trace import trace, traceproperty

Expand Down Expand Up @@ -1129,18 +1130,12 @@ def fill_null(self, fill_value: ty.Union[dt.ScalarTypes, ty.Dict]):
dtype: int64, length: 4, null_count: 0
"""
self._prototype_support_warning("fill_null")

if not isinstance(fill_value, IColumn._scalar_types):
raise TypeError(f"fill_null with {type(fill_value)} is not supported")
if isinstance(fill_value, IColumn._scalar_types):
res = Scope._EmptyColumn(self.dtype.constructor(nullable=False))
for m, i in self._items():
if not m:
res._append_value(i)
else:
res._append_value(fill_value)
return res._finalize()
import pyarrow.compute as pc

arr = pc.fill_null(self.to_arrow(), fill_value)
arr_dtype = self.dtype.with_null(nullable=False)
return from_arrow(arr, dtype=arr_dtype, device=self.device)
else:
raise TypeError(f"fill_null with {type(fill_value)} is not supported")

Expand Down
8 changes: 4 additions & 4 deletions torcharrow/test/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,13 +537,13 @@ def base_test_python_comparison_ops(self):
assert c == c.append([None])

def base_test_na_handling(self):
c = ta.DataFrame({"a": [None, 2, 17.0]}, device=self.device)
c = ta.DataFrame({"a": [None, 2, 17]}, device=self.device)

self.assertEqual(list(c.fill_null(99.0)), [(i,) for i in [99.0, 2, 17.0]])
self.assertEqual(list(c.drop_null()), [(i,) for i in [2, 17.0]])
self.assertEqual(list(c.fill_null(99)), [(i,) for i in [99, 2, 17]])
self.assertEqual(list(c.drop_null()), [(i,) for i in [2, 17]])

c = c.append([(2,)])
self.assertEqual(list(c.drop_duplicates()), [(i,) for i in [None, 2, 17.0]])
self.assertEqual(list(c.drop_duplicates()), [(i,) for i in [None, 2, 17]])

# duplicates with subset
d = ta.DataFrame(
Expand Down
8 changes: 4 additions & 4 deletions torcharrow/test/test_numerical_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ def base_test_operators(self):
# TODO Test type promotion rules

def base_test_na_handling(self):
c = ta.Column([None, 2, 17.0], device=self.device)
c = ta.Column([None, 2, 17], device=self.device)

self.assertEqual(list(c.fill_null(99.0)), [99.0, 2, 17.0])
self.assertEqual(list(c.drop_null()), [2.0, 17.0])
self.assertEqual(list(c.fill_null(99)), [99, 2, 17])
self.assertEqual(list(c.drop_null()), [2, 17])

c = c.append([2])
self.assertEqual(set(c.drop_duplicates()), {None, 2, 17.0})
self.assertEqual(set(c.drop_duplicates()), {None, 2, 17})

def base_test_agg_handling(self):
import functools
Expand Down
21 changes: 0 additions & 21 deletions torcharrow/velox_rt/numerical_column_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,27 +631,6 @@ def round(self, decimals=0):

# data cleaning -----------------------------------------------------------

@trace
@expression
def fill_null(self, fill_value: Union[dt.ScalarTypes, Dict]):
self._prototype_support_warning("fill_null")

if not isinstance(fill_value, IColumn._scalar_types):
raise TypeError(f"fill_null with {type(fill_value)} is not supported")
if not self.is_nullable:
return self
else:
col = velox.Column(get_velox_type(self.dtype))
for i in range(len(self)):
if self._getmask(i):
if isinstance(fill_value, Dict):
raise NotImplementedError()
else:
col.append(fill_value)
else:
col.append(self._getdata(i))
return ColumnFromVelox._from_velox(self.device, self.dtype, col, True)

@trace
@expression
def drop_null(self, how="any"):
Expand Down

0 comments on commit 3d06b72

Please sign in to comment.