From cac217c80a596d580e8aaff110599eeb32d44e44 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 19 Jan 2023 02:54:00 +0000 Subject: [PATCH] Fix key error formatting and move exc code to exc.py (#92593) Fixes https://github.com/pytorch/torchdynamo/issues/1953 and moves exception formatting code from convert_frame.py to exc.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/92593 Approved by: https://github.com/ezyang --- torch/_dynamo/convert_frame.py | 85 +----------------------- torch/_dynamo/exc.py | 118 ++++++++++++++++++++++++++++++++- torch/_dynamo/utils.py | 16 ----- 3 files changed, 120 insertions(+), 99 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index b40074e0ed71a..1d74b5a5c7c13 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -2,11 +2,9 @@ import itertools import logging import os -import traceback import types import weakref -from traceback import FrameSummary -from typing import cast, Dict, List, Optional, Set +from typing import Dict, Optional, Set import torch from torch.fx.graph_module import _forward_from_src as original_forward_from_src @@ -17,7 +15,9 @@ from .bytecode_transformation import is_generator, transform_code_object from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher from .exc import ( + augment_exc_message, BackendCompilerFailed, + format_error_msg, InternalTorchDynamoError, TorchRuntimeError, unimplemented, @@ -32,7 +32,6 @@ CleanupManager, counters, dynamo_timed, - filter_stack, format_bytecode, gen_record_file_name, guard_failures, @@ -173,84 +172,6 @@ def has_tensor(obj): return False -def format_error_msg(exc, code, record_filename=None, frame=None): - msg = os.linesep * 2 - - if config.verbose: - msg = format_bytecode( - "WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code - ) - msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" - msg += traceback.format_exc() - if hasattr(exc, "real_stack"): - msg += ( - "\n" - + "=" * 10 - + " The above exception occurred while processing the following code " - + "=" * 10 - + "\n\n" - ) - stack_above_dynamo = [] - if frame is not None: - stack_above_dynamo = filter_stack(traceback.extract_stack(frame)) - - msg += "".join( - traceback.format_list( - stack_above_dynamo + list(reversed(get_real_stack(exc))) - ) - ) - msg += "\n" - msg += "=" * 10 - - else: - msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\ - line {code.co_firstlineno} \ndue to: \n{traceback.format_exc(limit=-1)}" - - return msg - - -def get_real_stack(exc) -> List[FrameSummary]: - assert hasattr(exc, "real_stack") - return cast(List[FrameSummary], exc.real_stack) - - -def augment_exc_message(exc, msg="\n"): - if ( - hasattr(exc, "real_stack") - and len(exc.real_stack) > 0 - and not (config.verbose and config.suppress_errors) - ): - msg += f"\nfrom user code:\n {''.join(traceback.format_list(list(reversed(get_real_stack(exc)[0:2]))))}" - - if config.replay_record_enabled and hasattr(exc, "record_filename"): - msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ - {config.dynamo_import}.replay('{exc.record_filename}').\n" - - if not config.verbose: - msg += ( - f"\nSet {config.dynamo_import}.config.verbose=True for more information\n" - ) - - if hasattr(exc, "inner_exception") and hasattr( - exc.inner_exception, "minifier_path" - ): - msg += ( - f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " - "this script to find the smallest traced graph which reproduces this error.\n" - ) - - if not config.suppress_errors: - msg += ( - "\n\n" - "You can suppress this exception and fall back to eager by setting:\n" - " torch._dynamo.config.suppress_errors = True\n" - ) - - old_msg = "" if len(exc.args) == 0 else exc.args[0] - new_msg = old_msg + msg - exc.args = (new_msg,) + exc.args[1:] - - def exception_handler(e, code, frame=None): record_filename = None if hasattr(e, "exec_record"): diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 41a9f68351aa9..349438def9e05 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -1,7 +1,11 @@ import os import textwrap +from traceback import extract_stack, format_exc, format_list, FrameSummary +from typing import cast, List -from .utils import counters +from . import config + +from .utils import counters, format_bytecode class TorchDynamoException(RuntimeError): @@ -70,3 +74,115 @@ def unimplemented(msg: str): def warning(msg: str): counters["warnings"][msg] += 1 assert msg != os.environ.get("BREAK", False) + + +# KeyError has special handling for its args +# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details +class KeyErrorMsg: + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __repr__(self) -> str: + return self.__str__() + + +def augment_exc_message(exc, msg="\n"): + import traceback + + if ( + hasattr(exc, "real_stack") + and len(exc.real_stack) > 0 + and not (config.verbose and config.suppress_errors) + ): + msg += f"\nfrom user code:\n {''.join(traceback.format_list(list(reversed(get_real_stack(exc)[0:2]))))}" + + if config.replay_record_enabled and hasattr(exc, "record_filename"): + msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ + {config.dynamo_import}.replay('{exc.record_filename}').\n" + + if not config.verbose: + msg += ( + f"\nSet {config.dynamo_import}.config.verbose=True for more information\n" + ) + + if hasattr(exc, "inner_exception") and hasattr( + exc.inner_exception, "minifier_path" + ): + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + "this script to find the smallest traced graph which reproduces this error.\n" + ) + + if not config.suppress_errors: + msg += ( + "\n\n" + "You can suppress this exception and fall back to eager by setting:\n" + " torch._dynamo.config.suppress_errors = True\n" + ) + + old_msg = "" if len(exc.args) == 0 else exc.args[0] + + if isinstance(exc, KeyError): + exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:] + else: + new_msg = old_msg + msg + exc.args = (new_msg,) + exc.args[1:] + + +def get_real_stack(exc) -> List[FrameSummary]: + assert hasattr(exc, "real_stack") + return cast(List[FrameSummary], exc.real_stack) + + +# filter out all frames after entering dynamo +def filter_stack(stack): + user_stack = [] + for frame in stack: + if "convert_frame" in frame.filename: + break + if ( + "eval_frame" in frame.filename + or f"{config.dynamo_import}.optimize(" in frame.line + ): + continue + user_stack.append(frame) + + return user_stack + + +def format_error_msg(exc, code, record_filename=None, frame=None): + + msg = os.linesep * 2 + + if config.verbose: + msg = format_bytecode( + "WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code + ) + msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" + msg += format_exc() + if hasattr(exc, "real_stack"): + msg += ( + "\n" + + "=" * 10 + + " The above exception occurred while processing the following code " + + "=" * 10 + + "\n\n" + ) + stack_above_dynamo = [] + if frame is not None: + stack_above_dynamo = filter_stack(extract_stack(frame)) + + msg += "".join( + format_list(stack_above_dynamo + list(reversed(get_real_stack(exc)))) + ) + msg += "\n" + msg += "=" * 10 + + else: + msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\ + line {code.co_firstlineno} \ndue to: \n{format_exc(limit=-1)}" + + return msg diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index c9ba52183de96..47a50eb3d4eab 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -183,22 +183,6 @@ def init_logging(): graph_break_dup_warning_checker.reset() -# filter out all frames after entering dynamo -def filter_stack(stack): - user_stack = [] - for frame in stack: - if "convert_frame" in frame.filename: - break - if ( - "eval_frame" in frame.filename - or f"{config.dynamo_import}.optimize(" in frame.line - ): - continue - user_stack.append(frame) - - return user_stack - - def format_graph_tabular(graph): node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in graph.nodes] return tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"])