forked from NVIDIA/Fuser
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytest_utils.py
155 lines (123 loc) · 3.99 KB
/
pytest_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Owner(s): ["module: nvfuser"]
import torch
import jax.numpy as jnp
from torch.testing import make_tensor
from typing import Optional
from enum import Enum, auto
class ArgumentType(Enum):
# a symbolic value requires an input argument during kernel execution
Symbolic = auto()
# scalar with constant value
ConstantScalar = auto()
# python number - int, float, complex, bool
Constant = auto()
bool_dtypes = (torch.bool,)
int_dtypes = (
torch.int32,
torch.int64,
)
half_precision_float_dtypes = (
torch.bfloat16,
torch.float16,
)
full_precision_float_dtypes = (
torch.float32,
torch.float64,
)
complex_dtypes = (
torch.complex64,
torch.complex128,
)
# Half-precision float dtypes bf16, fp16 are skipped because nvfuser upcasts those dtypes to fp32
# but does not return the original type.
bool_int_dtypes = bool_dtypes + int_dtypes
float_dtypes = half_precision_float_dtypes + full_precision_float_dtypes
int_float_dtypes = int_dtypes + full_precision_float_dtypes
float_complex_dtypes = full_precision_float_dtypes + complex_dtypes
all_dtypes_except_reduced = int_dtypes + full_precision_float_dtypes + complex_dtypes
all_dtypes_except_bool = all_dtypes_except_reduced + half_precision_float_dtypes
all_dtypes = all_dtypes_except_bool + bool_dtypes
map_dtype_to_str = {
torch.bool: "bool",
torch.uint8: "uint8",
torch.int8: "int8",
torch.int16: "int16",
torch.int32: "int32",
torch.int64: "int64",
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float32: "float32",
torch.float64: "float64",
torch.complex64: "complex64",
torch.complex128: "complex128",
}
torch_to_jax_dtype_map = {
torch.bool: jnp.bool_,
torch.uint8: jnp.uint8,
torch.int8: jnp.int8,
torch.int16: jnp.int16,
torch.int32: jnp.int32,
torch.int64: jnp.int64,
torch.bfloat16: jnp.bfloat16,
torch.float16: jnp.float16,
torch.float32: jnp.float32,
torch.float64: jnp.float64,
torch.complex64: jnp.complex64,
torch.complex128: jnp.complex128,
}
torch_to_python_dtype_map = {
torch.bool: bool,
torch.uint8: int,
torch.int8: int,
torch.int16: int,
torch.int32: int,
torch.int64: int,
torch.bfloat16: float,
torch.float16: float,
torch.float32: float,
torch.float64: float,
torch.complex64: complex,
torch.complex128: complex,
}
def make_tensor_like(a):
# type: (torch.Tensor) -> torch.Tensor
"""Returns a tensor with the same properties as the given tensor.
Args:
a (torch.Tensor): The tensor to copy properties from.
Returns:
torch.Tensor: A tensor with the same properties as :attr:`a`.
"""
return torch.testing.make_tensor(
a.shape, device=a.device, dtype=a.dtype, requires_grad=a.requires_grad
)
def make_number(
dtype: torch.dtype, low: Optional[float] = None, high: Optional[float] = None
):
"""Returns a random number with desired dtype
Args:
dtype (torch.dtype): Desired dtype for number.
low (Optional[Number]): Sets the lower limit (inclusive) of the given range.
high (Optional[Number]): Sets the upper limit (exclusive) of the given range.
Returns:
(Scalar): The scalar number with specified dtype.
"""
return make_tensor([1], device="cpu", dtype=dtype, low=low, high=high).item()
def find_nonmatching_dtype(dtype: torch.dtype):
if dtype in int_float_dtypes:
return torch.complex128
elif dtype in complex_dtypes:
return torch.double
elif dtype is torch.bool:
return torch.float32
return None
def is_complex_dtype(dtype: torch.dtype):
return dtype in complex_dtypes
def is_floating_dtype(dtype: torch.dtype):
return dtype in float_dtypes
def is_integer_dtype(dtype: torch.dtype):
return dtype in int_dtypes
def is_tensor(a):
return isinstance(a, torch.Tensor)