Skip to content

Commit 3f665ce

Browse files
committed
Replace PyObjectPayload::Module with PyModule.
1 parent 0febcb9 commit 3f665ce

File tree

3 files changed

+65
-73
lines changed

3 files changed

+65
-73
lines changed

vm/src/obj/objmodule.rs

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,36 @@
11
use crate::frame::ScopeRef;
2-
use crate::pyobject::{
3-
DictProtocol, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
4-
};
2+
use crate::function::PyRef;
3+
use crate::pyobject::{DictProtocol, PyContext, PyObjectPayload2, PyObjectRef, PyResult};
54
use crate::vm::VirtualMachine;
65

7-
fn module_dir(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
8-
arg_check!(vm, args, required = [(obj, Some(vm.ctx.module_type()))]);
9-
let scope = get_scope(obj);
10-
let keys = scope
11-
.locals
12-
.get_key_value_pairs()
13-
.iter()
14-
.map(|(k, _v)| k.clone())
15-
.collect();
16-
Ok(vm.ctx.new_list(keys))
6+
#[derive(Clone, Debug)]
7+
pub struct PyModule {
8+
pub name: String,
9+
pub scope: ScopeRef,
1710
}
11+
pub type PyModuleRef = PyRef<PyModule>;
1812

19-
pub fn init(context: &PyContext) {
20-
let module_type = &context.module_type;
21-
context.set_attr(&module_type, "__dir__", context.new_rustfunc(module_dir));
13+
impl PyObjectPayload2 for PyModule {
14+
fn required_type(ctx: &PyContext) -> PyObjectRef {
15+
ctx.module_type()
16+
}
2217
}
2318

24-
fn get_scope(obj: &PyObjectRef) -> &ScopeRef {
25-
if let PyObjectPayload::Module { ref scope, .. } = &obj.payload {
26-
&scope
27-
} else {
28-
panic!("Can't get scope from non-module.")
19+
impl PyModuleRef {
20+
fn dir(self, vm: &mut VirtualMachine) -> PyResult {
21+
let keys = self
22+
.scope
23+
.locals
24+
.get_key_value_pairs()
25+
.iter()
26+
.map(|(k, _v)| k.clone())
27+
.collect();
28+
Ok(vm.ctx.new_list(keys))
2929
}
3030
}
31+
32+
pub fn init(context: &PyContext) {
33+
extend_class!(&context, &context.module_type, {
34+
"__dir__" => context.new_rustfunc(PyModuleRef::dir)
35+
});
36+
}

vm/src/pyobject.rs

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use crate::obj::objiter;
3131
use crate::obj::objlist::{self, PyList};
3232
use crate::obj::objmap;
3333
use crate::obj::objmemory;
34-
use crate::obj::objmodule;
34+
use crate::obj::objmodule::{self, PyModule};
3535
use crate::obj::objnone;
3636
use crate::obj::objobject;
3737
use crate::obj::objproperty;
@@ -92,10 +92,10 @@ impl fmt::Display for PyObject {
9292
}
9393
}
9494

95-
match &self.payload {
96-
PyObjectPayload::Module { name, .. } => write!(f, "module '{}'", name),
97-
_ => write!(f, "'{}' object", objtype::get_type_name(&self.typ())),
95+
if let Some(PyModule { ref name, .. }) = self.payload::<PyModule>() {
96+
return write!(f, "module '{}'", name);
9897
}
98+
write!(f, "'{}' object", objtype::get_type_name(&self.typ()))
9999
}
100100
}
101101

@@ -600,9 +600,11 @@ impl PyContext {
600600

601601
pub fn new_module(&self, name: &str, scope: ScopeRef) -> PyObjectRef {
602602
PyObject::new(
603-
PyObjectPayload::Module {
604-
name: name.to_string(),
605-
scope,
603+
PyObjectPayload::AnyRustValue {
604+
value: Box::new(PyModule {
605+
name: name.to_string(),
606+
scope,
607+
}),
606608
},
607609
self.module_type.clone(),
608610
)
@@ -730,7 +732,7 @@ impl PyContext {
730732
}
731733

732734
pub fn set_attr(&self, obj: &PyObjectRef, attr_name: &str, value: PyObjectRef) {
733-
if let PyObjectPayload::Module { ref scope, .. } = obj.payload {
735+
if let Some(PyModule { ref scope, .. }) = obj.payload::<PyModule>() {
734736
scope.locals.set_item(self, attr_name, value)
735737
} else if let Some(ref dict) = obj.dict {
736738
dict.borrow_mut().insert(attr_name.to_string(), value);
@@ -850,15 +852,14 @@ impl AttributeProtocol for PyObjectRef {
850852
return None;
851853
}
852854

853-
match self.payload {
854-
PyObjectPayload::Module { ref scope, .. } => scope.locals.get_item(attr_name),
855-
_ => {
856-
if let Some(ref dict) = self.dict {
857-
dict.borrow().get(attr_name).cloned()
858-
} else {
859-
None
860-
}
861-
}
855+
if let Some(PyModule { ref scope, .. }) = self.payload::<PyModule>() {
856+
return scope.locals.get_item(attr_name);
857+
}
858+
859+
if let Some(ref dict) = self.dict {
860+
dict.borrow().get(attr_name).cloned()
861+
} else {
862+
None
862863
}
863864
}
864865

@@ -868,15 +869,14 @@ impl AttributeProtocol for PyObjectRef {
868869
|| mro.iter().any(|d| class_has_item(d, attr_name));
869870
}
870871

871-
match self.payload {
872-
PyObjectPayload::Module { ref scope, .. } => scope.locals.contains_key(attr_name),
873-
_ => {
874-
if let Some(ref dict) = self.dict {
875-
dict.borrow().contains_key(attr_name)
876-
} else {
877-
false
878-
}
879-
}
872+
if let Some(PyModule { ref scope, .. }) = self.payload::<PyModule>() {
873+
return scope.locals.contains_key(attr_name);
874+
}
875+
876+
if let Some(ref dict) = self.dict {
877+
dict.borrow().contains_key(attr_name)
878+
} else {
879+
false
880880
}
881881
}
882882
}
@@ -900,22 +900,20 @@ impl DictProtocol for PyObjectRef {
900900
fn get_item(&self, k: &str) -> Option<PyObjectRef> {
901901
if let Some(dict) = self.payload::<PyDict>() {
902902
objdict::content_get_key_str(&dict.entries.borrow(), k)
903+
} else if let Some(PyModule { ref scope, .. }) = self.payload::<PyModule>() {
904+
scope.locals.get_item(k)
903905
} else {
904-
match self.payload {
905-
PyObjectPayload::Module { ref scope, .. } => scope.locals.get_item(k),
906-
ref k => panic!("TODO {:?}", k),
907-
}
906+
panic!("TODO {:?}", k)
908907
}
909908
}
910909

911910
fn get_key_value_pairs(&self) -> Vec<(PyObjectRef, PyObjectRef)> {
912911
if let Some(_) = self.payload::<PyDict>() {
913912
objdict::get_key_value_pairs(self)
913+
} else if let Some(PyModule { ref scope, .. }) = self.payload::<PyModule>() {
914+
scope.locals.get_key_value_pairs()
914915
} else {
915-
match self.payload {
916-
PyObjectPayload::Module { ref scope, .. } => scope.locals.get_key_value_pairs(),
917-
_ => panic!("TODO"),
918-
}
916+
panic!("TODO")
919917
}
920918
}
921919

@@ -924,13 +922,10 @@ impl DictProtocol for PyObjectRef {
924922
if let Some(dict) = self.payload::<PyDict>() {
925923
let key = ctx.new_str(key.to_string());
926924
objdict::set_item_in_content(&mut dict.entries.borrow_mut(), &key, &v);
925+
} else if let Some(PyModule { ref scope, .. }) = self.payload::<PyModule>() {
926+
scope.locals.set_item(ctx, key, v);
927927
} else {
928-
match &self.payload {
929-
PyObjectPayload::Module { scope, .. } => {
930-
scope.locals.set_item(ctx, key, v);
931-
}
932-
ref k => panic!("TODO {:?}", k),
933-
};
928+
panic!("TODO {:?}", self);
934929
}
935930
}
936931
}
@@ -1542,10 +1537,6 @@ pub enum PyObjectPayload {
15421537
function: PyObjectRef,
15431538
object: PyObjectRef,
15441539
},
1545-
Module {
1546-
name: String,
1547-
scope: ScopeRef,
1548-
},
15491540
WeakRef {
15501541
referent: PyObjectWeakRef,
15511542
},
@@ -1582,7 +1573,6 @@ impl fmt::Debug for PyObjectPayload {
15821573
ref function,
15831574
ref object,
15841575
} => write!(f, "bound-method: {:?} of {:?}", function, object),
1585-
PyObjectPayload::Module { .. } => write!(f, "module"),
15861576
PyObjectPayload::RustFunction { .. } => write!(f, "rust function"),
15871577
PyObjectPayload::Frame { .. } => write!(f, "frame"),
15881578
PyObjectPayload::AnyRustValue { value } => value.fmt(f),

vm/src/vm.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::obj::objframe;
2121
use crate::obj::objgenerator;
2222
use crate::obj::objiter;
2323
use crate::obj::objlist::PyList;
24+
use crate::obj::objmodule::PyModule;
2425
use crate::obj::objsequence;
2526
use crate::obj::objstr;
2627
use crate::obj::objtuple::PyTuple;
@@ -212,13 +213,8 @@ impl VirtualMachine {
212213
}
213214

214215
pub fn get_builtin_scope(&self) -> ScopeRef {
215-
let a2 = &*self.builtins;
216-
match a2.payload {
217-
PyObjectPayload::Module { ref scope, .. } => scope.clone(),
218-
_ => {
219-
panic!("OMG");
220-
}
221-
}
216+
let PyModule { ref scope, .. } = self.builtins.payload::<PyModule>().unwrap();
217+
scope.clone()
222218
}
223219

224220
// Container of the virtual machine state:

0 commit comments

Comments
 (0)