Skip to content

Commit 06563a3

Browse files
committed
Support mutual importing modules.
1 parent 75e8f81 commit 06563a3

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

tests/snippets/import.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from import_target import func as aliased_func, other_func as aliased_other_func
44
from import_star import *
55

6+
import import_mutual1
67
assert import_target.X == import_target.func()
78
assert import_target.X == func()
89

tests/snippets/import_mutual1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
# Mutual recursive import:
3+
import import_mutual2
4+

tests/snippets/import_mutual2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
# Mutual recursive import:
3+
import import_mutual1

vm/src/import.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,21 @@ use crate::pyobject::{ItemProtocol, PyResult};
1111
use crate::util;
1212
use crate::vm::VirtualMachine;
1313

14-
fn import_uncached_module(vm: &VirtualMachine, current_path: PathBuf, module: &str) -> PyResult {
14+
fn import_uncached_module(
15+
vm: &VirtualMachine,
16+
current_path: PathBuf,
17+
module_name: &str,
18+
) -> PyResult {
1519
// Check for Rust-native modules
16-
if let Some(module) = vm.stdlib_inits.borrow().get(module) {
20+
if let Some(module) = vm.stdlib_inits.borrow().get(module_name) {
1721
return Ok(module(vm).clone());
1822
}
1923

2024
let notfound_error = vm.context().exceptions.module_not_found_error.clone();
2125
let import_error = vm.context().exceptions.import_error.clone();
2226

2327
// Time to search for module in any place:
24-
let file_path = find_source(vm, current_path, module)
28+
let file_path = find_source(vm, current_path, module_name)
2529
.map_err(|e| vm.new_exception(notfound_error.clone(), e))?;
2630
let source = util::read_file(file_path.as_path())
2731
.map_err(|e| vm.new_exception(import_error.clone(), e.to_string()))?;
@@ -35,19 +39,25 @@ fn import_uncached_module(vm: &VirtualMachine, current_path: PathBuf, module: &s
3539
// trace!("Code object: {:?}", code_obj);
3640

3741
let attrs = vm.ctx.new_dict();
38-
attrs.set_item("__name__", vm.new_str(module.to_string()), vm)?;
39-
vm.run_code_obj(code_obj, Scope::new(None, attrs.clone()))?;
40-
Ok(vm.ctx.new_module(module, attrs))
42+
attrs.set_item("__name__", vm.new_str(module_name.to_string()), vm)?;
43+
let module = vm.ctx.new_module(module_name, attrs.clone());
44+
45+
// Store module in cache to prevent infinite loop with mutual importing libs:
46+
let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules").unwrap();
47+
sys_modules.set_item(module_name, module.clone(), vm)?;
48+
49+
// Execute main code in module:
50+
vm.run_code_obj(code_obj, Scope::new(None, attrs))?;
51+
Ok(module)
4152
}
4253

4354
pub fn import_module(vm: &VirtualMachine, current_path: PathBuf, module_name: &str) -> PyResult {
4455
// First, see if we already loaded the module:
45-
let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules")?;
56+
let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules").unwrap();
4657
if let Ok(module) = sys_modules.get_item(module_name.to_string(), vm) {
4758
return Ok(module);
4859
}
4960
let module = import_uncached_module(vm, current_path, module_name)?;
50-
sys_modules.set_item(module_name, module.clone(), vm)?;
5161
Ok(module)
5262
}
5363

0 commit comments

Comments
 (0)