Skip to content

Commit

Permalink
Improve user experience for PadSequence collation (pytorch#361)
Browse files Browse the repository at this point in the history
Summary:
Velox functional today always mark dtype as nullable,
make the collation support nullable dtype.

Pull Request resolved: pytorch#361

Reviewed By: bearzx

Differential Revision: D36904409

Pulled By: wenleix

fbshipit-source-id: 42fb9578827f62e11cae44433d3469a6485132cc
  • Loading branch information
wenleix authored and facebook-github-bot committed Jun 3, 2022
1 parent 7f1d4c2 commit e5a4ca4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
16 changes: 14 additions & 2 deletions torcharrow/test/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def base_test_pad_sequence(self):
{
"int32": [[11, 12, 13, 14], [21, 22], [31], [41, 42, 43]],
"int64": [[11, 12, 13, 14], [21, 22], [31], [41, 42, 43]],
"nullable_int64": [[11, 12, 13, 14], [21, 22], [31], [41, 42, 43]],
"float32": [
[11.5, 12.5, 13.5, 14.5],
[21.5, 22.5],
Expand All @@ -163,6 +164,10 @@ def base_test_pad_sequence(self):
[
dt.Field("int32", dt.List(dt.int32)),
dt.Field("int64", dt.List(dt.int64)),
dt.Field(
"nullable_int64",
dt.List(dt.Int64(nullable=True), nullable=True),
),
dt.Field("float32", dt.List(dt.float32)),
]
),
Expand All @@ -173,13 +178,14 @@ def base_test_pad_sequence(self):
{
"int32": tap.PadSequence(padding_value=-1),
"int64": tap.PadSequence(padding_value=-2),
"nullable_int64": tap.PadSequence(padding_value=-2),
"float32": tap.PadSequence(padding_value=-3),
}
)

# named tuple with 3 fields
# named tuple with 4 fields
self.assertTrue(isinstance(collated_tensors, tuple))
self.assertEquals(len(collated_tensors), 3)
self.assertEquals(len(collated_tensors), 4)

self.assertEquals(collated_tensors.int32.dtype, torch.int32)
self.assertEquals(
Expand All @@ -193,6 +199,12 @@ def base_test_pad_sequence(self):
[[11, 12, 13, 14], [21, 22, -2, -2], [31, -2, -2, -2], [41, 42, 43, -2]],
)

self.assertEquals(collated_tensors.nullable_int64.dtype, torch.int64)
self.assertEquals(
collated_tensors.nullable_int64.tolist(),
[[11, 12, 13, 14], [21, 22, -2, -2], [31, -2, -2, -2], [41, 42, 43, -2]],
)

self.assertEquals(collated_tensors.float32.dtype, torch.float32)
self.assertEquals(
collated_tensors.float32.tolist(),
Expand Down
23 changes: 19 additions & 4 deletions torcharrow/velox_rt/list_column_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import array as ar
import warnings
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union

import torcharrow as ta
import torcharrow._torcharrow as velox
Expand Down Expand Up @@ -201,16 +201,31 @@ def _to_tensor_pad_sequence(self, batch_first: bool, padding_value):
# TODO: pad_sequence also works for nest numeric list
# pyre-fixme[16]: `DType` has no attribute `item_dtype`.
assert dt.is_numerical(self.dtype.item_dtype)
assert not self.dtype.nullable
assert self.null_count == 0

import torch
from torch.nn.utils.rnn import pad_sequence

packed_list: pytorch.PackedList = self._to_tensor_default()
packed_list: Union[
pytorch.WithPresence, pytorch.PackedList
] = self._to_tensor_default()

if isinstance(packed_list, pytorch.WithPresence):
# presence tensor will be provided if dtype is nullable.
# However, as long as there is no null value, the collation can still be done, and we just need to discard the presence tensor
assert torch.all(packed_list.presence)
packed_list = packed_list.values

flattened_values = packed_list.values
if isinstance(flattened_values, pytorch.WithPresence):
# presence tensor will be provided if item_type is nullable.
# However, as long as there is no null value, the collation can still be done, and we just need to discard the presence tensor
assert torch.all(flattened_values.presence)
flattened_values = flattened_values.values

# pyre-fixme[11]: Annotation `tensor` is not defined as a type.
unpad_tensors: List[torch.tensor] = [
packed_list.values[packed_list.offsets[i] : packed_list.offsets[i + 1]]
flattened_values[packed_list.offsets[i] : packed_list.offsets[i + 1]]
for i in range(len(self))
]

Expand Down

0 comments on commit e5a4ca4

Please sign in to comment.