Skip to content

Commit 2a6ef9b

Browse files
lezcanopytorchmergebot
authored andcommitted
[dynamo] Avoid recompilation when the PyTorch function accepts scalars (pytorch#108162)
Before, it would create a 0D tensor with the input, which would incur in a guard and specialisation. It's not clear whether the guard and specialisation is the right behaviour when we create 0D tensors, but that's a story for another day. Pull Request resolved: pytorch#108162 Approved by: https://github.com/ev-br, https://github.com/peterbell10
1 parent 591cb77 commit 2a6ef9b

File tree

6 files changed

+62
-18
lines changed

6 files changed

+62
-18
lines changed

test/dynamo/test_misc.py

+18
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,24 @@ def fn(x, y):
13661366
self.assertEqual(ref, res)
13671367
self.assertEqual(cnts.frame_count, 2)
13681368

1369+
def test_numpy_recompilation_scalar(self):
1370+
def fn(x, a):
1371+
return np.where(x < 0.5, a, x)
1372+
1373+
x = np.random.randn(8)
1374+
cnts = torch._dynamo.testing.CompileCounter()
1375+
opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn)
1376+
1377+
ref = fn(x, 3)
1378+
res = opt_fn(x, 3)
1379+
self.assertEqual(ref, res)
1380+
1381+
ref = fn(x, 4)
1382+
res = opt_fn(x, 4)
1383+
self.assertEqual(ref, res)
1384+
1385+
self.assertEqual(cnts.frame_count, 1)
1386+
13691387
def test_tensor_interacts_with_numpy_ndarray(self):
13701388
def fn(x, y):
13711389
a = x.numpy()

torch/_numpy/_dtypes_impl.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -80,31 +80,52 @@ def python_type_for_torch(dtyp):
8080

8181
# ### NEP 50 helpers ###
8282

83-
SCALAR_TYPES = {int, bool, float, complex}
83+
_SCALAR_TYPES = (int, bool, float, complex)
84+
85+
_SCALAR_AND_SYMBOLIC_TYPES = (
86+
*_SCALAR_TYPES,
87+
torch.SymInt,
88+
torch.SymFloat,
89+
torch.SymBool,
90+
)
91+
92+
93+
def is_scalar(x):
94+
return isinstance(x, _SCALAR_TYPES)
95+
96+
97+
def is_scalar_or_symbolic(x):
98+
return isinstance(x, _SCALAR_AND_SYMBOLIC_TYPES)
8499

85100

86101
def _dtype_for_scalar(py_type):
87102
return {
88103
bool: torch.bool,
104+
torch.SymBool: torch.bool,
89105
int: torch.int64,
106+
torch.SymInt: torch.int64,
90107
float: torch.float64,
108+
torch.SymFloat: torch.float64,
91109
complex: torch.complex128,
92110
}[py_type]
93111

94112

95113
def _category(dtype):
96114
return {
97115
torch.bool: 0,
116+
torch.SymBool: 0,
98117
# int
99118
torch.uint8: 1,
100119
torch.int8: 1,
101120
torch.int16: 1,
102121
torch.int32: 1,
103122
torch.int64: 1,
123+
torch.SymInt: 1,
104124
# float
105125
torch.float16: 2,
106126
torch.float32: 2,
107127
torch.float64: 2,
128+
torch.SymFloat: 2,
108129
# complex
109130
torch.complex64: 3,
110131
torch.complex128: 3,

torch/_numpy/_funcs_impl.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from . import _dtypes_impl, _util
1818
from ._normalizations import (
1919
ArrayLike,
20+
ArrayLikeOrScalar,
2021
CastingModes,
2122
DTypeLike,
2223
NDArray,
@@ -626,8 +627,8 @@ def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0):
626627

627628
def where(
628629
condition: ArrayLike,
629-
x: Optional[ArrayLike] = None,
630-
y: Optional[ArrayLike] = None,
630+
x: Optional[ArrayLikeOrScalar] = None,
631+
y: Optional[ArrayLikeOrScalar] = None,
631632
/,
632633
):
633634
if (x is None) != (y is None):
@@ -984,8 +985,7 @@ def clip(
984985
return torch.clamp(a, min, max)
985986

986987

987-
def repeat(a: ArrayLike, repeats: ArrayLike, axis=None):
988-
# XXX: scalar repeats; ArrayLikeOrScalar ?
988+
def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None):
989989
return torch.repeat_interleave(a, repeats, axis)
990990

991991

@@ -1553,9 +1553,7 @@ def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1):
15531553
if n == 0:
15541554
# no spacing argument - use 1 in all axes
15551555
dx = [1.0] * len_axes
1556-
elif n == 1 and (
1557-
type(varargs[0]) in _dtypes_impl.SCALAR_TYPES or varargs[0].ndim == 0
1558-
):
1556+
elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0):
15591557
# single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0)
15601558
dx = varargs * len_axes
15611559
elif n == len_axes:
@@ -1616,7 +1614,7 @@ def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1):
16161614
out = torch.empty_like(f, dtype=otype)
16171615

16181616
# spacing for the current axis (NB: np.ndim(ax_dx) == 0)
1619-
uniform_spacing = type(ax_dx) in _dtypes_impl.SCALAR_TYPES or ax_dx.ndim == 0
1617+
uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0
16201618

16211619
# Numerical differentiation: 2nd order interior
16221620
slice1[axis] = slice(1, -1)

torch/_numpy/_ndarray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def __setitem__(self, index, value):
452452
index = _util.ndarrays_to_tensors(index)
453453
index = _upcast_int_indices(index)
454454

455-
if type(value) not in _dtypes_impl.SCALAR_TYPES:
455+
if not _dtypes_impl.is_scalar(value):
456456
value = normalize_array_like(value)
457457
value = _util.cast_if_needed(value, self.tensor.dtype)
458458

torch/_numpy/_normalizations.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,17 @@ def normalize_array_like(x, parm=None):
4747

4848

4949
def normalize_array_like_or_scalar(x, parm=None):
50-
if type(x) in _dtypes_impl.SCALAR_TYPES:
50+
if _dtypes_impl.is_scalar_or_symbolic(x):
5151
return x
5252
return normalize_array_like(x, parm)
5353

5454

55+
def normalize_optional_array_like_or_scalar(x, parm=None):
56+
if x is None:
57+
return None
58+
return normalize_array_like_or_scalar(x, parm)
59+
60+
5561
def normalize_optional_array_like(x, parm=None):
5662
# This explicit normalizer is needed because otherwise normalize_array_like
5763
# does not run for a parameter annotated as Optional[ArrayLike]
@@ -118,9 +124,10 @@ def normalize_casting(arg, parm=None):
118124

119125
normalizers = {
120126
"ArrayLike": normalize_array_like,
121-
"Union[ArrayLike, Scalar]": normalize_array_like_or_scalar,
127+
"ArrayLikeOrScalar": normalize_array_like_or_scalar,
122128
"Optional[ArrayLike]": normalize_optional_array_like,
123129
"Sequence[ArrayLike]": normalize_seq_array_like,
130+
"Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar,
124131
"Optional[NDArray]": normalize_ndarray,
125132
"Optional[OutArray]": normalize_outarray,
126133
"NDArray": normalize_ndarray,

torch/_numpy/_ufuncs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
from __future__ import annotations
22

3-
from typing import Optional, Union
3+
from typing import Optional
44

55
import torch
66

77
from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util
88
from ._normalizations import (
99
ArrayLike,
10+
ArrayLikeOrScalar,
1011
CastingModes,
1112
DTypeLike,
1213
normalizer,
1314
NotImplementedType,
1415
OutArray,
15-
Scalar,
1616
)
1717

1818

@@ -71,8 +71,8 @@ def deco_binary_ufunc(torch_func):
7171

7272
@normalizer
7373
def wrapped(
74-
x1: Union[ArrayLike, Scalar],
75-
x2: Union[ArrayLike, Scalar],
74+
x1: ArrayLikeOrScalar,
75+
x2: ArrayLikeOrScalar,
7676
/,
7777
out: Optional[OutArray] = None,
7878
*,
@@ -145,8 +145,8 @@ def matmul(
145145
# ldexp casting is special : the dtype of the result == dtype of the 1st arg
146146
@normalizer
147147
def ldexp(
148-
x1: Union[ArrayLike, Scalar],
149-
x2: Union[ArrayLike, Scalar],
148+
x1: ArrayLikeOrScalar,
149+
x2: ArrayLikeOrScalar,
150150
/,
151151
out: Optional[OutArray] = None,
152152
*,

0 commit comments

Comments
 (0)