Skip to content

Commit 239e018

Browse files
authored
Merge pull request RustPython#1764 from IdanDor/feature/mro_resolution
Fixes to MRO usage and issue RustPython#1659
2 parents e07ca66 + 6ea5d58 commit 239e018

File tree

3 files changed

+110
-32
lines changed

3 files changed

+110
-32
lines changed

tests/snippets/class.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,49 @@ class B:
207207
# assert a == 3
208208
# A.b()
209209
# nested_scope()
210+
211+
212+
213+
# Multiple inheritance and mro tests.
214+
class A():
215+
def f(self):
216+
return 'a'
217+
218+
class B(A):
219+
def f(self):
220+
return 'b' + super().f()
221+
222+
class C(A):
223+
def f(self):
224+
return 'c' + super().f()
225+
226+
class D(B, C):
227+
def f(self):
228+
return 'd' + super().f()
229+
230+
assert D().f() == 'dbca', "Mro resolution using super failed."
231+
232+
233+
234+
class A():
235+
pass
236+
try:
237+
class B(A, A):
238+
pass
239+
except TypeError:
240+
pass
241+
else:
242+
assert False, "Managed to create a class with duplicate base classes."
243+
244+
245+
class A():
246+
pass
247+
class B(A):
248+
pass
249+
try:
250+
class C(A, B):
251+
pass
252+
except TypeError:
253+
pass
254+
else:
255+
assert False, "Managed to create a class without local type precedence."

vm/src/obj/objtype.rs

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::cell::RefCell;
2-
use std::collections::HashMap;
2+
use std::collections::{HashMap, HashSet};
33
use std::fmt;
44

55
use super::objdict::PyDictRef;
@@ -16,6 +16,7 @@ use crate::pyobject::{
1616
};
1717
use crate::slots::{PyClassSlots, PyTpFlags};
1818
use crate::vm::VirtualMachine;
19+
use itertools::Itertools;
1920

2021
/// type(object_or_name, bases, dict)
2122
/// type(object) -> the object's type
@@ -290,10 +291,14 @@ impl PyClassRef {
290291
}
291292
}
292293

293-
let typ = new(metatype, name.as_str(), base.clone(), bases, attributes)?;
294-
typ.slots.borrow_mut().flags = base.slots.borrow().flags;
295-
vm.ctx.add_tp_new_wrapper(&typ);
296-
Ok(typ.into())
294+
match new(metatype, name.as_str(), base.clone(), bases, attributes) {
295+
Ok(typ) => {
296+
typ.slots.borrow_mut().flags = base.slots.borrow().flags;
297+
vm.ctx.add_tp_new_wrapper(&typ);
298+
Ok(typ.into())
299+
}
300+
Err(string) => Err(vm.new_type_error(string)),
301+
}
297302
}
298303

299304
#[pyslot]
@@ -430,43 +435,65 @@ impl PyClassRef {
430435
}
431436
}
432437

433-
fn take_next_base(mut bases: Vec<Vec<PyClassRef>>) -> Option<(PyClassRef, Vec<Vec<PyClassRef>>)> {
434-
let mut next = None;
435-
438+
fn take_next_base(mut bases: Vec<Vec<PyClassRef>>) -> (Option<PyClassRef>, Vec<Vec<PyClassRef>>) {
436439
bases = bases.into_iter().filter(|x| !x.is_empty()).collect();
437440

438441
for base in &bases {
439442
let head = base[0].clone();
440443
if !(&bases).iter().any(|x| x[1..].iter().any(|x| x.is(&head))) {
441-
next = Some(head);
442-
break;
444+
// Remove from other heads.
445+
for item in &mut bases {
446+
if item[0].is(&head) {
447+
item.remove(0);
448+
}
449+
}
450+
451+
return (Some(head), bases);
443452
}
444453
}
445454

446-
if let Some(head) = next {
447-
for item in &mut bases {
448-
if item[0].is(&head) {
449-
item.remove(0);
455+
(None, bases)
456+
}
457+
458+
fn linearise_mro(mut bases: Vec<Vec<PyClassRef>>) -> Result<Vec<PyClassRef>, String> {
459+
vm_trace!("Linearising MRO: {:?}", bases);
460+
// Python requires that the class direct bases are kept in the same order.
461+
// This is called local precedence ordering.
462+
// This means we must verify that for classes A(), B(A) we must reject C(A, B) even though this
463+
// algorithm will allow the mro ordering of [C, B, A, object].
464+
// To verify this, we make sure non of the direct bases are in the mro of bases after them.
465+
for (i, base_mro) in bases.iter().enumerate() {
466+
let base = &base_mro[0]; // Mros cannot be empty.
467+
for later_mro in bases[i + 1..].iter() {
468+
// We start at index 1 to skip direct bases.
469+
// This will not catch duplicate bases, but such a thing is already tested for.
470+
if later_mro[1..].iter().any(|cls| cls.is(base)) {
471+
return Err(
472+
"Unable to find mro order which keeps local precedence ordering".to_owned(),
473+
);
450474
}
451475
}
452-
return Some((head, bases));
453476
}
454-
None
455-
}
456477

457-
fn linearise_mro(mut bases: Vec<Vec<PyClassRef>>) -> Option<Vec<PyClassRef>> {
458-
vm_trace!("Linearising MRO: {:?}", bases);
459478
let mut result = vec![];
460479
loop {
461480
if (&bases).iter().all(Vec::is_empty) {
462481
break;
463482
}
464-
let (head, new_bases) = take_next_base(bases)?;
483+
let (head, new_bases) = take_next_base(bases);
484+
if head.is_none() {
485+
// Take the head class of each class here. Now that we have reached the problematic bases.
486+
// Because this failed, we assume the lists cannot be empty.
487+
return Err(format!(
488+
"Cannot create a consistent method resolution order (MRO) for bases {}",
489+
new_bases.iter().map(|x| x.first().unwrap()).join(", ")
490+
));
491+
}
465492

466-
result.push(head);
493+
result.push(head.unwrap());
467494
bases = new_bases;
468495
}
469-
Some(result)
496+
Ok(result)
470497
}
471498

472499
pub fn new(
@@ -475,12 +502,20 @@ pub fn new(
475502
_base: PyClassRef,
476503
bases: Vec<PyClassRef>,
477504
dict: HashMap<String, PyObjectRef>,
478-
) -> PyResult<PyClassRef> {
505+
) -> Result<PyClassRef, String> {
506+
// Check for duplicates in bases.
507+
let mut unique_bases = HashSet::new();
508+
for base in bases.iter() {
509+
if !unique_bases.insert(base.get_id()) {
510+
return Err(format!("duplicate base class {}", base.name));
511+
}
512+
}
513+
479514
let mros = bases
480515
.iter()
481516
.map(|x| x.iter_mro().cloned().collect())
482517
.collect();
483-
let mro = linearise_mro(mros).unwrap();
518+
let mro = linearise_mro(mros)?;
484519
let new_type = PyObject {
485520
payload: PyClass {
486521
name: String::from(name),
@@ -584,11 +619,8 @@ mod tests {
584619
use super::{linearise_mro, new};
585620
use super::{HashMap, IdProtocol, PyClassRef, PyContext};
586621

587-
fn map_ids(obj: Option<Vec<PyClassRef>>) -> Option<Vec<usize>> {
588-
match obj {
589-
Some(vec) => Some(vec.into_iter().map(|x| x.get_id()).collect()),
590-
None => None,
591-
}
622+
fn map_ids(obj: Result<Vec<PyClassRef>, String>) -> Result<Vec<usize>, String> {
623+
Ok(obj?.into_iter().map(|x| x.get_id()).collect())
592624
}
593625

594626
#[test]
@@ -619,14 +651,14 @@ mod tests {
619651
vec![object.clone()],
620652
vec![object.clone()]
621653
])),
622-
map_ids(Some(vec![object.clone()]))
654+
map_ids(Ok(vec![object.clone()]))
623655
);
624656
assert_eq!(
625657
map_ids(linearise_mro(vec![
626658
vec![a.clone(), object.clone()],
627659
vec![b.clone(), object.clone()],
628660
])),
629-
map_ids(Some(vec![a.clone(), b.clone(), object.clone()]))
661+
map_ids(Ok(vec![a.clone(), b.clone(), object.clone()]))
630662
);
631663
}
632664
}

vm/src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ pub fn create_type(name: &str, type_type: &PyClassRef, base: &PyClassRef) -> PyC
245245
vec![base.clone()],
246246
dict,
247247
)
248-
.unwrap()
248+
.expect("Failed to create a new type in internal code.")
249249
}
250250

251251
/// Paritally initialize a struct, ensuring that all fields are

0 commit comments

Comments
 (0)