Skip to content

Commit

Permalink
[Inductor] Fix test_conv2d_unary_cpu_cpp_wrapper failure (pytorch#137158
Browse files Browse the repository at this point in the history
)

Summary: test_conv2d_unary_cpu_cpp_wrapper is failing on ciflow/slow because of mis-handling of inf. This PR fixes that.

Pull Request resolved: pytorch#137158
Approved by: https://github.com/chenyang78
  • Loading branch information
desertfire authored and pytorchmergebot committed Oct 2, 2024
1 parent d117ec1 commit 5c2c3ca
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,17 @@ def load_custom_op_wrapper(self):

self.custom_op_wrapper_loaded = True

def generate_float_value(self, val):
assert isinstance(val, float)
if val == float("inf"):
return "std::numeric_limits<float>::infinity()"
elif val == float("-inf"):
return "-std::numeric_limits<float>::infinity()"
elif val == float("nan"):
return "std::numeric_limits<float>::quiet_NaN()"
else:
return f"{val}"

def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type):
def generate_py_arg_inner(lines, raw_arg, arg_type):
if raw_arg is None:
Expand Down Expand Up @@ -2211,7 +2222,7 @@ def generate_py_arg_inner(lines, raw_arg, arg_type):
)
return f"PyLong_FromLongLong({self.expr_printer(expr)})"
elif isinstance(arg_type, torch.FloatType):
return f"PyFloat_FromDouble({raw_arg})"
return f"PyFloat_FromDouble({self.generate_float_value(raw_arg)})"
elif isinstance(arg_type, torch.BoolType):
return f"PyBool_FromLong({1 if raw_arg else 0})"
elif isinstance(arg_type, torch.StringType):
Expand All @@ -2222,7 +2233,7 @@ def generate_py_arg_inner(lines, raw_arg, arg_type):
if isinstance(raw_arg, int):
return f"PyLong_FromLongLong({raw_arg})"
elif isinstance(raw_arg, float):
return f"PyFloat_FromDouble({raw_arg})"
return f"PyFloat_FromDouble({self.generate_float_value(raw_arg)})"
elif isinstance(raw_arg, bool):
return f"PyBool_FromLong({1 if raw_arg else 0})"
elif isinstance(raw_arg, complex):
Expand Down Expand Up @@ -2441,11 +2452,8 @@ def val_to_arg_str_for_prim_type(self, val, type_) -> str:
return self.codegen_device(val)
elif isinstance(val, torch.dtype):
return self.codegen_dtype(val)
elif isinstance(val, float) and val in [float("inf"), float("-inf")]:
if val == float("inf"):
return "std::numeric_limits<float>::infinity()"
else:
return "-std::numeric_limits<float>::infinity()"
elif isinstance(val, float):
return self.generate_float_value(val)
elif isinstance(val, (list, tuple)):
# FIXME: This happens because type_ is not always properly set to torch.ListType
return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}"
Expand Down

0 comments on commit 5c2c3ca

Please sign in to comment.