Skip to content

Commit

Permalink
Merge pull request #98 from dbarnett/typechecked_return
Browse files Browse the repository at this point in the history
Change typechecked to consistently return exact input type
  • Loading branch information
Stewori authored Aug 6, 2020
2 parents befcedd + 0d74f44 commit 4004925
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
12 changes: 8 additions & 4 deletions pytypes/typechecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,15 +964,19 @@ def typechecked_module(md, force_recursive = False):
"""
if not pytypes.checking_enabled:
return md
# Save input to return original string if input was a string.
md_arg = md
if isinstance(md, str):
if md in sys.modules:
md = sys.modules[md]
if md is None:
return md
return md_arg
elif md in _pending_modules:
# if import is pending, we just store this call for later
_pending_modules[md].append(lambda t: typechecked_module(t, True))
return md
return md_arg
else:
raise KeyError('Found no module {!r} to typecheck'.format(md))
assert(ismodule(md))
if md.__name__ in _pending_modules:
# if import is pending, we just store this call for later
Expand All @@ -981,7 +985,7 @@ def typechecked_module(md, force_recursive = False):
# todo: Issue warning here that not the whole module might be covered yet
if md.__name__ in _fully_typechecked_modules and \
_fully_typechecked_modules[md.__name__] == len(md.__dict__):
return md
return md_arg
# To play it safe we avoid to modify the dict while iterating over it,
# so we previously cache keys.
# For this we don't use keys() because of Python 3.
Expand All @@ -997,7 +1001,7 @@ def typechecked_module(md, force_recursive = False):
typechecked_class(memb, force_recursive, force_recursive)
if not md.__name__ in _pending_modules:
_fully_typechecked_modules[md.__name__] = len(md.__dict__)
return md
return md_arg


def typechecked(memb):
Expand Down
11 changes: 9 additions & 2 deletions tests/test_typechecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2651,8 +2651,14 @@ class TestTypecheck_module(unittest.TestCase):
def test_function_py2(self):
from testhelpers import modulewide_typecheck_testhelper_py2 as mth
self.assertEqual(mth.testfunc(3, 2.5, 'abcd'), (9, 7.5))
with self.assertRaises(KeyError):
pytypes.typechecked_module('nonexistent123')
self.assertEqual(mth.testfunc(3, 2.5, 7), (9, 7.5)) # would normally fail
pytypes.typechecked_module(mth)
module_name = 'testhelpers.modulewide_typecheck_testhelper_py2'
returned_mth = pytypes.typechecked_module(module_name)
self.assertEqual(returned_mth, module_name)
returned_mth = pytypes.typechecked_module(mth)
self.assertEqual(returned_mth, mth)
self.assertEqual(mth.testfunc(3, 2.5, 'abcd'), (9, 7.5))
self.assertRaises(InputTypeError, lambda: mth.testfunc(3, 2.5, 7))

Expand All @@ -2662,7 +2668,8 @@ def test_function_py3(self):
from testhelpers import modulewide_typecheck_testhelper as mth
self.assertEqual(mth.testfunc(3, 2.5, 'abcd'), (9, 7.5))
self.assertEqual(mth.testfunc(3, 2.5, 7), (9, 7.5)) # would normally fail
pytypes.typechecked_module(mth)
returned_mth = pytypes.typechecked_module(mth)
self.assertEqual(returned_mth, mth)
self.assertEqual(mth.testfunc(3, 2.5, 'abcd'), (9, 7.5))
self.assertRaises(InputTypeError, lambda: mth.testfunc(3, 2.5, 7))

Expand Down

0 comments on commit 4004925

Please sign in to comment.