Skip to content

Commit

Permalink
Batching rules for: torch.bmm, torch.dot (#43781)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch/pytorch#43781

Test Plan: - `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23400843

Pulled By: zou3519

fbshipit-source-id: a901bba6dc2d8435d314cb4dac85bbd5cd4ee2a5
  • Loading branch information
zou3519 authored and facebook-github-bot committed Sep 1, 2020
1 parent fa12e22 commit dbc4218
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 0 deletions.
78 changes: 78 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,81 @@ Tensor mv_batching_rule(const Tensor& self, const Tensor& other) {
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
}

Tensor dot_batching_rule(const Tensor& self, const Tensor& other) {
auto self_batched = isBatchedTensor(self);
auto other_batched = isBatchedTensor(other);

TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1,
"dot(self, other): Shape mismatch: vector "
"(got `self` of size ", self.sizes(), ") ",
"and vector (got `other` of size ", other.sizes(), ")");

// See Note [Batching rules for matmul-like operators] for why we have cases
if (self_batched && !other_batched) {
// self_physical: [..., K], other_physical: [K]
// View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze.
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other);
return self_physical.newLogicalFromPhysical(result.squeeze(-1));
}
if (!self_batched && other_batched) {
// self_physical: [K], other_physical: [..., K]
// View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze.
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
return other_physical.newLogicalFromPhysical(result.squeeze(-1));
}
if (self_batched && other_batched) {
// self_physical: [..., K], other_physical: [..., K]
// View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze.
auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
auto result = at::matmul(
physical_args[0].tensor().unsqueeze(-2),
physical_args[1].tensor().unsqueeze(-1));
return physical_args[0].newLogicalFromPhysical(result.squeeze(-1).squeeze(-1));
}
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
}

Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) {
TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3,
"bmm(self, other): Shape mismatch: expected 3D `self` "
"(got `self` of size ", self.sizes(), ") ",
"and 3D `other` (got `other` of size ", other.sizes(), ")");

auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
return physical_args[0].newLogicalFromPhysical(result);
}

Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
auto self_batched = isBatchedTensor(self);
auto other_batched = isBatchedTensor(other);

TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2,
"mm(self, other): Shape mismatch: expected matrix "
"(got `self` of size ", self.sizes(), ") ",
"and matrix (got `other` of size ", other.sizes(), ")");

// See Note [Batching rules for matmul-like operators] for why we have cases
if (self_batched && !other_batched) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::matmul(self_physical.tensor(), other);
return self_physical.newLogicalFromPhysical(result);
}
if (!self_batched && other_batched) {
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
auto result = at::matmul(self, other_physical.tensor());
return other_physical.newLogicalFromPhysical(result);
}
if (self_batched && other_batched) {
auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
return physical_args[0].newLogicalFromPhysical(result.squeeze(-1).squeeze(-1));
}
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
}

// I am quite sad that we need to register operators with exploded TensorOptions,
// even though the native:: implementations can use TensorOptions&.
// This also makes it hard to metaprogram: i.e., we can't use
Expand Down Expand Up @@ -508,6 +583,9 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {

// matmul-like operators
m.impl("mv", mv_batching_rule);
m.impl("dot", dot_batching_rule);
m.impl("bmm", bmm_batching_rule);
m.impl("mm", mm_batching_rule);
}

} // namespace at
90 changes: 90 additions & 0 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,36 @@ def make_case(op, input_getter=TensorFactory.randn):
self._test_unary(lambda t: op(number, t), getter, device='cuda')
self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')

def test_bmm(self):
op = torch.bmm
test = self._vmap_test
B0, B1 = 7, 11

# shape mismatch
msg = "Shape mismatch"
with self.assertRaisesRegex(RuntimeError, msg):
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
with self.assertRaisesRegex(RuntimeError, msg):
vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2))
with self.assertRaisesRegex(RuntimeError, msg):
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))

# left arg is vmapped
test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None))
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)),
in_dims=(1, None))

# right arg is vmapped
test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)),
in_dims=(None, 1))

# both args are vmapped
test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3)))
test(vmap(op), (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), in_dims=(1, 0))
test(vmap(op, in_dims=(0, None)),
(torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), in_dims=(None, 0))

def test_chunk(self):
test = self._vmap_view_test
op = torch.chunk
Expand All @@ -901,6 +931,36 @@ def test_diagonal(self):
test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
(tensor,), in_dims=1, out_dims=1)

def test_dot(self):
op = torch.dot
test = self._vmap_test
B0, B1 = 7, 11

# shape mismatch
msg = "Shape mismatch"
with self.assertRaisesRegex(RuntimeError, msg):
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
with self.assertRaisesRegex(RuntimeError, msg):
vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
with self.assertRaisesRegex(RuntimeError, msg):
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2))

# left arg is vmapped
test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None))
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 5), torch.rand(5)),
in_dims=(1, None))

# right arg is vmapped
test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0))
test(vmap(op, in_dims=(None, 0)), (torch.rand(5), torch.rand(B1, B0, 5)),
in_dims=(None, 1))

# both args are vmapped
test(op, (torch.rand(B0, 5), torch.rand(B0, 5)))
test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
test(vmap(op, in_dims=(0, None)),
(torch.rand(B1, 5), torch.rand(B0, 5)), in_dims=(None, 0))

def test_expand_as(self):
op = torch.Tensor.expand_as
test = self._vmap_view_test
Expand Down Expand Up @@ -933,6 +993,36 @@ def test_movedim(self):
test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
(torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), in_dims=(2, None, None))

def test_mm(self):
op = torch.mm
test = self._vmap_test
B0, B1 = 7, 11

# shape mismatch
msg = "Shape mismatch"
with self.assertRaisesRegex(RuntimeError, msg):
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
with self.assertRaisesRegex(RuntimeError, msg):
vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
with self.assertRaisesRegex(RuntimeError, msg):
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))

# left arg is vmapped
test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None))
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)),
in_dims=(1, None))

# right arg is vmapped
test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)),
in_dims=(None, 1))

# both args are vmapped
test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2)))
test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), in_dims=(1, 0))
test(vmap(op, in_dims=(0, None)),
(torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))

def test_mv(self):
op = torch.mv
test = self._vmap_test
Expand Down

0 comments on commit dbc4218

Please sign in to comment.