diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 1e84d289ec..adff90456e 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -2883,8 +2883,6 @@ def method(self) -> int: ... self.assertNotIsSubclass(NotImpl, Foo) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pep695_generics_can_be_runtime_checkable(self): @runtime_checkable class HasX(Protocol): @@ -3332,8 +3330,6 @@ class D(PNonCall): ... with self.assertRaisesRegex(TypeError, non_callable_members_illegal): issubclass(D, PNonCall) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_no_weird_caching_with_issubclass_after_isinstance(self): @runtime_checkable class Spam(Protocol): @@ -3400,8 +3396,6 @@ def __getattr__(self, attr): ): issubclass(Eggs, Spam) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_no_weird_caching_with_issubclass_after_isinstance_pep695(self): @runtime_checkable class Spam[T](Protocol): @@ -8552,8 +8546,6 @@ def test_valid_uses(self): self.assertEqual(C4.__args__, (P, T)) self.assertEqual(C4.__parameters__, (P, T)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_args_kwargs(self): P = ParamSpec('P') P_2 = ParamSpec('P_2') diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 498330fa97..41ed70d939 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -1019,26 +1019,7 @@ impl Constructor for PyType { attributes.insert(identifier!(vm, __hash__), vm.ctx.none.clone().into()); } - // All *classes* should have a dict. Exceptions are *instances* of - // classes that define __slots__ and instances of built-in classes - // (with exceptions, e.g function) - // Also, type subclasses don't need their own __dict__ descriptor - // since they inherit it from type - if !base_is_type { - let __dict__ = identifier!(vm, __dict__); - attributes.entry(__dict__).or_insert_with(|| { - vm.ctx - .new_static_getset( - "__dict__", - vm.ctx.types.type_type, - subtype_get_dict, - subtype_set_dict, - ) - .into() - }); - } - - let heaptype_slots: Option>> = + let (heaptype_slots, add_dict): (Option>>, bool) = if let Some(x) = attributes.get(identifier!(vm, __slots__)) { let slots = if x.class().is(vm.ctx.types.str_type) { let x = unsafe { x.downcast_unchecked_ref::() }; @@ -1055,9 +1036,26 @@ impl Constructor for PyType { let tuple = elements.into_pytuple(vm); tuple.try_into_typed(vm)? }; - Some(slots) + + // Check if __dict__ is in slots + let dict_name = "__dict__"; + let has_dict = slots.iter().any(|s| s.as_str() == dict_name); + + // Filter out __dict__ from slots + let filtered_slots = if has_dict { + let filtered: Vec = slots + .iter() + .filter(|s| s.as_str() != dict_name) + .cloned() + .collect(); + PyTuple::new_ref_typed(filtered, &vm.ctx) + } else { + slots + }; + + (Some(filtered_slots), has_dict) } else { - None + (None, false) }; // FIXME: this is a temporary fix. multi bases with multiple slots will break object @@ -1070,8 +1068,10 @@ impl Constructor for PyType { let member_count: usize = base_member_count + heaptype_member_count; let mut flags = PyTypeFlags::heap_type_flags(); - // Only add HAS_DICT and MANAGED_DICT if __slots__ is not defined. - if heaptype_slots.is_none() { + // Add HAS_DICT and MANAGED_DICT if: + // 1. __slots__ is not defined, OR + // 2. __dict__ is in __slots__ + if heaptype_slots.is_none() || add_dict { flags |= PyTypeFlags::HAS_DICT | PyTypeFlags::MANAGED_DICT; } @@ -1141,6 +1141,25 @@ impl Constructor for PyType { cell.set(Some(typ.clone().into())); }; + // All *classes* should have a dict. Exceptions are *instances* of + // classes that define __slots__ and instances of built-in classes + // (with exceptions, e.g function) + // Also, type subclasses don't need their own __dict__ descriptor + // since they inherit it from type + + // Add __dict__ descriptor after type creation to ensure correct __objclass__ + if !base_is_type { + let __dict__ = identifier!(vm, __dict__); + if !typ.attributes.read().contains_key(&__dict__) { + unsafe { + let descriptor = + vm.ctx + .new_getset("__dict__", &typ, subtype_get_dict, subtype_set_dict); + typ.attributes.write().insert(__dict__, descriptor.into()); + } + } + } + // avoid deadlock let attributes = typ .attributes @@ -1446,51 +1465,77 @@ impl Representable for PyType { } } -fn find_base_dict_descr(cls: &Py, vm: &VirtualMachine) -> Option { - cls.iter_base_chain().skip(1).find_map(|cls| { - // TODO: should actually be some translation of: - // cls.slot_dictoffset != 0 && !cls.flags.contains(HEAPTYPE) - if cls.is(vm.ctx.types.type_type) { - cls.get_attr(identifier!(vm, __dict__)) - } else { - None +// = get_builtin_base_with_dict +fn get_builtin_base_with_dict(typ: &Py, vm: &VirtualMachine) -> Option { + let mut current = Some(typ.to_owned()); + while let Some(t) = current { + // In CPython: type->tp_dictoffset != 0 && !(type->tp_flags & Py_TPFLAGS_HEAPTYPE) + // Special case: type itself is a builtin with dict support + if t.is(vm.ctx.types.type_type) { + return Some(t); + } + // We check HAS_DICT flag (equivalent to tp_dictoffset != 0) and HEAPTYPE + if t.slots.flags.contains(PyTypeFlags::HAS_DICT) + && !t.slots.flags.contains(PyTypeFlags::HEAPTYPE) + { + return Some(t); } - }) + current = t.__base__(); + } + None +} + +// = get_dict_descriptor +fn get_dict_descriptor(base: &Py, vm: &VirtualMachine) -> Option { + let dict_attr = identifier!(vm, __dict__); + // Use _PyType_Lookup (which is lookup_ref in RustPython) + base.lookup_ref(dict_attr, vm) +} + +// = raise_dict_descr_error +fn raise_dict_descriptor_error(obj: &PyObject, vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_type_error(format!( + "this __dict__ descriptor does not support '{}' objects", + obj.class().name() + )) } fn subtype_get_dict(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - // TODO: obj.class().as_pyref() need to be supported - let ret = match find_base_dict_descr(obj.class(), vm) { - Some(descr) => vm.call_get_descriptor(&descr, obj).unwrap_or_else(|| { - Err(vm.new_type_error(format!( - "this __dict__ descriptor does not support '{}' objects", - descr.class() - ))) - })?, - None => object::object_get_dict(obj, vm)?.into(), - }; - Ok(ret) + let base = get_builtin_base_with_dict(obj.class(), vm); + + if let Some(base_type) = base { + if let Some(descr) = get_dict_descriptor(&base_type, vm) { + // Call the descriptor's tp_descr_get + vm.call_get_descriptor(&descr, obj.clone()) + .unwrap_or_else(|| Err(raise_dict_descriptor_error(&obj, vm))) + } else { + Err(raise_dict_descriptor_error(&obj, vm)) + } + } else { + // PyObject_GenericGetDict + object::object_get_dict(obj, vm).map(Into::into) + } } +// = subtype_setdict fn subtype_set_dict(obj: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let cls = obj.class(); - match find_base_dict_descr(cls, vm) { - Some(descr) => { + let base = get_builtin_base_with_dict(obj.class(), vm); + + if let Some(base_type) = base { + if let Some(descr) = get_dict_descriptor(&base_type, vm) { + // Call the descriptor's tp_descr_set let descr_set = descr .class() .mro_find_map(|cls| cls.slots.descr_set.load()) - .ok_or_else(|| { - vm.new_type_error(format!( - "this __dict__ descriptor does not support '{}' objects", - cls.name() - )) - })?; + .ok_or_else(|| raise_dict_descriptor_error(&obj, vm))?; descr_set(&descr, obj, PySetterValue::Assign(value), vm) + } else { + Err(raise_dict_descriptor_error(&obj, vm)) } - None => { - object::object_set_dict(obj, value.try_into_value(vm)?, vm)?; - Ok(()) - } + } else { + // PyObject_GenericSetDict + object::object_set_dict(obj, value.try_into_value(vm)?, vm)?; + Ok(()) } } diff --git a/vm/src/stdlib/typevar.rs b/vm/src/stdlib/typevar.rs index a8962acc90..11c20ba787 100644 --- a/vm/src/stdlib/typevar.rs +++ b/vm/src/stdlib/typevar.rs @@ -833,11 +833,12 @@ impl Comparable for ParamSpecArgs { fn eq( zelf: &crate::Py, other: PyObjectRef, - vm: &VirtualMachine, + _vm: &VirtualMachine, ) -> PyResult { - // Check if other has __origin__ attribute - if let Ok(other_origin) = other.get_attr("__origin__", vm) { - return Ok(zelf.__origin__.is(&other_origin)); + // First check if other is also ParamSpecArgs + if let Ok(other_args) = other.downcast::() { + // Check if they have the same origin + return Ok(zelf.__origin__.is(&other_args.__origin__)); } Ok(false) } @@ -911,11 +912,12 @@ impl Comparable for ParamSpecKwargs { fn eq( zelf: &crate::Py, other: PyObjectRef, - vm: &VirtualMachine, + _vm: &VirtualMachine, ) -> PyResult { - // Check if other has __origin__ attribute - if let Ok(other_origin) = other.get_attr("__origin__", vm) { - return Ok(zelf.__origin__.is(&other_origin)); + // First check if other is also ParamSpecKwargs + if let Ok(other_kwargs) = other.downcast::() { + // Check if they have the same origin + return Ok(zelf.__origin__.is(&other_kwargs.__origin__)); } Ok(false) }