Skip to content

Commit c359c43

Browse files
committed
Add set_file_attr to import_codeobj
1 parent aefbae4 commit c359c43

File tree

3 files changed

+39
-33
lines changed

3 files changed

+39
-33
lines changed

derive/src/compile_bytecode.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
//!
99
//! // the mode to compile the code in
1010
//! mode = "exec", // or "eval" or "single"
11-
//! // the path put into the CodeObject, defaults to `None`
12-
//! source_path = "frozen",
11+
//! // the path put into the CodeObject, defaults to "frozen"
12+
//! module_name = "frozen",
1313
//! )
1414
//! ```
1515
@@ -35,9 +35,9 @@ struct CompilationSource {
3535
}
3636

3737
impl CompilationSource {
38-
fn compile(self, mode: &compile::Mode, source_path: String) -> Result<CodeObject, Diagnostic> {
38+
fn compile(self, mode: &compile::Mode, module_name: String) -> Result<CodeObject, Diagnostic> {
3939
let compile = |source| {
40-
compile::compile(source, mode, source_path).map_err(|err| {
40+
compile::compile(source, mode, module_name).map_err(|err| {
4141
Diagnostic::spans_error(self.span, format!("Compile error: {}", err))
4242
})
4343
};
@@ -69,7 +69,7 @@ struct PyCompileInput {
6969

7070
impl PyCompileInput {
7171
fn compile(&self) -> Result<CodeObject, Diagnostic> {
72-
let mut source_path = None;
72+
let mut module_name = None;
7373
let mut mode = None;
7474
let mut source: Option<CompilationSource> = None;
7575

@@ -97,10 +97,10 @@ impl PyCompileInput {
9797
},
9898
_ => bail_span!(name_value.lit, "mode must be a string"),
9999
})
100-
} else if name_value.ident == "source_path" {
101-
source_path = Some(match &name_value.lit {
100+
} else if name_value.ident == "module_name" {
101+
module_name = Some(match &name_value.lit {
102102
Lit::Str(s) => s.value(),
103-
_ => bail_span!(name_value.lit, "source_path must be string"),
103+
_ => bail_span!(name_value.lit, "module_name must be string"),
104104
})
105105
} else if name_value.ident == "source" {
106106
assert_source_empty(&source)?;
@@ -137,7 +137,7 @@ impl PyCompileInput {
137137
})?
138138
.compile(
139139
&mode.unwrap_or(compile::Mode::Exec),
140-
source_path.unwrap_or_else(|| "frozen".to_string()),
140+
module_name.unwrap_or_else(|| "frozen".to_string()),
141141
)
142142
}
143143
}

vm/src/frozen.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@ pub fn get_module_inits() -> HashMap<String, CodeObject> {
55
hashmap! {
66
"__hello__".into() => py_compile_bytecode!(
77
source = "initialized = True; print(\"Hello world!\")\n",
8+
module_name = "__hello__",
89
),
910
"_frozen_importlib".into() => py_compile_bytecode!(
1011
file = "../Lib/importlib/_bootstrap.py",
12+
module_name = "_frozen_importlib",
1113
),
1214
"_frozen_importlib_external".into() => py_compile_bytecode!(
1315
file = "../Lib/importlib/_bootstrap_external.py",
16+
module_name = "_frozen_importlib_external",
1417
),
1518
}
1619
}

vm/src/import.rs

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,24 @@ pub fn init_importlib(vm: &VirtualMachine) -> PyResult {
2525
}
2626

2727
pub fn import_frozen(vm: &VirtualMachine, module_name: &str) -> PyResult {
28-
if let Some(frozen) = vm.frozen.borrow().get(module_name) {
29-
let mut frozen = frozen.clone();
30-
frozen.source_path = format!("frozen {}", module_name);
31-
import_codeobj(vm, module_name, frozen)
32-
} else {
33-
Err(vm.new_import_error(format!("Cannot import frozen module {}", module_name)))
34-
}
28+
vm.frozen
29+
.borrow()
30+
.get(module_name)
31+
.ok_or_else(|| vm.new_import_error(format!("Cannot import frozen module {}", module_name)))
32+
.and_then(|frozen| import_codeobj(vm, module_name, frozen.clone(), false))
3533
}
3634

3735
pub fn import_builtin(vm: &VirtualMachine, module_name: &str) -> PyResult {
38-
let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules").unwrap();
39-
if let Some(make_module_func) = vm.stdlib_inits.borrow().get(module_name) {
40-
let module = make_module_func(vm);
41-
sys_modules.set_item(module_name, module.clone(), vm)?;
42-
Ok(module)
43-
} else {
44-
Err(vm.new_import_error(format!("Cannot import bultin module {}", module_name)))
45-
}
36+
vm.stdlib_inits
37+
.borrow()
38+
.get(module_name)
39+
.ok_or_else(|| vm.new_import_error(format!("Cannot import bultin module {}", module_name)))
40+
.and_then(|make_module_func| {
41+
let module = make_module_func(vm);
42+
let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules")?;
43+
sys_modules.set_item(module_name, module.clone(), vm)?;
44+
Ok(module)
45+
})
4646
}
4747

4848
pub fn import_module(vm: &VirtualMachine, current_path: PathBuf, module_name: &str) -> PyResult {
@@ -57,8 +57,8 @@ pub fn import_module(vm: &VirtualMachine, current_path: PathBuf, module_name: &s
5757
} else if vm.stdlib_inits.borrow().contains_key(module_name) {
5858
import_builtin(vm, module_name)
5959
} else {
60-
let notfound_error = vm.context().exceptions.module_not_found_error.clone();
61-
let import_error = vm.context().exceptions.import_error.clone();
60+
let notfound_error = &vm.ctx.exceptions.module_not_found_error;
61+
let import_error = &vm.ctx.exceptions.import_error;
6262

6363
// Time to search for module in any place:
6464
let file_path = find_source(vm, current_path, module_name)
@@ -83,21 +83,24 @@ pub fn import_file(
8383
) -> PyResult {
8484
let code_obj = compile::compile(&content, &compile::Mode::Exec, file_path)
8585
.map_err(|err| vm.new_syntax_error(&err))?;
86-
import_codeobj(vm, module_name, code_obj)
86+
import_codeobj(vm, module_name, code_obj, true)
8787
}
8888

89-
pub fn import_codeobj(vm: &VirtualMachine, module_name: &str, code_obj: CodeObject) -> PyResult {
89+
pub fn import_codeobj(
90+
vm: &VirtualMachine,
91+
module_name: &str,
92+
code_obj: CodeObject,
93+
set_file_attr: bool,
94+
) -> PyResult {
9095
let attrs = vm.ctx.new_dict();
9196
attrs.set_item("__name__", vm.new_str(module_name.to_string()), vm)?;
92-
let file_path = &code_obj.source_path;
93-
if !file_path.starts_with("frozen") {
94-
// TODO: Should be less hacky, not depend on source_path
95-
attrs.set_item("__file__", vm.new_str(file_path.to_owned()), vm)?;
97+
if set_file_attr {
98+
attrs.set_item("__file__", vm.new_str(code_obj.source_path.to_owned()), vm)?;
9699
}
97100
let module = vm.ctx.new_module(module_name, attrs.clone());
98101

99102
// Store module in cache to prevent infinite loop with mutual importing libs:
100-
let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules").unwrap();
103+
let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules")?;
101104
sys_modules.set_item(module_name, module.clone(), vm)?;
102105

103106
// Execute main code in module:

0 commit comments

Comments
 (0)