Skip to content

Commit 615a121

Browse files
authored
Merge pull request RustPython#841 from RustPython/mutual-import
Support mutual importing modules.
2 parents 240c1e4 + 2a8b586 commit 615a121

File tree

4 files changed

+44
-33
lines changed

4 files changed

+44
-33
lines changed

tests/snippets/import.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from import_target import func as aliased_func, other_func as aliased_other_func
44
from import_star import *
55

6+
import import_mutual1
67
assert import_target.X == import_target.func()
78
assert import_target.X == func()
89

tests/snippets/import_mutual1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
# Mutual recursive import:
3+
import import_mutual2
4+

tests/snippets/import_mutual2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
# Mutual recursive import:
3+
import import_mutual1

vm/src/import.rs

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,47 @@ use crate::pyobject::{ItemProtocol, PyResult};
1111
use crate::util;
1212
use crate::vm::VirtualMachine;
1313

14-
fn import_uncached_module(vm: &VirtualMachine, current_path: PathBuf, module: &str) -> PyResult {
15-
// Check for Rust-native modules
16-
if let Some(module) = vm.stdlib_inits.borrow().get(module) {
17-
return Ok(module(vm).clone());
18-
}
14+
pub fn import_module(vm: &VirtualMachine, current_path: PathBuf, module_name: &str) -> PyResult {
15+
// Cached modules:
16+
let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules").unwrap();
1917

20-
let notfound_error = vm.context().exceptions.module_not_found_error.clone();
21-
let import_error = vm.context().exceptions.import_error.clone();
18+
// First, see if we already loaded the module:
19+
if let Ok(module) = sys_modules.get_item(module_name.to_string(), vm) {
20+
Ok(module)
21+
} else if let Some(make_module_func) = vm.stdlib_inits.borrow().get(module_name) {
22+
let module = make_module_func(vm);
23+
sys_modules.set_item(module_name, module.clone(), vm)?;
24+
Ok(module)
25+
} else {
26+
let notfound_error = vm.context().exceptions.module_not_found_error.clone();
27+
let import_error = vm.context().exceptions.import_error.clone();
2228

23-
// Time to search for module in any place:
24-
let file_path = find_source(vm, current_path, module)
25-
.map_err(|e| vm.new_exception(notfound_error.clone(), e))?;
26-
let source = util::read_file(file_path.as_path())
27-
.map_err(|e| vm.new_exception(import_error.clone(), e.to_string()))?;
28-
let code_obj = compile::compile(
29-
vm,
30-
&source,
31-
&compile::Mode::Exec,
32-
file_path.to_str().unwrap().to_string(),
33-
)
34-
.map_err(|err| vm.new_syntax_error(&err))?;
35-
// trace!("Code object: {:?}", code_obj);
29+
// Time to search for module in any place:
30+
let file_path = find_source(vm, current_path, module_name)
31+
.map_err(|e| vm.new_exception(notfound_error.clone(), e))?;
32+
let source = util::read_file(file_path.as_path())
33+
.map_err(|e| vm.new_exception(import_error.clone(), e.to_string()))?;
34+
let code_obj = compile::compile(
35+
vm,
36+
&source,
37+
&compile::Mode::Exec,
38+
file_path.to_str().unwrap().to_string(),
39+
)
40+
.map_err(|err| vm.new_syntax_error(&err))?;
41+
// trace!("Code object: {:?}", code_obj);
3642

37-
let attrs = vm.ctx.new_dict();
38-
attrs.set_item("__name__", vm.new_str(module.to_string()), vm)?;
39-
vm.run_code_obj(code_obj, Scope::new(None, attrs.clone()))?;
40-
Ok(vm.ctx.new_module(module, attrs))
41-
}
43+
let attrs = vm.ctx.new_dict();
44+
attrs.set_item("__name__", vm.new_str(module_name.to_string()), vm)?;
45+
let module = vm.ctx.new_module(module_name, attrs.clone());
4246

43-
pub fn import_module(vm: &VirtualMachine, current_path: PathBuf, module_name: &str) -> PyResult {
44-
// First, see if we already loaded the module:
45-
let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules")?;
46-
if let Ok(module) = sys_modules.get_item(module_name.to_string(), vm) {
47-
return Ok(module);
47+
// Store module in cache to prevent infinite loop with mutual importing libs:
48+
sys_modules.set_item(module_name, module.clone(), vm)?;
49+
50+
// Execute main code in module:
51+
vm.run_code_obj(code_obj, Scope::new(None, attrs))?;
52+
53+
Ok(module)
4854
}
49-
let module = import_uncached_module(vm, current_path, module_name)?;
50-
sys_modules.set_item(module_name, module.clone(), vm)?;
51-
Ok(module)
5255
}
5356

5457
fn find_source(vm: &VirtualMachine, current_path: PathBuf, name: &str) -> Result<PathBuf, String> {

0 commit comments

Comments
 (0)