Skip to content

Commit

Permalink
enable null check when constructing columns (pytorch#129)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#129

with the fix it now throws exception:

https://pxl.cl/1W1lQ

Reviewed By: wenleix

Differential Revision: D33281970

fbshipit-source-id: 9db44b61c75d3a6279c05e09c7d9ce8a2c672c5a
  • Loading branch information
Qian Xu authored and wenleix committed Jan 14, 2022
1 parent 931040d commit e8ff0d9
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 4 deletions.
9 changes: 8 additions & 1 deletion torcharrow/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,14 @@ def _Column(
raise ValueError("Column cannot infer type from data")
if dt.contains_tuple(dtype):
raise TypeError("Cannot infer type from Python tuple")
return Scope._FromPyList(data, dtype, device)

result = Scope._FromPyList(data, dtype, device)
# since dtype is known, check the nullability
if not dtype.nullable and result.null_count != 0:
raise ValueError(
f"None found in the list for non-nullable type: {dtype}"
)
return result

if Scope._is_column(data):
dtype = dtype or data.dtype
Expand Down
19 changes: 18 additions & 1 deletion torcharrow/test/test_list_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,28 @@ def base_test_nonempty(self):
self.assertEqual(c[i], lst)
self.assertIsNone(c[4])

c2 = ta.Column([None, None, [1, 2, 3]], dt.List(dt.int64), device=self.device)
c2 = ta.Column(
[None, None, [1, 2, 3]],
dt.List(dt.int64, nullable=True),
device=self.device,
)
self.assertIsNone(c2[0])
self.assertIsNone(c2[1])
self.assertEqual(c2[2], [1, 2, 3])

def base_test_list_with_none(self):
with self.assertRaises(ValueError) as ex:
ta.Column(
[None, None, [1, 2, 3]],
dt.List(dt.int64),
device=self.device,
)
self.assertTrue(
"None found in the list for non-nullable type: List(int64)"
in str(ex.exception),
f"Exception message is not as expected: {str(ex.exception)}",
)

def base_test_append_concat(self):
base_list = [["hello", "world"], ["how", "are", "you"]]
sf1 = ta.Column(base_list, dtype=dt.List(dt.string), device=self.device)
Expand Down
3 changes: 3 additions & 0 deletions torcharrow/test/test_list_column_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def test_empty(self):
def test_nonempty(self):
self.base_test_nonempty()

def test_list_with_none(self):
self.base_test_list_with_none()

def test_append_concat(self):
return self.base_test_append_concat()

Expand Down
4 changes: 3 additions & 1 deletion torcharrow/test/test_map_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def base_test_map(self):
self.assertIsNone(c[2])

c2 = ta.Column(
[None, None, {"foo": 123}], dt.Map(dt.string, dt.int64), device=self.device
[None, None, {"foo": 123}],
dt.Map(dt.string, dt.int64, nullable=True),
device=self.device,
)
self.assertIsNone(c2[0])
self.assertIsNone(c2[1])
Expand Down
3 changes: 3 additions & 0 deletions torcharrow/test/test_string_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def base_test_string_categorization_methods(self):
list(
ta.Column(
["", "abc", "XYZ", "123", "XYZ123", "äöå", ",.!", None],
dt.String(True),
device=self.device,
).str.isalpha()
),
Expand All @@ -51,6 +52,7 @@ def base_test_string_categorization_methods(self):
list(
ta.Column(
["", "abc", "XYZ", "123", "XYZ123", "äöå", ",.!", None],
dt.String(True),
device=self.device,
).str.isalnum()
),
Expand All @@ -61,6 +63,7 @@ def base_test_string_categorization_methods(self):
list(
ta.Column(
["", "abc", "XYZ", "123", "XYZ123", "äöå", ",.!", "\u00B2", None],
dt.String(True),
device=self.device,
).str.isdecimal()
),
Expand Down
5 changes: 4 additions & 1 deletion torcharrow/test/transformation/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def setUpClass(cls):
"struct_list": ta.Column(
[[(1, "a"), (2, "b")], [(3, "c")], None],
dtype=dt.List(
dt.Struct([dt.Field("f1", dt.int64), dt.Field("f2", dt.string)])
dt.Struct(
[dt.Field("f1", dt.int64), dt.Field("f2", dt.string)]
),
nullable=True,
),
),
}
Expand Down

0 comments on commit e8ff0d9

Please sign in to comment.