Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2883,8 +2883,6 @@ def method(self) -> int: ...

self.assertNotIsSubclass(NotImpl, Foo)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_pep695_generics_can_be_runtime_checkable(self):
@runtime_checkable
class HasX(Protocol):
Expand Down Expand Up @@ -3332,8 +3330,6 @@ class D(PNonCall): ...
with self.assertRaisesRegex(TypeError, non_callable_members_illegal):
issubclass(D, PNonCall)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_no_weird_caching_with_issubclass_after_isinstance(self):
@runtime_checkable
class Spam(Protocol):
Expand Down Expand Up @@ -3400,8 +3396,6 @@ def __getattr__(self, attr):
):
issubclass(Eggs, Spam)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_no_weird_caching_with_issubclass_after_isinstance_pep695(self):
@runtime_checkable
class Spam[T](Protocol):
Expand Down Expand Up @@ -8552,8 +8546,6 @@ def test_valid_uses(self):
self.assertEqual(C4.__args__, (P, T))
self.assertEqual(C4.__parameters__, (P, T))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_args_kwargs(self):
P = ParamSpec('P')
P_2 = ParamSpec('P_2')
Expand Down
159 changes: 102 additions & 57 deletions vm/src/builtins/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1019,26 +1019,7 @@ impl Constructor for PyType {
attributes.insert(identifier!(vm, __hash__), vm.ctx.none.clone().into());
}

// All *classes* should have a dict. Exceptions are *instances* of
// classes that define __slots__ and instances of built-in classes
// (with exceptions, e.g function)
// Also, type subclasses don't need their own __dict__ descriptor
// since they inherit it from type
if !base_is_type {
let __dict__ = identifier!(vm, __dict__);
attributes.entry(__dict__).or_insert_with(|| {
vm.ctx
.new_static_getset(
"__dict__",
vm.ctx.types.type_type,
subtype_get_dict,
subtype_set_dict,
)
.into()
});
}

let heaptype_slots: Option<PyRef<PyTuple<PyStrRef>>> =
let (heaptype_slots, add_dict): (Option<PyRef<PyTuple<PyStrRef>>>, bool) =
if let Some(x) = attributes.get(identifier!(vm, __slots__)) {
let slots = if x.class().is(vm.ctx.types.str_type) {
let x = unsafe { x.downcast_unchecked_ref::<PyStr>() };
Expand All @@ -1055,9 +1036,26 @@ impl Constructor for PyType {
let tuple = elements.into_pytuple(vm);
tuple.try_into_typed(vm)?
};
Some(slots)

// Check if __dict__ is in slots
let dict_name = "__dict__";
let has_dict = slots.iter().any(|s| s.as_str() == dict_name);

// Filter out __dict__ from slots
let filtered_slots = if has_dict {
let filtered: Vec<PyStrRef> = slots
.iter()
.filter(|s| s.as_str() != dict_name)
.cloned()
.collect();
PyTuple::new_ref_typed(filtered, &vm.ctx)
} else {
slots
};

(Some(filtered_slots), has_dict)
} else {
None
(None, false)
};

// FIXME: this is a temporary fix. multi bases with multiple slots will break object
Expand All @@ -1070,8 +1068,10 @@ impl Constructor for PyType {
let member_count: usize = base_member_count + heaptype_member_count;

let mut flags = PyTypeFlags::heap_type_flags();
// Only add HAS_DICT and MANAGED_DICT if __slots__ is not defined.
if heaptype_slots.is_none() {
// Add HAS_DICT and MANAGED_DICT if:
// 1. __slots__ is not defined, OR
// 2. __dict__ is in __slots__
if heaptype_slots.is_none() || add_dict {
flags |= PyTypeFlags::HAS_DICT | PyTypeFlags::MANAGED_DICT;
}

Expand Down Expand Up @@ -1141,6 +1141,25 @@ impl Constructor for PyType {
cell.set(Some(typ.clone().into()));
};

// All *classes* should have a dict. Exceptions are *instances* of
// classes that define __slots__ and instances of built-in classes
// (with exceptions, e.g function)
// Also, type subclasses don't need their own __dict__ descriptor
// since they inherit it from type

// Add __dict__ descriptor after type creation to ensure correct __objclass__
if !base_is_type {
let __dict__ = identifier!(vm, __dict__);
if !typ.attributes.read().contains_key(&__dict__) {
unsafe {
let descriptor =
vm.ctx
.new_getset("__dict__", &typ, subtype_get_dict, subtype_set_dict);
typ.attributes.write().insert(__dict__, descriptor.into());
}
}
}

// avoid deadlock
let attributes = typ
.attributes
Expand Down Expand Up @@ -1446,51 +1465,77 @@ impl Representable for PyType {
}
}

fn find_base_dict_descr(cls: &Py<PyType>, vm: &VirtualMachine) -> Option<PyObjectRef> {
cls.iter_base_chain().skip(1).find_map(|cls| {
// TODO: should actually be some translation of:
// cls.slot_dictoffset != 0 && !cls.flags.contains(HEAPTYPE)
if cls.is(vm.ctx.types.type_type) {
cls.get_attr(identifier!(vm, __dict__))
} else {
None
// = get_builtin_base_with_dict
fn get_builtin_base_with_dict(typ: &Py<PyType>, vm: &VirtualMachine) -> Option<PyTypeRef> {
let mut current = Some(typ.to_owned());
while let Some(t) = current {
// In CPython: type->tp_dictoffset != 0 && !(type->tp_flags & Py_TPFLAGS_HEAPTYPE)
// Special case: type itself is a builtin with dict support
if t.is(vm.ctx.types.type_type) {
return Some(t);
}
// We check HAS_DICT flag (equivalent to tp_dictoffset != 0) and HEAPTYPE
if t.slots.flags.contains(PyTypeFlags::HAS_DICT)
&& !t.slots.flags.contains(PyTypeFlags::HEAPTYPE)
{
return Some(t);
}
})
current = t.__base__();
}
None
}

// = get_dict_descriptor
fn get_dict_descriptor(base: &Py<PyType>, vm: &VirtualMachine) -> Option<PyObjectRef> {
let dict_attr = identifier!(vm, __dict__);
// Use _PyType_Lookup (which is lookup_ref in RustPython)
base.lookup_ref(dict_attr, vm)
}

// = raise_dict_descr_error
fn raise_dict_descriptor_error(obj: &PyObject, vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_type_error(format!(
"this __dict__ descriptor does not support '{}' objects",
obj.class().name()
))
}

fn subtype_get_dict(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
// TODO: obj.class().as_pyref() need to be supported
let ret = match find_base_dict_descr(obj.class(), vm) {
Some(descr) => vm.call_get_descriptor(&descr, obj).unwrap_or_else(|| {
Err(vm.new_type_error(format!(
"this __dict__ descriptor does not support '{}' objects",
descr.class()
)))
})?,
None => object::object_get_dict(obj, vm)?.into(),
};
Ok(ret)
let base = get_builtin_base_with_dict(obj.class(), vm);

if let Some(base_type) = base {
if let Some(descr) = get_dict_descriptor(&base_type, vm) {
// Call the descriptor's tp_descr_get
vm.call_get_descriptor(&descr, obj.clone())
.unwrap_or_else(|| Err(raise_dict_descriptor_error(&obj, vm)))
} else {
Err(raise_dict_descriptor_error(&obj, vm))
}
} else {
// PyObject_GenericGetDict
object::object_get_dict(obj, vm).map(Into::into)
}
}

// = subtype_setdict
fn subtype_set_dict(obj: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
let cls = obj.class();
match find_base_dict_descr(cls, vm) {
Some(descr) => {
let base = get_builtin_base_with_dict(obj.class(), vm);

if let Some(base_type) = base {
if let Some(descr) = get_dict_descriptor(&base_type, vm) {
// Call the descriptor's tp_descr_set
let descr_set = descr
.class()
.mro_find_map(|cls| cls.slots.descr_set.load())
.ok_or_else(|| {
vm.new_type_error(format!(
"this __dict__ descriptor does not support '{}' objects",
cls.name()
))
})?;
.ok_or_else(|| raise_dict_descriptor_error(&obj, vm))?;
descr_set(&descr, obj, PySetterValue::Assign(value), vm)
} else {
Err(raise_dict_descriptor_error(&obj, vm))
}
None => {
object::object_set_dict(obj, value.try_into_value(vm)?, vm)?;
Ok(())
}
} else {
// PyObject_GenericSetDict
object::object_set_dict(obj, value.try_into_value(vm)?, vm)?;
Ok(())
}
}

Expand Down
18 changes: 10 additions & 8 deletions vm/src/stdlib/typevar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,11 +833,12 @@ impl Comparable for ParamSpecArgs {
fn eq(
zelf: &crate::Py<ParamSpecArgs>,
other: PyObjectRef,
vm: &VirtualMachine,
_vm: &VirtualMachine,
) -> PyResult<bool> {
// Check if other has __origin__ attribute
if let Ok(other_origin) = other.get_attr("__origin__", vm) {
return Ok(zelf.__origin__.is(&other_origin));
// First check if other is also ParamSpecArgs
if let Ok(other_args) = other.downcast::<ParamSpecArgs>() {
// Check if they have the same origin
return Ok(zelf.__origin__.is(&other_args.__origin__));
}
Ok(false)
}
Expand Down Expand Up @@ -911,11 +912,12 @@ impl Comparable for ParamSpecKwargs {
fn eq(
zelf: &crate::Py<ParamSpecKwargs>,
other: PyObjectRef,
vm: &VirtualMachine,
_vm: &VirtualMachine,
) -> PyResult<bool> {
// Check if other has __origin__ attribute
if let Ok(other_origin) = other.get_attr("__origin__", vm) {
return Ok(zelf.__origin__.is(&other_origin));
// First check if other is also ParamSpecKwargs
if let Ok(other_kwargs) = other.downcast::<ParamSpecKwargs>() {
// Check if they have the same origin
return Ok(zelf.__origin__.is(&other_kwargs.__origin__));
}
Ok(false)
}
Expand Down