1
1
use std:: cell:: RefCell ;
2
- use std:: collections:: HashMap ;
2
+ use std:: collections:: { HashMap , HashSet } ;
3
3
use std:: fmt;
4
4
5
5
use super :: objdict:: PyDictRef ;
@@ -16,6 +16,7 @@ use crate::pyobject::{
16
16
} ;
17
17
use crate :: slots:: { PyClassSlots , PyTpFlags } ;
18
18
use crate :: vm:: VirtualMachine ;
19
+ use itertools:: Itertools ;
19
20
20
21
/// type(object_or_name, bases, dict)
21
22
/// type(object) -> the object's type
@@ -290,10 +291,14 @@ impl PyClassRef {
290
291
}
291
292
}
292
293
293
- let typ = new ( metatype, name. as_str ( ) , base. clone ( ) , bases, attributes) ?;
294
- typ. slots . borrow_mut ( ) . flags = base. slots . borrow ( ) . flags ;
295
- vm. ctx . add_tp_new_wrapper ( & typ) ;
296
- Ok ( typ. into ( ) )
294
+ match new ( metatype, name. as_str ( ) , base. clone ( ) , bases, attributes) {
295
+ Ok ( typ) => {
296
+ typ. slots . borrow_mut ( ) . flags = base. slots . borrow ( ) . flags ;
297
+ vm. ctx . add_tp_new_wrapper ( & typ) ;
298
+ Ok ( typ. into ( ) )
299
+ }
300
+ Err ( string) => Err ( vm. new_type_error ( string) ) ,
301
+ }
297
302
}
298
303
299
304
#[ pyslot]
@@ -430,43 +435,65 @@ impl PyClassRef {
430
435
}
431
436
}
432
437
433
- fn take_next_base ( mut bases : Vec < Vec < PyClassRef > > ) -> Option < ( PyClassRef , Vec < Vec < PyClassRef > > ) > {
434
- let mut next = None ;
435
-
438
+ fn take_next_base ( mut bases : Vec < Vec < PyClassRef > > ) -> ( Option < PyClassRef > , Vec < Vec < PyClassRef > > ) {
436
439
bases = bases. into_iter ( ) . filter ( |x| !x. is_empty ( ) ) . collect ( ) ;
437
440
438
441
for base in & bases {
439
442
let head = base[ 0 ] . clone ( ) ;
440
443
if !( & bases) . iter ( ) . any ( |x| x[ 1 ..] . iter ( ) . any ( |x| x. is ( & head) ) ) {
441
- next = Some ( head) ;
442
- break ;
444
+ // Remove from other heads.
445
+ for item in & mut bases {
446
+ if item[ 0 ] . is ( & head) {
447
+ item. remove ( 0 ) ;
448
+ }
449
+ }
450
+
451
+ return ( Some ( head) , bases) ;
443
452
}
444
453
}
445
454
446
- if let Some ( head) = next {
447
- for item in & mut bases {
448
- if item[ 0 ] . is ( & head) {
449
- item. remove ( 0 ) ;
455
+ ( None , bases)
456
+ }
457
+
458
+ fn linearise_mro ( mut bases : Vec < Vec < PyClassRef > > ) -> Result < Vec < PyClassRef > , String > {
459
+ vm_trace ! ( "Linearising MRO: {:?}" , bases) ;
460
+ // Python requires that the class direct bases are kept in the same order.
461
+ // This is called local precedence ordering.
462
+ // This means we must verify that for classes A(), B(A) we must reject C(A, B) even though this
463
+ // algorithm will allow the mro ordering of [C, B, A, object].
464
+ // To verify this, we make sure non of the direct bases are in the mro of bases after them.
465
+ for ( i, base_mro) in bases. iter ( ) . enumerate ( ) {
466
+ let base = & base_mro[ 0 ] ; // Mros cannot be empty.
467
+ for later_mro in bases[ i + 1 ..] . iter ( ) {
468
+ // We start at index 1 to skip direct bases.
469
+ // This will not catch duplicate bases, but such a thing is already tested for.
470
+ if later_mro[ 1 ..] . iter ( ) . any ( |cls| cls. is ( base) ) {
471
+ return Err (
472
+ "Unable to find mro order which keeps local precedence ordering" . to_owned ( ) ,
473
+ ) ;
450
474
}
451
475
}
452
- return Some ( ( head, bases) ) ;
453
476
}
454
- None
455
- }
456
477
457
- fn linearise_mro ( mut bases : Vec < Vec < PyClassRef > > ) -> Option < Vec < PyClassRef > > {
458
- vm_trace ! ( "Linearising MRO: {:?}" , bases) ;
459
478
let mut result = vec ! [ ] ;
460
479
loop {
461
480
if ( & bases) . iter ( ) . all ( Vec :: is_empty) {
462
481
break ;
463
482
}
464
- let ( head, new_bases) = take_next_base ( bases) ?;
483
+ let ( head, new_bases) = take_next_base ( bases) ;
484
+ if head. is_none ( ) {
485
+ // Take the head class of each class here. Now that we have reached the problematic bases.
486
+ // Because this failed, we assume the lists cannot be empty.
487
+ return Err ( format ! (
488
+ "Cannot create a consistent method resolution order (MRO) for bases {}" ,
489
+ new_bases. iter( ) . map( |x| x. first( ) . unwrap( ) ) . join( ", " )
490
+ ) ) ;
491
+ }
465
492
466
- result. push ( head) ;
493
+ result. push ( head. unwrap ( ) ) ;
467
494
bases = new_bases;
468
495
}
469
- Some ( result)
496
+ Ok ( result)
470
497
}
471
498
472
499
pub fn new (
@@ -475,12 +502,20 @@ pub fn new(
475
502
_base : PyClassRef ,
476
503
bases : Vec < PyClassRef > ,
477
504
dict : HashMap < String , PyObjectRef > ,
478
- ) -> PyResult < PyClassRef > {
505
+ ) -> Result < PyClassRef , String > {
506
+ // Check for duplicates in bases.
507
+ let mut unique_bases = HashSet :: new ( ) ;
508
+ for base in bases. iter ( ) {
509
+ if !unique_bases. insert ( base. get_id ( ) ) {
510
+ return Err ( format ! ( "duplicate base class {}" , base. name) ) ;
511
+ }
512
+ }
513
+
479
514
let mros = bases
480
515
. iter ( )
481
516
. map ( |x| x. iter_mro ( ) . cloned ( ) . collect ( ) )
482
517
. collect ( ) ;
483
- let mro = linearise_mro ( mros) . unwrap ( ) ;
518
+ let mro = linearise_mro ( mros) ? ;
484
519
let new_type = PyObject {
485
520
payload : PyClass {
486
521
name : String :: from ( name) ,
@@ -584,11 +619,8 @@ mod tests {
584
619
use super :: { linearise_mro, new} ;
585
620
use super :: { HashMap , IdProtocol , PyClassRef , PyContext } ;
586
621
587
- fn map_ids ( obj : Option < Vec < PyClassRef > > ) -> Option < Vec < usize > > {
588
- match obj {
589
- Some ( vec) => Some ( vec. into_iter ( ) . map ( |x| x. get_id ( ) ) . collect ( ) ) ,
590
- None => None ,
591
- }
622
+ fn map_ids ( obj : Result < Vec < PyClassRef > , String > ) -> Result < Vec < usize > , String > {
623
+ Ok ( obj?. into_iter ( ) . map ( |x| x. get_id ( ) ) . collect ( ) )
592
624
}
593
625
594
626
#[ test]
@@ -619,14 +651,14 @@ mod tests {
619
651
vec![ object. clone( ) ] ,
620
652
vec![ object. clone( ) ]
621
653
] ) ) ,
622
- map_ids( Some ( vec![ object. clone( ) ] ) )
654
+ map_ids( Ok ( vec![ object. clone( ) ] ) )
623
655
) ;
624
656
assert_eq ! (
625
657
map_ids( linearise_mro( vec![
626
658
vec![ a. clone( ) , object. clone( ) ] ,
627
659
vec![ b. clone( ) , object. clone( ) ] ,
628
660
] ) ) ,
629
- map_ids( Some ( vec![ a. clone( ) , b. clone( ) , object. clone( ) ] ) )
661
+ map_ids( Ok ( vec![ a. clone( ) , b. clone( ) , object. clone( ) ] ) )
630
662
) ;
631
663
}
632
664
}
0 commit comments