Skip to content

Commit

Permalink
[Internal][Executor] Add unit tests for exception_utils (microsoft#595)
Browse files Browse the repository at this point in the history
# Description

Add unit tests for exception_utils

# All Promptflow Contribution checklist:
- [X] **The pull request does not introduce [breaking changes]**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).**

## General Guidelines and Best Practices
- [X] Title of the pull request is clear and informative.
- [X] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [X] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Peiwen Gao <[email protected]>
  • Loading branch information
PeiwenGaoMS and GPW9795 authored Oct 8, 2023
1 parent 7a6a487 commit 16cf176
Showing 1 changed file with 177 additions and 29 deletions.
206 changes: 177 additions & 29 deletions src/promptflow/tests/executor/unittests/_utils/test_exception_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import re
from traceback import TracebackException

import pytest

Expand All @@ -9,9 +10,17 @@
ErrorResponse,
ExceptionPresenter,
JsonSerializedPromptflowException,
get_tb_next,
infer_error_code_from_class,
last_frame_info,
)
from promptflow.exceptions import (
ErrorTarget,
PromptflowException,
SystemErrorException,
UserErrorException,
ValidationException,
)
from promptflow.exceptions import ErrorTarget, PromptflowException, SystemErrorException, UserErrorException


def set_inner_exception_by_parameter():
Expand Down Expand Up @@ -44,6 +53,23 @@ def raise_user_error():
raise UserErrorException("run failed", target=ErrorTarget.TOOL) from e


def raise_context_exception():
try:
code_with_bug()
except Exception as e:
raise CustomizedContextException(e)


class CustomizedContextException(Exception):
def __init__(self, inner_exception):
self.inner_exception = inner_exception

@property
def message(self):
code_with_bug()
return "context exception"


class CustomizedException(Exception):
pass

Expand Down Expand Up @@ -78,11 +104,17 @@ def raise_promptflow_exception_without_inner_exception():
raise PromptflowException("Promptflow exception")


TOOL_EXECUTION_ERROR_TRACEBACK = r"""Traceback \(most recent call last\):
File ".*test_exception_utils.py", line .*, in code_with_bug
1 / 0
ZeroDivisionError: division by zero
"""

TOOL_EXCEPTION_TRACEBACK = r"""
The above exception was the direct cause of the following exception:
Traceback \(most recent call last\):
File ".*test_exception_utils.py", line .*, in test_debug_info
File ".*test_exception_utils.py", line .*, in test_.*
raise_tool_execution_error\(\)
File ".*test_exception_utils.py", line .*, in raise_tool_execution_error
raise ToolExecutionError\(node_name="MyTool"\) from e
Expand Down Expand Up @@ -112,19 +144,55 @@ def raise_promptflow_exception_without_inner_exception():
1 / 0
"""

CONTEXT_EXCEPTION_TRACEBACK = r"""
During handling of the above exception, another exception occurred:
Traceback \(most recent call last\):
File ".*test_exception_utils.py", line .*, in test_debug_info_for_context_exception
raise_context_exception\(\)
File ".*test_exception_utils.py", line .*, in raise_context_exception
raise CustomizedContextException\(e\)
"""

CONTEXT_EXCEPTION_INNER_TRACEBACK = r"""Traceback \(most recent call last\):
File ".*test_exception_utils.py", line .*, in raise_context_exception
code_with_bug\(\)
File ".*test_exception_utils.py", line .*, in code_with_bug
1 / 0
"""


@pytest.mark.unittest
@pytest.mark.parametrize(
"clz, expected",
[
(UserErrorException, "UserError"),
(SystemErrorException, "SystemError"),
(ToolExecutionError, "ToolExecutionError"),
(ValueError, "ValueError"),
],
)
def test_infer_error_code_from_class(clz, expected):
assert infer_error_code_from_class(clz) == expected
class TestExceptionUtilsCommonMethod:
def test_get_tb_next(self):
with pytest.raises(ToolExecutionError) as e:
raise_tool_execution_error()
tb_next = get_tb_next(e.value.__traceback__, 3)
te = TracebackException(type(e.value), e.value, tb_next)
formatted_tb = "".join(te.format())
assert re.match(TOOL_EXCEPTION_INNER_TRACEBACK, formatted_tb)

def test_last_frame_info(self):
with pytest.raises(ToolExecutionError) as e:
raise_tool_execution_error()
frame_info = last_frame_info(e.value)
assert "test_exception_utils.py" in frame_info.get("filename")
assert frame_info.get("lineno") > 0
assert frame_info.get("name") == "raise_tool_execution_error"
assert last_frame_info(None) == {}

@pytest.mark.parametrize(
"error_class, expected_error_code",
[
(UserErrorException, "UserError"),
(SystemErrorException, "SystemError"),
(ValidationException, "ValidationError"),
(ToolExecutionError, "ToolExecutionError"),
(ValueError, "ValueError"),
],
)
def test_infer_error_code_from_class(self, error_class, expected_error_code):
assert infer_error_code_from_class(error_class) == expected_error_code


@pytest.mark.unittest
Expand All @@ -143,6 +211,18 @@ def test_debug_info(self):
assert inner_exception["type"] == "ZeroDivisionError"
assert re.match(TOOL_EXCEPTION_INNER_TRACEBACK, inner_exception["stackTrace"])

def test_debug_info_for_context_exception(self):
with pytest.raises(CustomizedContextException) as e:
raise_context_exception()
presenter = ExceptionPresenter.create(e.value)
debug_info = presenter.debug_info
assert debug_info["type"] == "CustomizedContextException"
assert re.match(CONTEXT_EXCEPTION_TRACEBACK, debug_info["stackTrace"])

inner_exception = debug_info["innerException"]
assert inner_exception["type"] == "ZeroDivisionError"
assert re.match(CONTEXT_EXCEPTION_INNER_TRACEBACK, inner_exception["stackTrace"])

def test_debug_info_for_general_exception(self):
# Test General Exception
with pytest.raises(CustomizedException) as e:
Expand All @@ -162,7 +242,9 @@ def test_to_dict_for_general_exception(self):
raise_general_exception()

presenter = ExceptionPresenter.create(e.value)
dct = presenter.to_dict(include_debug_info=False)
dct = presenter.to_dict(include_debug_info=True)
assert "debugInfo" in dct
dct.pop("debugInfo")
assert dct == {
"code": "SystemError",
"message": "General exception",
Expand Down Expand Up @@ -212,6 +294,8 @@ def test_to_dict_for_tool_execution_error(self):
raise_tool_execution_error()

presenter = ExceptionPresenter.create(e.value)
assert re.search(TOOL_EXCEPTION_INNER_TRACEBACK, presenter.formatted_traceback)
assert re.search(TOOL_EXCEPTION_TRACEBACK, presenter.formatted_traceback)
dct = presenter.to_dict(include_debug_info=False)
assert dct.pop("additionalInfo") is not None
assert dct == {
Expand Down Expand Up @@ -255,6 +339,8 @@ def test_from_error_dict(self):
}
response = ErrorResponse.from_error_dict(error_dict)
assert response.response_code == "400"
assert response.error_codes == ["UserError"]
assert response.message == "Flow run failed."
response_dct = response.to_dict()
assert response_dct["time"] is not None
response_dct.pop("time")
Expand All @@ -271,6 +357,17 @@ def test_from_error_dict(self):
"location": None,
}

def test_to_simplied_dict(self):
with pytest.raises(CustomizedException) as e:
raise_general_exception()
error_response = ErrorResponse.from_exception(e.value)
assert error_response.to_simplified_dict() == {
"error": {
"code": "SystemError",
"message": "General exception",
}
}

def test_from_exception(self):
with pytest.raises(CustomizedException) as e:
raise_general_exception()
Expand Down Expand Up @@ -334,6 +431,65 @@ def test_innermost_error_code_with_code(self, error_dict, expected_innermost_err

assert inner_error_code == expected_innermost_error_code

@pytest.mark.parametrize(
"error_dict, expected_additional_info",
[
({"code": "UserError"}, {}),
(
{
"code": "UserError",
"additionalInfo": [
{
"type": "test_additional_info",
"info": "This is additional info for testing.",
},
"not_dict",
{
"type": "empty_info",
},
{
"info": "Empty type",
},
{
"test": "Invalid additional info",
},
],
},
{"test_additional_info": "This is additional info for testing."},
),
],
)
def test_additional_info(self, error_dict, expected_additional_info):
error_response = ErrorResponse.from_error_dict(error_dict)
assert error_response.additional_info == expected_additional_info
assert all(error_response.get_additional_info(key) == value for key, value in expected_additional_info.items())

@pytest.mark.parametrize(
"raise_exception_func, error_class",
[
(raise_general_exception, CustomizedException),
(raise_tool_execution_error, ToolExecutionError),
],
)
def test_get_user_execution_error_info(self, raise_exception_func, error_class):
with pytest.raises(error_class) as e:
raise_exception_func()

error_repsonse = ErrorResponse.from_exception(e.value)
actual_error_info = error_repsonse.get_user_execution_error_info()
self.assert_user_execution_error_info(e.value, actual_error_info)

def assert_user_execution_error_info(self, exception, error_info):
if isinstance(exception, ToolExecutionError):
assert error_info["type"] == "ZeroDivisionError"
assert error_info["message"] == "division by zero"
assert error_info["filename"].endswith("test_exception_utils.py")
assert error_info["lineno"] > 0
assert error_info["name"] == "code_with_bug"
assert re.match(TOOL_EXECUTION_ERROR_TRACEBACK, error_info["traceback"])
else:
assert error_info == {}


@pytest.mark.unittest
class TestExceptions:
Expand Down Expand Up @@ -495,14 +651,7 @@ def test_tool_execution_error(self):
assert last_frame_info.get("lineno") > 0
assert last_frame_info.get("name") == "code_with_bug"

assert re.match(
r"Traceback \(most recent call last\):\n"
r' File ".*test_exception_utils.py", line .*, in code_with_bug\n'
r" 1 / 0\n"
r"(.*\n)?" # Python >= 3.11 add extra line here like a pointer.
r"ZeroDivisionError: division by zero\n",
e.value.tool_traceback,
)
assert re.match(TOOL_EXECUTION_ERROR_TRACEBACK, e.value.tool_traceback)

def test_code_hierarchy(self):
with pytest.raises(ToolExecutionError) as e:
Expand Down Expand Up @@ -547,13 +696,7 @@ def test_additional_info(self):
assert re.match(r".*test_exception_utils.py", info_0_value["filename"])
assert info_0_value.get("lineno") > 0
assert info_0_value.get("name") == "code_with_bug"
assert re.match(
r"Traceback \(most recent call last\):\n"
r' File ".*test_exception_utils.py", line .*, in code_with_bug\n'
r" 1 / 0\n"
r"ZeroDivisionError: division by zero\n",
info_0_value.get("traceback"),
)
assert re.match(TOOL_EXECUTION_ERROR_TRACEBACK, info_0_value.get("traceback"))

def test_additional_info_for_empty_inner_error(self):
ex = ToolExecutionError(node_name="Node1")
Expand Down Expand Up @@ -634,7 +777,12 @@ def test_to_dict_for_JsonSerializedPromptflowException(self, include_debug_info)
exception_dict = ExceptionPresenter.create(e.value).to_dict(include_debug_info=True)
message = json.dumps(exception_dict)
exception = JsonSerializedPromptflowException(message=message)
assert str(exception) == message
json_serialized_exception_dict = ExceptionPresenter.create(exception).to_dict(
include_debug_info=include_debug_info
)
error_dict = exception.to_dict(include_debug_info=include_debug_info)
assert error_dict == json_serialized_exception_dict

if include_debug_info:
assert "debugInfo" in error_dict
Expand Down

0 comments on commit 16cf176

Please sign in to comment.