Skip to content

Commit ea4f2ce

Browse files
committed
Add a store of Rust classes in PyContext
1 parent 19be5c9 commit ea4f2ce

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

vm/src/pyobject.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use std::any::Any;
2+
use std::any::TypeId;
23
use std::cell::Cell;
34
use std::cell::RefCell;
4-
use std::collections::HashMap;
5+
use std::collections::hash_map::{Entry, HashMap};
56
use std::fmt;
67
use std::marker::PhantomData;
78
use std::mem;
@@ -162,6 +163,7 @@ pub struct PyContext {
162163
pub weakproxy_type: PyClassRef,
163164
pub object: PyClassRef,
164165
pub exceptions: exceptions::ExceptionZoo,
166+
rust_classes: RefCell<HashMap<TypeId, PyClassRef>>,
165167
}
166168

167169
pub fn create_type(name: &str, type_type: &PyClassRef, base: &PyClassRef) -> PyClassRef {
@@ -367,6 +369,7 @@ impl PyContext {
367369
weakproxy_type,
368370
type_type,
369371
exceptions,
372+
rust_classes: RefCell::default(),
370373
};
371374
objtype::init(&context);
372375
objlist::init(&context);
@@ -407,6 +410,36 @@ impl PyContext {
407410
context
408411
}
409412

413+
pub fn _add_class(&self, typeid: TypeId, class: PyClassRef) {
414+
let classes = &mut self.rust_classes.borrow_mut();
415+
match classes.entry(typeid) {
416+
Entry::Occupied(o) => panic!(
417+
"Attempted to add rust type twice to the same PyContext, previously held class \
418+
was {:?}",
419+
o.get()
420+
),
421+
Entry::Vacant(v) => {
422+
v.insert(class);
423+
}
424+
}
425+
}
426+
427+
#[inline]
428+
pub fn add_class<T: PyClassImpl + 'static>(&self) -> PyClassRef {
429+
let class = T::make_class(self);
430+
self._add_class(TypeId::of::<T>(), class.clone());
431+
class
432+
}
433+
434+
pub fn _get_class(&self, typeid: &TypeId) -> Option<PyClassRef> {
435+
self.rust_classes.borrow_mut().get(typeid).cloned()
436+
}
437+
438+
#[inline]
439+
pub fn get_class<T: 'static>(&self) -> Option<PyClassRef> {
440+
self._get_class(&TypeId::of::<T>())
441+
}
442+
410443
pub fn bytearray_type(&self) -> PyClassRef {
411444
self.bytearray_type.clone()
412445
}
@@ -1202,7 +1235,13 @@ impl PyObject<dyn PyObjectPayload> {
12021235
pub trait PyValue: fmt::Debug + Sized + 'static {
12031236
const HAVE_DICT: bool = false;
12041237

1205-
fn class(vm: &VirtualMachine) -> PyClassRef;
1238+
fn class(vm: &VirtualMachine) -> PyClassRef {
1239+
vm.ctx.get_class::<Self>().expect(
1240+
"Default .class implementation cannot find your class in PyContext.rust_classes. \
1241+
Is there some special way to access your class, or did you forget to call \
1242+
`ctx.add_class::<T>()`?",
1243+
)
1244+
}
12061245

12071246
fn into_ref(self, vm: &VirtualMachine) -> PyRef<Self> {
12081247
PyRef {

0 commit comments

Comments
 (0)