Skip to content

Commit

Permalink
Support indexing of the underlying tensors for nested tensors (pytorc…
Browse files Browse the repository at this point in the history
  • Loading branch information
YifanShenSZ authored and pytorchmergebot committed Jun 8, 2022
1 parent e85f3b5 commit 6ad51c9
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 31 deletions.
66 changes: 40 additions & 26 deletions aten/src/ATen/TensorIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,19 @@ static inline Tensor applySlice(
int64_t step,
bool disable_slice_optimization,
const at::Device& self_device,
const IntArrayRef& self_sizes) {
const c10::optional<IntArrayRef> & self_sizes) {
// TODO: implement negative step
TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");

// Skip this optimization if we are tracing, as the trace may be polymorphic
// over the shape of the `self` tensor, and we still want to record
// the slice.
int64_t length = (self_device == at::kCPU || self_device == at::kCUDA) ? self_sizes[dim] : self.size(dim);
if (!disable_slice_optimization && start == 0 && stop == length && step == 1) {
return self;
// See NOTE [nested tensor size for indexing]
if (self_sizes.has_value()) {
// Skip this optimization if we are tracing, as the trace may be polymorphic
// over the shape of the `self` tensor, and we still want to record
// the slice.
int64_t length = (self_device == at::kCPU || self_device == at::kCUDA) ? (*self_sizes)[dim] : self.size(dim);
if (!disable_slice_optimization && start == 0 && stop == length && step == 1) {
return self;
}
}
return self.slice(dim, start, stop, step);
}
Expand All @@ -218,16 +221,19 @@ static inline Tensor applySelect(
int64_t index,
int64_t real_dim,
const at::Device& /*self_device*/,
const IntArrayRef& self_sizes) {
TORCH_CHECK_INDEX(
!(index == 0 && dim == 0 && self_sizes.size() == 0),
"invalid index of a 0-dim tensor. ",
"Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");

int64_t size = self_sizes[dim];
TORCH_CHECK_INDEX(
index >= -size && index < size,
"index ", index, " is out of bounds for dimension ", real_dim, " with size ", size);
const c10::optional<IntArrayRef> & self_sizes) {
// See NOTE [nested tensor size for indexing]
if (self_sizes.has_value()) {
TORCH_CHECK_INDEX(
!(index == 0 && dim == 0 && self_sizes->size() == 0),
"invalid index of a 0-dim tensor. ",
"Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");

int64_t size = (*self_sizes)[dim];
TORCH_CHECK_INDEX(
index >= -size && index < size,
"index ", index, " is out of bounds for dimension ", real_dim, " with size ", size);
}

// if the index is negative, do not normalize it because that would fix the index
// on the current tensor size in the tracer.
Expand Down Expand Up @@ -373,7 +379,7 @@ static inline Tensor handleDimInMultiDimIndexing(
std::vector<Tensor>& outIndices,
bool disable_slice_optimization,
const at::Device& original_tensor_device,
const IntArrayRef& prev_dim_result_sizes) {
const c10::optional<IntArrayRef> & prev_dim_result_sizes) {
if (index.is_integer()) {
return impl::applySelect(prev_dim_result, *dim_ptr, index.integer(), real_dim, original_tensor_device, prev_dim_result_sizes);
} else if (index.is_slice()) {
Expand Down Expand Up @@ -431,17 +437,22 @@ static inline Tensor applySlicing(
std::vector<Tensor>& outIndices,
bool disable_slice_optimization,
const at::Device& self_device,
const IntArrayRef& self_sizes) {
const c10::optional<IntArrayRef> & self_sizes) {
int64_t dim = 0;
int64_t specified_dims = impl::count_specified_dimensions(indices);

TORCH_CHECK_INDEX(
specified_dims <= (int64_t)self_sizes.size(),
"too many indices for tensor of dimension ", (int)self_sizes.size());
// See NOTE [nested tensor size for indexing]
if (self_sizes.has_value()) {
TORCH_CHECK_INDEX(
specified_dims <= (int64_t)self_sizes->size(),
"too many indices for tensor of dimension ", (int)self_sizes->size());
}

Tensor result = self;
for (const auto i : c10::irange(indices.size())) {
auto& obj = indices[i];
// See NOTE [nested tensor size for indexing]
c10::optional<IntArrayRef> result_sizes = result.is_nested() ? c10::optional<IntArrayRef>(c10::nullopt) : c10::optional<IntArrayRef>(result.sizes());
result = handleDimInMultiDimIndexing(
/*prev_dim_result=*/result,
/*original_tensor=*/self,
Expand All @@ -452,7 +463,7 @@ static inline Tensor applySlicing(
/*outIndices=*/outIndices,
/*disable_slice_optimization=*/disable_slice_optimization,
/*original_tensor_device=*/self_device,
/*prev_dim_result_sizes=*/result.sizes());
/*prev_dim_result_sizes=*/result_sizes);
}
return result;
}
Expand Down Expand Up @@ -495,7 +506,10 @@ static inline Tensor dispatch_index_put_(Tensor& self, std::vector<Tensor>&& ind
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing functions from Python ]
static inline Tensor get_item(const Tensor& self, const ArrayRef<TensorIndex>& indices, bool disable_slice_optimization = false) {
at::Device self_device = self.device();
IntArrayRef self_sizes = self.sizes();
// NOTE [nested tensor size for indexing]
// nested tensor does not have a size (yet) so for now we represent its size as null
// may need to be changed after we reach a better solution for nested tensor size
c10::optional<IntArrayRef> self_sizes = self.is_nested() ? c10::optional<IntArrayRef>(c10::nullopt) : c10::optional<IntArrayRef>(self.sizes());

// handle simple types: integers, slices, none, ellipsis, bool
if (indices.size() == 1) {
Expand All @@ -511,7 +525,7 @@ static inline Tensor get_item(const Tensor& self, const ArrayRef<TensorIndex>& i
index.slice().step(),
/*disable_slice_optimization=*/true,
self_device,
self_sizes);
*self_sizes);
} else if (index.is_none()) {
return self.unsqueeze(0);
} else if (index.is_ellipsis()) {
Expand All @@ -526,7 +540,7 @@ static inline Tensor get_item(const Tensor& self, const ArrayRef<TensorIndex>& i
}

std::vector<Tensor> tensorIndices;
Tensor sliced = impl::applySlicing(self, indices, tensorIndices, disable_slice_optimization, self_device, self_sizes);
Tensor sliced = impl::applySlicing(self, indices, tensorIndices, disable_slice_optimization, self_device, *self_sizes);
if (tensorIndices.empty()) {
if (sliced.is_same(self)) {
// ensure we return a shallow copy for things like x[...]
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4069,6 +4069,7 @@
dispatch:
CompositeExplicitAutograd: select
SparseCsrCPU, SparseCsrCUDA: select_sparse_csr
NestedTensorCPU, NestedTensorCUDA: select_nested

- func: select_backward(Tensor grad_output, int[] input_sizes, int dim, int index) -> Tensor
variants: function
Expand Down
52 changes: 52 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,58 @@ Tensor& NestedTensor_mul__Tensor(Tensor& self, const Tensor& other) {
});
}

Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) {
TORCH_CHECK(
dim == 0,
"NestedTensor can only be selected along dimension 0 ",
"got dimension ", dim, " instead."
);
auto self_ptr = get_nested_tensor_impl(self);
// buffer contains the underlying data in a contiguous vector
const at::Tensor & buffer = self_ptr->get_buffer();
int64_t numel = buffer.numel();
TORCH_CHECK(
numel > 0,
"cannot index an empty nested tensor."
);
// nested_tensor[i] = i-th original tensor
int64_t ntensors = *(self_ptr->opt_size(0));
int64_t positive_index = at::maybe_wrap_dim(index, ntensors);
// determine the memory segment of the i-th original tensor
Tensor sizemat = get_nested_size_tensor(self);
int64_t original_dim = sizemat.size(1);
const int64_t * sizemat_ptr = sizemat.data_ptr<int64_t>();
// start of the segment
int64_t start = 0, sizemat_offset = 0;
for (int64_t i = 0; i < positive_index; i++) {
int64_t row_product = sizemat_ptr[sizemat_offset];
sizemat_offset++;
for (int64_t j = 1; j < original_dim; j++) {
row_product *= sizemat_ptr[sizemat_offset];
sizemat_offset++;
}
start += row_product;
}
// btw determine the shape of the i-th original tensor
IntArrayRef shape(sizemat_ptr + sizemat_offset, sizemat_ptr + sizemat_offset + original_dim);
// stop of the segment
int64_t stop;
if (positive_index == ntensors - 1) {
stop = numel;
}
else {
int64_t row_product = sizemat_ptr[sizemat_offset];
sizemat_offset++;
for (int64_t j = 1; j < original_dim; j++) {
row_product *= sizemat_ptr[sizemat_offset];
sizemat_offset++;
}
stop = start + row_product;
}
// extract the memory segment then reshape to the original shape
return buffer.slice(0, start, stop).view(shape);
}

Tensor clone_nested(
const Tensor& self,
c10::optional<c10::MemoryFormat> optional_memory_format) {
Expand Down
36 changes: 36 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,42 @@ def test_device_checks(self, device):
is_cuda = 'cuda' in str(device)
self.assertEqual(nt.is_cuda, is_cuda)

@dtypes(torch.float, torch.float16, torch.double)
def test_nested_tensor_indexing(self, device, dtype):
# edge case: empty nested tensor
nt0 = torch.nested_tensor([])
self.assertRaisesRegex(
RuntimeError,
"cannot index an empty nested tensor",
lambda: nt0[0]
)
# normal case
x0 = torch.randn((2, 5), device=device, dtype=dtype)
x1 = torch.randn((3, 4), device=device, dtype=dtype)
nt = torch.nested_tensor([x0, x1])
# single index: only support integer in the batch dimension
self.assertEqual(nt[0], x0)
self.assertEqual(nt[-1], x1)
self.assertRaises(IndexError, lambda: nt[2])
self.assertRaises(IndexError, lambda: nt[-3])
self.assertRaises(NotImplementedError, lambda: nt[:])
self.assertRaises(NotImplementedError, lambda: nt[None])
self.assertRaises(NotImplementedError, lambda: nt[...])
# tuple of indices: only support integer in the batch dimension
# + all possible indexing in the original tensor dimensions
self.assertEqual(nt[0, 0, 0], x0[0, 0])
self.assertEqual(nt[0, 1, :], x0[1, :])
self.assertEqual(nt[1, ...], x1)
self.assertRaises(IndexError, lambda: nt[1, 4, 2])
self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1])
# make sure indexing returns a view
nt[0].fill_(100.0)
answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5))
self.assertEqual(nt[0], answer)
nt[1, 1, :].fill_(200.0)
answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
self.assertEqual(nt[1, 1, :], answer)

# Helper functions for testing elementwise ops
def random_nt(self, device, dtype, num_tensors, max_dims, min_dims=None):
if min_dims is None:
Expand Down
20 changes: 15 additions & 5 deletions torch/csrc/autograd/python_variable_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,25 @@ static inline Variable applySlicing(
variable_list& outIndices,
bool is_tracing,
const at::Device& self_device,
const IntArrayRef& self_sizes,
const c10::optional<IntArrayRef> & self_sizes,
int64_t specified_dims) {
int64_t size = PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
int64_t dim = 0;

if (specified_dims > (int64_t)self_sizes.size()) {
throw IndexError("too many indices for tensor of dimension %d", (int)(self_sizes.size()));
// See NOTE [nested tensor size for indexing]
if (self_sizes.has_value()) {
TORCH_CHECK_INDEX(
specified_dims <= (int64_t)self_sizes->size(),
"too many indices for tensor of dimension ", (int)self_sizes->size());
}

Variable result = self;
for(const auto i : c10::irange(size)) {
PyObject* obj = PyTuple_GET_ITEM(index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
// NOTE [nested tensor size for indexing]
// nested tensor does not have a size (yet) so for now we represent its size as null
// may need to be changed after we reach a better solution for nested tensor size
c10::optional<IntArrayRef> result_sizes = result.is_nested() ? c10::optional<IntArrayRef>(c10::nullopt) : c10::optional<IntArrayRef>(result.sizes());
result = at::indexing::handleDimInMultiDimIndexing(
/*prev_dim_result=*/result,
/*original_tensor=*/self,
Expand Down Expand Up @@ -202,7 +209,7 @@ static inline Variable applySlicing(
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing functions from Python ]
/*disable_slice_optimization=*/is_tracing,
/*original_tensor_device=*/self_device,
/*prev_dim_result_sizes=*/result.sizes());
/*prev_dim_result_sizes=*/result_sizes);
}
return result;
}
Expand Down Expand Up @@ -320,8 +327,11 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
if (specified_dims == -1) {
return handle_torch_function_indexing(self, holder.get());
}
// See NOTE [nested tensor size for indexing]
c10::optional<IntArrayRef> self_sizes = c10::nullopt;
if (! self_.is_nested()) self_sizes = self_.sizes();
Variable sliced = applySlicing(
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_.device(), self_.sizes(), specified_dims);
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_.device(), self_sizes, specified_dims);
if (variableIndices.empty()) {
if (sliced.is_same(self_)) {
// ensure we return a shallow copy for things like x[...]
Expand Down

0 comments on commit 6ad51c9

Please sign in to comment.