Skip to content

Commit 95a947d

Browse files
committed
__class_getitem__ isn't a special method
1 parent 4f0feef commit 95a947d

File tree

2 files changed

+50
-18
lines changed

2 files changed

+50
-18
lines changed

vm/src/builtins/make_module.rs

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ mod decl {
2020
use crate::builtins::pybool::IntoPyBool;
2121
use crate::builtins::pystr::{PyStr, PyStrRef};
2222
use crate::builtins::pytype::PyTypeRef;
23-
use crate::builtins::{PyByteArray, PyBytes};
23+
use crate::builtins::{PyByteArray, PyBytes, PyTupleRef};
2424
use crate::byteslike::ArgBytesLike;
2525
use crate::common::{hash::PyHash, str::to_ascii};
2626
#[cfg(feature = "rustpython-compiler")]
@@ -832,7 +832,7 @@ mod decl {
832832
pub fn __build_class__(
833833
function: PyFunctionRef,
834834
qualified_name: PyStrRef,
835-
bases: Args<PyTypeRef>,
835+
bases: Args,
836836
mut kwargs: KwArgs,
837837
vm: &VirtualMachine,
838838
) -> PyResult {
@@ -845,7 +845,41 @@ mod decl {
845845
vm.ctx.types.type_type.clone()
846846
};
847847

848-
for base in bases.clone() {
848+
let mut new_bases: Option<Vec<PyObjectRef>> = None;
849+
850+
let bases = PyTupleRef::with_elements(bases.into_vec(), &vm.ctx);
851+
852+
for (i, base) in bases.as_slice().iter().enumerate() {
853+
if base.isinstance(&vm.ctx.types.type_type) {
854+
if let Some(bases) = &mut new_bases {
855+
bases.push(base.clone());
856+
}
857+
continue;
858+
}
859+
let mro_entries = vm.get_attribute_opt(base.clone(), "__mro_entries__")?;
860+
let entries = match mro_entries {
861+
Some(meth) => vm.invoke(&meth, (bases.clone(),))?,
862+
None => {
863+
if let Some(bases) = &mut new_bases {
864+
bases.push(base.clone());
865+
}
866+
continue;
867+
}
868+
};
869+
let entries: PyTupleRef = entries
870+
.downcast()
871+
.map_err(|_| vm.new_type_error("__mro_entries__ must return a tuple".to_owned()))?;
872+
let new_bases = new_bases.get_or_insert_with(|| bases.as_slice()[..i].to_vec());
873+
new_bases.extend_from_slice(entries.as_slice());
874+
}
875+
876+
let new_bases = new_bases.map(|v| PyTupleRef::with_elements(v, &vm.ctx));
877+
let (orig_bases, bases) = match new_bases {
878+
Some(new) => (Some(bases), new),
879+
None => (None, bases),
880+
};
881+
882+
for base in bases.as_slice().iter() {
849883
let base_class = base.class();
850884
if base_class.issubclass(&metaclass) {
851885
metaclass = base.clone_class();
@@ -858,7 +892,7 @@ mod decl {
858892
}
859893
}
860894

861-
let bases = bases.into_tuple(vm);
895+
let bases = bases.into_object();
862896

863897
// Prepare uses full __getattribute__ resolution chain.
864898
let prepare = vm.get_attribute(metaclass.clone().into_object(), "__prepare__")?;
@@ -872,6 +906,10 @@ mod decl {
872906
let classcell = function.invoke_with_locals(().into(), Some(namespace.clone()), vm)?;
873907
let classcell = <Option<PyCellRef>>::try_from_object(vm, classcell)?;
874908

909+
if let Some(orig_bases) = orig_bases {
910+
namespace.set_item("__orig_bases__", orig_bases.into_object(), vm)?;
911+
}
912+
875913
let class = vm.invoke(
876914
metaclass.as_object(),
877915
FuncArgs::new(vec![name_obj, bases, namespace.into_object()], kwargs),

vm/src/pyobject.rs

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -686,25 +686,19 @@ where
686686
{
687687
fn get_item(&self, key: T, vm: &VirtualMachine) -> PyResult {
688688
match vm.get_special_method(self.clone(), "__getitem__")? {
689-
Ok(special_method) => special_method.invoke((key,), vm),
689+
Ok(special_method) => return special_method.invoke((key,), vm),
690690
Err(obj) => {
691691
if obj.isinstance(&vm.ctx.types.type_type) {
692-
vm.get_special_method(obj, "__class_getitem__")?
693-
.map_err(|obj2| {
694-
vm.new_type_error(format!(
695-
"'{}' object is not subscriptable",
696-
obj2.class().name
697-
))
698-
})?
699-
.invoke((key,), vm)
700-
} else {
701-
Err(vm.new_type_error(format!(
702-
"'{}' object is not subscriptable",
703-
obj.class().name
704-
)))
692+
if let Some(class_getitem) = vm.get_attribute_opt(obj, "__class_getitem__")? {
693+
return vm.invoke(&class_getitem, (key,));
694+
}
705695
}
706696
}
707697
}
698+
Err(vm.new_type_error(format!(
699+
"'{}' object is not subscriptable",
700+
self.class().name
701+
)))
708702
}
709703

710704
fn set_item(&self, key: T, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {

0 commit comments

Comments
 (0)