Skip to content

Commit

Permalink
Add the correct path when it's a package
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fraenkel committed May 3, 2023
1 parent 0fa82cf commit 75206df
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 1 deletion.
6 changes: 5 additions & 1 deletion integration/_support/package/tasks/module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from invoke import task
from . import pytest as pt
import pytest

pytest.__version__


@task
def mytask(c):
print("hi!")
print(pt.hi)
1 change: 1 addition & 0 deletions integration/_support/package/tasks/pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
hi = "hi!"
2 changes: 2 additions & 0 deletions invoke/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def load(self, name: Optional[str] = None) -> Tuple[ModuleType, str]:
# being imported is trying to load local-to-it names.
if os.path.isfile(spec.origin):
path = os.path.dirname(spec.origin)
if spec.origin.endswith("__init__.py"):
path = os.path.dirname(path)
if path not in sys.path:
sys.path.insert(0, path)
# Actual import
Expand Down
6 changes: 6 additions & 0 deletions tests/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def adds_module_parent_dir_to_sys_path(self):
# Crummy doesn't-explode test.
_BasicLoader().load("namespacing")

def adds_package_dir_to_sys_path(self):
config = Config({"tasks": {"collection_name": "module"}})
_BasicLoader(config).load("package")
package = Path(support) / "package"
assert str(package) not in sys.path

def doesnt_duplicate_parent_dir_addition(self):
_BasicLoader().load("namespacing")
_BasicLoader().load("namespacing")
Expand Down

0 comments on commit 75206df

Please sign in to comment.