Skip to content

Commit

Permalink
(2/2) Make TorchScript Preserve Fully Qualified Class Name for Python…
Browse files Browse the repository at this point in the history
… Exceptions: frontend change (pytorch#72899)

Summary:
Pull Request resolved: pytorch#72899

Reland D33282878 (pytorch@911d527). This is the frontend change.
ghstack-source-id: 149204031

Test Plan: Refer to D33282878 (pytorch@911d527). Also check CI

Reviewed By: gmagogsfm

Differential Revision: D34252127

fbshipit-source-id: 27b17ddd4d05d904eb91fd9ee094d9121f00e388
(cherry picked from commit 1d276ba)
  • Loading branch information
shunting314 authored and pytorchmergebot committed Feb 16, 2022
1 parent f8a2efc commit 763ad1b
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 159 deletions.
3 changes: 3 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1880,6 +1880,9 @@ cc_test(
"test/cpp/jit/*.h",
"test/cpp/tensorexpr/*.cpp",
"test/cpp/tensorexpr/*.h",
], exclude=[
# skip this since <pybind11/embed.h> is not found in OSS build
"test/cpp/jit/test_exception.cpp",
]),
linkstatic = True,
tags = [
Expand Down
159 changes: 159 additions & 0 deletions test/cpp/jit/test_exception.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* We have a python unit test for exceptions in test/jit/test_exception.py .
* Add a CPP version here to verify that excepted exception types thrown from
* C++. This is hard to test in python code since C++ exceptions will be
* translated to python exceptions.
*/
#include <gtest/gtest.h>
#include <pybind11/embed.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/jit.h>
#include <iostream>
#include <stdexcept>

namespace torch {
namespace jit {

namespace py = pybind11;

TEST(TestException, TestAssertion) {
std::string pythonCode = R"PY(
def foo():
raise AssertionError("An assertion failed")
)PY";
auto cu_ptr = torch::jit::compile(pythonCode);
torch::jit::GraphFunction* gf =
(torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
std::cerr << "Graph is\n" << *gf->graph() << std::endl;

bool is_jit_exception = false;
std::string message;
c10::optional<std::string> exception_class;
try {
cu_ptr->run_method("foo");
} catch (JITException& e) {
is_jit_exception = true;
message = e.what();
exception_class = e.getPythonClassName();
}
EXPECT_TRUE(is_jit_exception);
EXPECT_FALSE(exception_class);
EXPECT_TRUE(
message.find("RuntimeError: AssertionError: An assertion failed") !=
std::string::npos);
}

struct MyPythonExceptionValue : public torch::jit::SugaredValue {
explicit MyPythonExceptionValue(const py::object& exception_class) {
qualified_name_ =
(py::str(py::getattr(exception_class, "__module__", py::str(""))) +
py::str(".") +
py::str(py::getattr(exception_class, "__name__", py::str(""))))
.cast<std::string>();
}

std::string kind() const override {
return "My Python exception";
}

// Simplified from PythonExceptionValue::call
std::shared_ptr<torch::jit::SugaredValue> call(
const torch::jit::SourceRange& loc,
torch::jit::GraphFunction& caller,
at::ArrayRef<torch::jit::NamedValue> args,
at::ArrayRef<torch::jit::NamedValue> kwargs,
size_t n_binders) override {
TORCH_CHECK(args.size() == 1);
Value* error_message = args.at(0).value(*caller.graph());
Value* qualified_class_name =
insertConstant(*caller.graph(), qualified_name_, loc);
return std::make_shared<ExceptionMessageValue>(
error_message, qualified_class_name);
}

private:
std::string qualified_name_;
};

class SimpleResolver : public torch::jit::Resolver {
public:
explicit SimpleResolver() {}

std::shared_ptr<torch::jit::SugaredValue> resolveValue(
const std::string& name,
torch::jit::GraphFunction& m,
const torch::jit::SourceRange& loc) override {
// follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
// a python extension. We can not add that as a cpp_binary's dep)
if (name == "SimpleValueError") {
py::object obj = py::globals()["SimpleValueError"];
return std::make_shared<MyPythonExceptionValue>(obj);
}
TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
}

torch::jit::TypePtr resolveType(
const std::string& name,
const torch::jit::SourceRange& loc) override {
return nullptr;
}
};

/*
* - The python source code parsing for TorchScript here is learned from
* torch::jit::compile.
* - The code only parses one Def. If there are multiple in the code, those
* except the first one are skipped.
*/
TEST(TestException, TestCustomException) {
py::scoped_interpreter guard{};
py::exec(R"PY(
class SimpleValueError(ValueError):
def __init__(self, message):
super(SimpleValueError, self).__init__(message)
)PY");

std::string pythonCode = R"PY(
def foo():
raise SimpleValueError("An assertion failed")
)PY";

torch::jit::Parser p(
std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
std::cerr << "Def is:\n" << def << std::endl;
auto cu = std::make_shared<torch::jit::CompilationUnit>();
(void)cu->define(
c10::nullopt,
{},
{},
{def},
// class PythonResolver is defined in
// torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
// can not use it. Create a SimpleResolver insteand
{std::make_shared<SimpleResolver>()},
nullptr);
torch::jit::GraphFunction* gf =
(torch::jit::GraphFunction*)&cu->get_function("foo");
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
bool is_jit_exception = false;
c10::optional<std::string> exception_class;
std::string message;
try {
cu->run_method("foo");
} catch (JITException& e) {
is_jit_exception = true;
exception_class = e.getPythonClassName();
message = e.what();
}
EXPECT_TRUE(is_jit_exception);
EXPECT_EQ("__main__.SimpleValueError", *exception_class);
EXPECT_TRUE(
message.find("__main__.SimpleValueError: An assertion failed") !=
std::string::npos);
}

} // namespace jit
} // namespace torch
8 changes: 8 additions & 0 deletions test/jit/myexception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
r"""
Define exceptions used in test_exception.py. We define them in a
separate file on purpose to make sure the fully qualified exception class name
is captured correctly in suce cases.
"""
class MyKeyError(KeyError):
def __init__(self, msg):
super(KeyError, self).__init__(msg)
176 changes: 176 additions & 0 deletions test/jit/test_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Owner(s): ["oncall: jit"]
from torch.testing._internal.common_utils import TestCase
import torch
from torch import nn

r"""
Test TorchScript exception handling.
"""
class TestException(TestCase):
def test_pyop_exception_message(self):
class Foo(torch.jit.ScriptModule):
def __init__(self):
super(Foo, self).__init__()
self.conv = nn.Conv2d(1, 10, kernel_size=5)

@torch.jit.script_method
def forward(self, x):
return self.conv(x)
foo = Foo()
# testing that the correct error message propagates
with self.assertRaisesRegex(RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"):
foo(torch.ones([123])) # wrong size

def test_builtin_error_messsage(self):
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
@torch.jit.script
def close_match(x):
return x.masked_fill(True)

with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
"supported in TorchScript"):
@torch.jit.script
def unknown_op(x):
torch.set_anomaly_enabled(True)
return x

def test_exceptions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
if bool(cond):
raise ValueError(3)
return 1
''')

cu.foo(torch.tensor(0))
with self.assertRaisesRegex(torch.jit.Error, "3"):
cu.foo(torch.tensor(1))

def foo(cond):
a = 3
if bool(cond):
raise ArbitraryError(a, "hi")
if 1 == 2:
raise ArbitraryError
return a

with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
torch.jit.script(foo)

def exception_as_value():
a = Exception()
print(a)

with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
torch.jit.script(exception_as_value)

@torch.jit.script
def foo_no_decl_always_throws():
raise RuntimeError("Hi")

# function that has no declared type but always throws set to None
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "NoneType")

@torch.jit.script
def foo_decl_always_throws():
# type: () -> Tensor
raise Exception("Hi")

output_type = next(foo_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "Tensor")

def foo():
raise 3 + 4

with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
torch.jit.script(foo)

# a escapes scope
@torch.jit.script
def foo():
if 1 == 1:
a = 1
else:
if 1 == 1:
raise Exception("Hi")
else:
raise Exception("Hi")
return a
self.assertEqual(foo(), 1)

@torch.jit.script
def tuple_fn():
raise RuntimeError("hello", "goodbye")

with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
tuple_fn()

@torch.jit.script
def no_message():
raise RuntimeError

with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
no_message()

def test_assertions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
assert bool(cond), "hi"
return 0
''')

cu.foo(torch.tensor(1))
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
cu.foo(torch.tensor(0))

@torch.jit.script
def foo(cond):
assert bool(cond), "hi"

foo(torch.tensor(1))
# we don't currently validate the name of the exception
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
foo(torch.tensor(0))

def test_python_op_exception(self):
@torch.jit.ignore
def python_op(x):
raise Exception("bad!")

@torch.jit.script
def fn(x):
return python_op(x)

with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
fn(torch.tensor(4))

def test_dict_expansion_raises_error(self):
def fn(self):
d = {"foo": 1, "bar": 2, "baz": 3}
return {**d}

with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
"Dict expansion "):
torch.jit.script(fn)

def test_custom_python_exception(self):
class MyValueError(ValueError):
def __init__(self, msg):
super(MyValueError, self).__init__(msg)

@torch.jit.script
def fn():
raise MyValueError("test custom exception")

with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"):
fn()

def test_custom_python_exception_defined_elsewhere(self):
from jit.myexception import MyKeyError

@torch.jit.script
def fn():
raise MyKeyError("This is a user defined key error")
with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"):
fn()
Loading

0 comments on commit 763ad1b

Please sign in to comment.