forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
(2/2) Make TorchScript Preserve Fully Qualified Class Name for Python…
… Exceptions: frontend change (pytorch#72899) Summary: Pull Request resolved: pytorch#72899 Reland D33282878 (pytorch@911d527). This is the frontend change. ghstack-source-id: 149204031 Test Plan: Refer to D33282878 (pytorch@911d527). Also check CI Reviewed By: gmagogsfm Differential Revision: D34252127 fbshipit-source-id: 27b17ddd4d05d904eb91fd9ee094d9121f00e388 (cherry picked from commit 1d276ba)
- Loading branch information
1 parent
f8a2efc
commit 763ad1b
Showing
10 changed files
with
388 additions
and
159 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
/* | ||
* We have a python unit test for exceptions in test/jit/test_exception.py . | ||
* Add a CPP version here to verify that excepted exception types thrown from | ||
* C++. This is hard to test in python code since C++ exceptions will be | ||
* translated to python exceptions. | ||
*/ | ||
#include <gtest/gtest.h> | ||
#include <pybind11/embed.h> | ||
#include <torch/csrc/jit/frontend/parser.h> | ||
#include <torch/csrc/jit/frontend/resolver.h> | ||
#include <torch/csrc/jit/runtime/jit_exception.h> | ||
#include <torch/jit.h> | ||
#include <iostream> | ||
#include <stdexcept> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
namespace py = pybind11; | ||
|
||
TEST(TestException, TestAssertion) { | ||
std::string pythonCode = R"PY( | ||
def foo(): | ||
raise AssertionError("An assertion failed") | ||
)PY"; | ||
auto cu_ptr = torch::jit::compile(pythonCode); | ||
torch::jit::GraphFunction* gf = | ||
(torch::jit::GraphFunction*)&cu_ptr->get_function("foo"); | ||
std::cerr << "Graph is\n" << *gf->graph() << std::endl; | ||
|
||
bool is_jit_exception = false; | ||
std::string message; | ||
c10::optional<std::string> exception_class; | ||
try { | ||
cu_ptr->run_method("foo"); | ||
} catch (JITException& e) { | ||
is_jit_exception = true; | ||
message = e.what(); | ||
exception_class = e.getPythonClassName(); | ||
} | ||
EXPECT_TRUE(is_jit_exception); | ||
EXPECT_FALSE(exception_class); | ||
EXPECT_TRUE( | ||
message.find("RuntimeError: AssertionError: An assertion failed") != | ||
std::string::npos); | ||
} | ||
|
||
struct MyPythonExceptionValue : public torch::jit::SugaredValue { | ||
explicit MyPythonExceptionValue(const py::object& exception_class) { | ||
qualified_name_ = | ||
(py::str(py::getattr(exception_class, "__module__", py::str(""))) + | ||
py::str(".") + | ||
py::str(py::getattr(exception_class, "__name__", py::str("")))) | ||
.cast<std::string>(); | ||
} | ||
|
||
std::string kind() const override { | ||
return "My Python exception"; | ||
} | ||
|
||
// Simplified from PythonExceptionValue::call | ||
std::shared_ptr<torch::jit::SugaredValue> call( | ||
const torch::jit::SourceRange& loc, | ||
torch::jit::GraphFunction& caller, | ||
at::ArrayRef<torch::jit::NamedValue> args, | ||
at::ArrayRef<torch::jit::NamedValue> kwargs, | ||
size_t n_binders) override { | ||
TORCH_CHECK(args.size() == 1); | ||
Value* error_message = args.at(0).value(*caller.graph()); | ||
Value* qualified_class_name = | ||
insertConstant(*caller.graph(), qualified_name_, loc); | ||
return std::make_shared<ExceptionMessageValue>( | ||
error_message, qualified_class_name); | ||
} | ||
|
||
private: | ||
std::string qualified_name_; | ||
}; | ||
|
||
class SimpleResolver : public torch::jit::Resolver { | ||
public: | ||
explicit SimpleResolver() {} | ||
|
||
std::shared_ptr<torch::jit::SugaredValue> resolveValue( | ||
const std::string& name, | ||
torch::jit::GraphFunction& m, | ||
const torch::jit::SourceRange& loc) override { | ||
// follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is | ||
// a python extension. We can not add that as a cpp_binary's dep) | ||
if (name == "SimpleValueError") { | ||
py::object obj = py::globals()["SimpleValueError"]; | ||
return std::make_shared<MyPythonExceptionValue>(obj); | ||
} | ||
TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'"); | ||
} | ||
|
||
torch::jit::TypePtr resolveType( | ||
const std::string& name, | ||
const torch::jit::SourceRange& loc) override { | ||
return nullptr; | ||
} | ||
}; | ||
|
||
/* | ||
* - The python source code parsing for TorchScript here is learned from | ||
* torch::jit::compile. | ||
* - The code only parses one Def. If there are multiple in the code, those | ||
* except the first one are skipped. | ||
*/ | ||
TEST(TestException, TestCustomException) { | ||
py::scoped_interpreter guard{}; | ||
py::exec(R"PY( | ||
class SimpleValueError(ValueError): | ||
def __init__(self, message): | ||
super(SimpleValueError, self).__init__(message) | ||
)PY"); | ||
|
||
std::string pythonCode = R"PY( | ||
def foo(): | ||
raise SimpleValueError("An assertion failed") | ||
)PY"; | ||
|
||
torch::jit::Parser p( | ||
std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1)); | ||
auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false)); | ||
std::cerr << "Def is:\n" << def << std::endl; | ||
auto cu = std::make_shared<torch::jit::CompilationUnit>(); | ||
(void)cu->define( | ||
c10::nullopt, | ||
{}, | ||
{}, | ||
{def}, | ||
// class PythonResolver is defined in | ||
// torch/csrc/jit/python/script_init.cpp. It's not in a header file so I | ||
// can not use it. Create a SimpleResolver insteand | ||
{std::make_shared<SimpleResolver>()}, | ||
nullptr); | ||
torch::jit::GraphFunction* gf = | ||
(torch::jit::GraphFunction*)&cu->get_function("foo"); | ||
std::cerr << "Graph is\n" << *gf->graph() << std::endl; | ||
bool is_jit_exception = false; | ||
c10::optional<std::string> exception_class; | ||
std::string message; | ||
try { | ||
cu->run_method("foo"); | ||
} catch (JITException& e) { | ||
is_jit_exception = true; | ||
exception_class = e.getPythonClassName(); | ||
message = e.what(); | ||
} | ||
EXPECT_TRUE(is_jit_exception); | ||
EXPECT_EQ("__main__.SimpleValueError", *exception_class); | ||
EXPECT_TRUE( | ||
message.find("__main__.SimpleValueError: An assertion failed") != | ||
std::string::npos); | ||
} | ||
|
||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
r""" | ||
Define exceptions used in test_exception.py. We define them in a | ||
separate file on purpose to make sure the fully qualified exception class name | ||
is captured correctly in suce cases. | ||
""" | ||
class MyKeyError(KeyError): | ||
def __init__(self, msg): | ||
super(KeyError, self).__init__(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# Owner(s): ["oncall: jit"] | ||
from torch.testing._internal.common_utils import TestCase | ||
import torch | ||
from torch import nn | ||
|
||
r""" | ||
Test TorchScript exception handling. | ||
""" | ||
class TestException(TestCase): | ||
def test_pyop_exception_message(self): | ||
class Foo(torch.jit.ScriptModule): | ||
def __init__(self): | ||
super(Foo, self).__init__() | ||
self.conv = nn.Conv2d(1, 10, kernel_size=5) | ||
|
||
@torch.jit.script_method | ||
def forward(self, x): | ||
return self.conv(x) | ||
foo = Foo() | ||
# testing that the correct error message propagates | ||
with self.assertRaisesRegex(RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"): | ||
foo(torch.ones([123])) # wrong size | ||
|
||
def test_builtin_error_messsage(self): | ||
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): | ||
@torch.jit.script | ||
def close_match(x): | ||
return x.masked_fill(True) | ||
|
||
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently " | ||
"supported in TorchScript"): | ||
@torch.jit.script | ||
def unknown_op(x): | ||
torch.set_anomaly_enabled(True) | ||
return x | ||
|
||
def test_exceptions(self): | ||
cu = torch.jit.CompilationUnit(''' | ||
def foo(cond): | ||
if bool(cond): | ||
raise ValueError(3) | ||
return 1 | ||
''') | ||
|
||
cu.foo(torch.tensor(0)) | ||
with self.assertRaisesRegex(torch.jit.Error, "3"): | ||
cu.foo(torch.tensor(1)) | ||
|
||
def foo(cond): | ||
a = 3 | ||
if bool(cond): | ||
raise ArbitraryError(a, "hi") | ||
if 1 == 2: | ||
raise ArbitraryError | ||
return a | ||
|
||
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"): | ||
torch.jit.script(foo) | ||
|
||
def exception_as_value(): | ||
a = Exception() | ||
print(a) | ||
|
||
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"): | ||
torch.jit.script(exception_as_value) | ||
|
||
@torch.jit.script | ||
def foo_no_decl_always_throws(): | ||
raise RuntimeError("Hi") | ||
|
||
# function that has no declared type but always throws set to None | ||
output_type = next(foo_no_decl_always_throws.graph.outputs()).type() | ||
self.assertTrue(str(output_type) == "NoneType") | ||
|
||
@torch.jit.script | ||
def foo_decl_always_throws(): | ||
# type: () -> Tensor | ||
raise Exception("Hi") | ||
|
||
output_type = next(foo_decl_always_throws.graph.outputs()).type() | ||
self.assertTrue(str(output_type) == "Tensor") | ||
|
||
def foo(): | ||
raise 3 + 4 | ||
|
||
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"): | ||
torch.jit.script(foo) | ||
|
||
# a escapes scope | ||
@torch.jit.script | ||
def foo(): | ||
if 1 == 1: | ||
a = 1 | ||
else: | ||
if 1 == 1: | ||
raise Exception("Hi") | ||
else: | ||
raise Exception("Hi") | ||
return a | ||
self.assertEqual(foo(), 1) | ||
|
||
@torch.jit.script | ||
def tuple_fn(): | ||
raise RuntimeError("hello", "goodbye") | ||
|
||
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"): | ||
tuple_fn() | ||
|
||
@torch.jit.script | ||
def no_message(): | ||
raise RuntimeError | ||
|
||
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"): | ||
no_message() | ||
|
||
def test_assertions(self): | ||
cu = torch.jit.CompilationUnit(''' | ||
def foo(cond): | ||
assert bool(cond), "hi" | ||
return 0 | ||
''') | ||
|
||
cu.foo(torch.tensor(1)) | ||
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): | ||
cu.foo(torch.tensor(0)) | ||
|
||
@torch.jit.script | ||
def foo(cond): | ||
assert bool(cond), "hi" | ||
|
||
foo(torch.tensor(1)) | ||
# we don't currently validate the name of the exception | ||
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): | ||
foo(torch.tensor(0)) | ||
|
||
def test_python_op_exception(self): | ||
@torch.jit.ignore | ||
def python_op(x): | ||
raise Exception("bad!") | ||
|
||
@torch.jit.script | ||
def fn(x): | ||
return python_op(x) | ||
|
||
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"): | ||
fn(torch.tensor(4)) | ||
|
||
def test_dict_expansion_raises_error(self): | ||
def fn(self): | ||
d = {"foo": 1, "bar": 2, "baz": 3} | ||
return {**d} | ||
|
||
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, | ||
"Dict expansion "): | ||
torch.jit.script(fn) | ||
|
||
def test_custom_python_exception(self): | ||
class MyValueError(ValueError): | ||
def __init__(self, msg): | ||
super(MyValueError, self).__init__(msg) | ||
|
||
@torch.jit.script | ||
def fn(): | ||
raise MyValueError("test custom exception") | ||
|
||
with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"): | ||
fn() | ||
|
||
def test_custom_python_exception_defined_elsewhere(self): | ||
from jit.myexception import MyKeyError | ||
|
||
@torch.jit.script | ||
def fn(): | ||
raise MyKeyError("This is a user defined key error") | ||
with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"): | ||
fn() |
Oops, something went wrong.