Skip to content

Commit c1cc74c

Browse files
lezcanopytorchmergebot
authored andcommitted
Enable a number inductor of tests on CPU (pytorch#107465)
There were many test that their `_cuda` variants were not running on cuda. I fixed a few of these, but I'm sure there's plenty more. It'd be great to have a way to test that we're indeed compiling something in these tests, but I don't know how to do this off the top of my head. Pull Request resolved: pytorch#107465 Approved by: https://github.com/ezyang
1 parent 71632d4 commit c1cc74c

File tree

2 files changed

+28
-37
lines changed

2 files changed

+28
-37
lines changed

test/inductor/test_torchinductor.py

+26-37
Original file line numberDiff line numberDiff line change
@@ -6024,19 +6024,12 @@ def fn0(i0, i1):
60246024
def fn1(i0, i1):
60256025
return torch.lerp(i1, i0, 70000)
60266026

6027-
def compare(fn, inputs):
6028-
compiled = torch._dynamo.optimize("inductor")(fn)
6029-
expected = fn(*inputs)
6030-
actual = compiled(*inputs)
6031-
self.assertEqual(expected, actual)
6032-
self.assertEqual(expected.stride(), actual.stride())
6033-
6034-
compare(fn0, [torch.rand(10, 3, 10), torch.rand(3, 10, 10)])
6035-
compare(fn1, [torch.rand(3, 10, 10), torch.rand(3, 10, 10)])
6027+
self.common(fn0, [torch.rand(10, 3, 10), torch.rand(3, 10, 10)])
6028+
self.common(fn1, [torch.rand(3, 10, 10), torch.rand(3, 10, 10)])
60366029

60376030
def test_unspec_inputs(self):
60386031
if self.device == "cpu":
6039-
raise unittest.SkipTest("segfault with CPU backend")
6032+
raise unittest.SkipTest("Testing mixed devices")
60406033

60416034
def fn(x, y):
60426035
return x + y, x * y, x / y
@@ -6138,9 +6131,7 @@ def fn(x):
61386131
return attn.softmax(dim=-1)
61396132

61406133
x = torch.rand(128, 32, 63)
6141-
res_ref = fn(x)
6142-
res = torch._dynamo.optimize("inductor")(fn)(x)
6143-
self.assertEqual(res, res_ref)
6134+
self.common(fn, (x,))
61446135

61456136
def test_kwargs(self):
61466137
if self.device == "cuda":
@@ -6242,9 +6233,6 @@ def fn(a, b):
62426233
)
62436234

62446235
def test_index_dynamic_shapes(self):
6245-
if self.device == "cuda":
6246-
raise unittest.SkipTest("index dynamic shapes only supports cpu")
6247-
62486236
# Repro from vision_maskrcnn
62496237
def fn(arg0_1):
62506238
unsqueeze = arg0_1.unsqueeze(0)
@@ -6255,7 +6243,7 @@ def fn(arg0_1):
62556243
start=0,
62566244
step=1,
62576245
dtype=torch.int64,
6258-
device="cpu",
6246+
device=arg0_1.device,
62596247
requires_grad=False,
62606248
)
62616249
convert_element_type_1 = iota.to(torch.float32)
@@ -6267,7 +6255,7 @@ def fn(arg0_1):
62676255
start=0,
62686256
step=1,
62696257
dtype=torch.int64,
6270-
device="cpu",
6258+
device=arg0_1.device,
62716259
requires_grad=False,
62726260
)
62736261
convert_element_type_3 = iota_1.to(torch.float32)
@@ -6507,9 +6495,9 @@ def fn(a):
65076495
return a[out_features.index(in_feature)]
65086496

65096497
x = [
6510-
torch.rand([1, 256, 100, 152]),
6511-
torch.rand([1, 256, 50, 76]),
6512-
torch.rand([1, 256, 25, 38]),
6498+
torch.rand([1, 256, 100, 152], device=self.device),
6499+
torch.rand([1, 256, 50, 76], device=self.device),
6500+
torch.rand([1, 256, 25, 38], device=self.device),
65136501
]
65146502
opt_fn = torch._dynamo.optimize("inductor")(fn)
65156503
same(fn(x), opt_fn(x))
@@ -6521,8 +6509,7 @@ def fn(a):
65216509
return y
65226510

65236511
x = torch.rand(48, 3, 512, 512)
6524-
opt_fn = torch._dynamo.optimize("inductor")(fn)
6525-
same(fn(x), opt_fn(x))
6512+
self.common(fn, (x,))
65266513

65276514
@unittest.skipIf(not HAS_CPU, "requires C++ compiler")
65286515
def test_data_type_propogation(self):
@@ -6636,6 +6623,10 @@ def func(arg0_1):
66366623
elif node.target == "output":
66376624
self.assertEqual(get_data_type(node), torch.bfloat16)
66386625

6626+
# Calling div only torch.SymInt arguments is not yet supported.
6627+
# To support this behavior, we need to allow const-propping tensors that store symint data.
6628+
# For now, dynamo will explicitly graph break when it encounters user code with this behavior.
6629+
@expectedFailureCodegenDynamic
66396630
def test_AllenaiLongformerBase_repro(self):
66406631
def fn(query, scores, window_overlap):
66416632
batch_size, seq_len, num_heads, _ = query.size()
@@ -6661,12 +6652,12 @@ def fn(query, scores, window_overlap):
66616652
return input_tensor
66626653

66636654
args = [
6664-
((4, 1024, 12, 64), (768, 3072, 64, 1), torch.float32, "cpu"),
6665-
((48, 3, 512, 513), (787968, 262656, 513, 1), torch.float32, "cpu"),
6655+
((4, 1024, 12, 64), (768, 3072, 64, 1)),
6656+
((48, 3, 512, 513), (787968, 262656, 513, 1)),
66666657
]
6667-
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
6668-
opt_fn = torch._dynamo.optimize("inductor")(fn)
6669-
same(fn(*args, 256), opt_fn(*args, 256))
6658+
args = [rand_strided(sh, st) for (sh, st) in args]
6659+
args.append(256)
6660+
self.common(fn, args)
66706661

66716662
def test_cumsum_pattern_matcher_issue(self):
66726663
def fn(input_ids) -> torch.Tensor:
@@ -6675,25 +6666,23 @@ def fn(input_ids) -> torch.Tensor:
66756666
batch_size, seq_length = input_shape
66766667
past_key_values_length = 0
66776668
mask_seq_length = past_key_values_length + seq_length
6678-
attention_mask = torch.ones(batch_size, mask_seq_length)
6669+
attention_mask = torch.ones(
6670+
batch_size, mask_seq_length, device=input_ids.device
6671+
)
66796672
attention_mask = attention_mask.long()
66806673
return torch.cumsum(attention_mask, dim=1)
66816674

6682-
torch._dynamo.reset()
66836675
x = torch.randn(2, 2)
6684-
opt = torch._dynamo.optimize("inductor")(fn)
6685-
res = opt(x)
6686-
ref = fn(x)
6687-
self.assertEqual(res, ref, atol=0, rtol=0)
6676+
self.common(fn, (x,), atol=0, rtol=0)
66886677

6678+
# It's a view so it doens't generate a kernel
6679+
@expectedFailureCodegenDynamic
66896680
def test_slice(self):
66906681
def fn(a, b):
66916682
return torch.ops.aten.slice.Tensor(a, 0, 0, -b)
66926683

6693-
torch._dynamo.reset()
66946684
x = torch.rand(48, 3, 512, 512)
6695-
opt_fn = torch._dynamo.optimize("inductor")(fn)
6696-
same(fn(x, 2), opt_fn(x, 2))
6685+
self.common(fn, (x, 2))
66976686

66986687
def test_inplace_resize_as(self):
66996688
def fn(x, y):

test/inductor/test_torchinductor_dynamic_shapes.py

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
# xfail by default, set is_skip=True to skip
4747
test_failures = {
4848
"test_kwargs_dynamic_shapes": TestFailure(("cpu",)),
49+
# calling div on only symint args
50+
"test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure(("cpu", "cuda")),
4951
}
5052

5153
if TEST_WITH_ROCM:

0 commit comments

Comments
 (0)