Skip to content

Commit

Permalink
Fix key error formatting and move exc code to exc.py (pytorch#92593)
Browse files Browse the repository at this point in the history
Fixes pytorch/torchdynamo#1953 and moves exception formatting code from convert_frame.py to exc.py

Pull Request resolved: pytorch#92593
Approved by: https://github.com/ezyang
  • Loading branch information
mlazos authored and pytorchmergebot committed Jan 19, 2023
1 parent ba68205 commit cac217c
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 99 deletions.
85 changes: 3 additions & 82 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
import itertools
import logging
import os
import traceback
import types
import weakref
from traceback import FrameSummary
from typing import cast, Dict, List, Optional, Set
from typing import Dict, Optional, Set

import torch
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
Expand All @@ -17,7 +15,9 @@
from .bytecode_transformation import is_generator, transform_code_object
from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher
from .exc import (
augment_exc_message,
BackendCompilerFailed,
format_error_msg,
InternalTorchDynamoError,
TorchRuntimeError,
unimplemented,
Expand All @@ -32,7 +32,6 @@
CleanupManager,
counters,
dynamo_timed,
filter_stack,
format_bytecode,
gen_record_file_name,
guard_failures,
Expand Down Expand Up @@ -173,84 +172,6 @@ def has_tensor(obj):
return False


def format_error_msg(exc, code, record_filename=None, frame=None):
msg = os.linesep * 2

if config.verbose:
msg = format_bytecode(
"WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code
)
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
msg += traceback.format_exc()
if hasattr(exc, "real_stack"):
msg += (
"\n"
+ "=" * 10
+ " The above exception occurred while processing the following code "
+ "=" * 10
+ "\n\n"
)
stack_above_dynamo = []
if frame is not None:
stack_above_dynamo = filter_stack(traceback.extract_stack(frame))

msg += "".join(
traceback.format_list(
stack_above_dynamo + list(reversed(get_real_stack(exc)))
)
)
msg += "\n"
msg += "=" * 10

else:
msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
line {code.co_firstlineno} \ndue to: \n{traceback.format_exc(limit=-1)}"

return msg


def get_real_stack(exc) -> List[FrameSummary]:
assert hasattr(exc, "real_stack")
return cast(List[FrameSummary], exc.real_stack)


def augment_exc_message(exc, msg="\n"):
if (
hasattr(exc, "real_stack")
and len(exc.real_stack) > 0
and not (config.verbose and config.suppress_errors)
):
msg += f"\nfrom user code:\n {''.join(traceback.format_list(list(reversed(get_real_stack(exc)[0:2]))))}"

if config.replay_record_enabled and hasattr(exc, "record_filename"):
msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
{config.dynamo_import}.replay('{exc.record_filename}').\n"

if not config.verbose:
msg += (
f"\nSet {config.dynamo_import}.config.verbose=True for more information\n"
)

if hasattr(exc, "inner_exception") and hasattr(
exc.inner_exception, "minifier_path"
):
msg += (
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
"this script to find the smallest traced graph which reproduces this error.\n"
)

if not config.suppress_errors:
msg += (
"\n\n"
"You can suppress this exception and fall back to eager by setting:\n"
" torch._dynamo.config.suppress_errors = True\n"
)

old_msg = "" if len(exc.args) == 0 else exc.args[0]
new_msg = old_msg + msg
exc.args = (new_msg,) + exc.args[1:]


def exception_handler(e, code, frame=None):
record_filename = None
if hasattr(e, "exec_record"):
Expand Down
118 changes: 117 additions & 1 deletion torch/_dynamo/exc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import textwrap
from traceback import extract_stack, format_exc, format_list, FrameSummary
from typing import cast, List

from .utils import counters
from . import config

from .utils import counters, format_bytecode


class TorchDynamoException(RuntimeError):
Expand Down Expand Up @@ -70,3 +74,115 @@ def unimplemented(msg: str):
def warning(msg: str):
counters["warnings"][msg] += 1
assert msg != os.environ.get("BREAK", False)


# KeyError has special handling for its args
# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
class KeyErrorMsg:
def __init__(self, value):
self.value = value

def __str__(self):
return str(self.value)

def __repr__(self) -> str:
return self.__str__()


def augment_exc_message(exc, msg="\n"):
import traceback

if (
hasattr(exc, "real_stack")
and len(exc.real_stack) > 0
and not (config.verbose and config.suppress_errors)
):
msg += f"\nfrom user code:\n {''.join(traceback.format_list(list(reversed(get_real_stack(exc)[0:2]))))}"

if config.replay_record_enabled and hasattr(exc, "record_filename"):
msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
{config.dynamo_import}.replay('{exc.record_filename}').\n"

if not config.verbose:
msg += (
f"\nSet {config.dynamo_import}.config.verbose=True for more information\n"
)

if hasattr(exc, "inner_exception") and hasattr(
exc.inner_exception, "minifier_path"
):
msg += (
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
"this script to find the smallest traced graph which reproduces this error.\n"
)

if not config.suppress_errors:
msg += (
"\n\n"
"You can suppress this exception and fall back to eager by setting:\n"
" torch._dynamo.config.suppress_errors = True\n"
)

old_msg = "" if len(exc.args) == 0 else exc.args[0]

if isinstance(exc, KeyError):
exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:]
else:
new_msg = old_msg + msg
exc.args = (new_msg,) + exc.args[1:]


def get_real_stack(exc) -> List[FrameSummary]:
assert hasattr(exc, "real_stack")
return cast(List[FrameSummary], exc.real_stack)


# filter out all frames after entering dynamo
def filter_stack(stack):
user_stack = []
for frame in stack:
if "convert_frame" in frame.filename:
break
if (
"eval_frame" in frame.filename
or f"{config.dynamo_import}.optimize(" in frame.line
):
continue
user_stack.append(frame)

return user_stack


def format_error_msg(exc, code, record_filename=None, frame=None):

msg = os.linesep * 2

if config.verbose:
msg = format_bytecode(
"WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code
)
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
msg += format_exc()
if hasattr(exc, "real_stack"):
msg += (
"\n"
+ "=" * 10
+ " The above exception occurred while processing the following code "
+ "=" * 10
+ "\n\n"
)
stack_above_dynamo = []
if frame is not None:
stack_above_dynamo = filter_stack(extract_stack(frame))

msg += "".join(
format_list(stack_above_dynamo + list(reversed(get_real_stack(exc))))
)
msg += "\n"
msg += "=" * 10

else:
msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
line {code.co_firstlineno} \ndue to: \n{format_exc(limit=-1)}"

return msg
16 changes: 0 additions & 16 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,22 +183,6 @@ def init_logging():
graph_break_dup_warning_checker.reset()


# filter out all frames after entering dynamo
def filter_stack(stack):
user_stack = []
for frame in stack:
if "convert_frame" in frame.filename:
break
if (
"eval_frame" in frame.filename
or f"{config.dynamo_import}.optimize(" in frame.line
):
continue
user_stack.append(frame)

return user_stack


def format_graph_tabular(graph):
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in graph.nodes]
return tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"])
Expand Down

0 comments on commit cac217c

Please sign in to comment.