diff --git a/Lib/test/test_sqlite3/test_regression.py b/Lib/test/test_sqlite3/test_regression.py index 44cfa8aa14..a658ff1f3c 100644 --- a/Lib/test/test_sqlite3/test_regression.py +++ b/Lib/test/test_sqlite3/test_regression.py @@ -150,8 +150,6 @@ def __conform__(self, protocol): with self.assertRaises(IndexError): con.execute("insert into foo(bar, baz) values (?, ?)", parameters) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_error_msg_decode_error(self): # When porting the module to Python 3.0, the error message about # decoding errors disappeared. This verifies they're back again. @@ -265,8 +263,6 @@ def test_connection_call(self): """ self.assertRaises(TypeError, self.con, b"select 1") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_collation(self): def collation_cb(a, b): return 1 diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index adff90456e..318c088fbc 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -25,6 +25,7 @@ from typing import assert_type, cast, runtime_checkable from typing import get_type_hints from typing import get_origin, get_args, get_protocol_members +from typing import override from typing import is_typeddict, is_protocol from typing import reveal_type from typing import dataclass_transform @@ -619,6 +620,7 @@ def test_constructor(self): self.assertIs(T.__contravariant__, False) self.assertIs(T.__infer_variance__, True) + class TypeParameterDefaultsTests(BaseTestCase): def test_typevar(self): T = TypeVar('T', default=int) @@ -804,13 +806,12 @@ def test_pickle(self): self.assertEqual(z.__default__, typevar.__default__) - def template_replace(templates: list[str], replacements: dict[str, list[str]]) -> list[tuple[str]]: """Renders templates with possible combinations of replacements. Example 1: Suppose that: templates = ["dog_breed are awesome", "dog_breed are cool"] - replacements = ["dog_breed": ["Huskies", "Beagles"]] + replacements = {"dog_breed": ["Huskies", "Beagles"]} Then we would return: [ ("Huskies are awesome", "Huskies are cool"), @@ -906,7 +907,6 @@ def test_no_duplicates_if_replacement_not_in_templates(self): self.assertEqual(actual, expected) - class GenericAliasSubstitutionTests(BaseTestCase): """Tests for type variable substitution in generic aliases. @@ -1072,6 +1072,8 @@ def test_variadic_parameters(self): T2 = TypeVar('T2') Ts = TypeVarTuple('Ts') + class C(Generic[*Ts]): pass + generics = ['C', 'tuple', 'Tuple'] tuple_types = ['tuple', 'Tuple'] @@ -1169,12 +1171,10 @@ def test_variadic_parameters(self): - - class UnpackTests(BaseTestCase): def test_accepts_single_type(self): - # (*tuple[int],) + (*tuple[int],) Unpack[Tuple[int]] def test_dir(self): @@ -1239,6 +1239,7 @@ class Gen[*Ts]: ... with self.assertRaisesRegex(TypeError, bad_unpack_param): PartGen[Unpack[List[int]]] + class TypeVarTupleTests(BaseTestCase): def test_name(self): @@ -1625,7 +1626,6 @@ def func3(*args: '*CustomVariadic[int, str]'): pass self.assertEqual(gth(func3, localns={'CustomVariadic': CustomVariadic}), {'args': Unpack[CustomVariadic[int, str]]}) - def test_tuple_args_are_correct(self): Ts = TypeVarTuple('Ts') @@ -1805,7 +1805,6 @@ def g(*args: *Ts): pass self.assertEqual(f.__annotations__, {'args': Unpack[Ts]}) self.assertEqual(g.__annotations__, {'args': (*Ts,)[0]}) - def test_variadic_args_with_ellipsis_annotations_are_correct(self): def a(*args: *tuple[int, ...]): pass self.assertEqual(a.__annotations__, @@ -1815,7 +1814,6 @@ def b(*args: Unpack[Tuple[int, ...]]): pass self.assertEqual(b.__annotations__, {'args': Unpack[Tuple[int, ...]]}) - def test_concatenation_in_variadic_args_annotations_are_correct(self): Ts = TypeVarTuple('Ts') @@ -1946,13 +1944,13 @@ class D(Generic[Unpack[Ts]]): pass self.assertNotEqual(C[*Ts1], C[*Ts2]) self.assertNotEqual(D[Unpack[Ts1]], D[Unpack[Ts2]]) + class TypeVarTuplePicklingTests(BaseTestCase): # These are slightly awkward tests to run, because TypeVarTuples are only # picklable if defined in the global scope. We therefore need to push # various things defined in these tests into the global scope with `global` # statements at the start of each test. - # TODO: RUSTPYTHON @all_pickle_protocols def test_pickling_then_unpickling_results_in_same_identity(self, proto): global global_Ts1 # See explanation at start of class. @@ -1996,7 +1994,6 @@ def test_pickling_then_unpickling_tuple_with_typevartuple_equality( self.assertEqual(t, t2) - class UnionTests(BaseTestCase): def test_basics(self): @@ -2614,7 +2611,6 @@ def test_errors(self): with self.assertRaisesRegex(TypeError, "few arguments for"): C1[int] - class TypingCallableTests(BaseCallableTests, BaseTestCase): Callable = typing.Callable @@ -2901,7 +2897,6 @@ def __init__(self, y): self.assertIsInstance(Bar(1), HasX) self.assertNotIsInstance(Capybara('a'), HasX) - def test_everything_implements_empty_protocol(self): @runtime_checkable class Empty(Protocol): @@ -2946,6 +2941,23 @@ class E(C, BP): pass self.assertNotIsInstance(D(), E) self.assertNotIsInstance(E(), D) + def test_inheritance_from_object(self): + # Inheritance from object is specifically allowed, unlike other nominal classes + class P(Protocol, object): + x: int + + self.assertEqual(typing.get_protocol_members(P), {'x'}) + + class OldGeneric(Protocol, Generic[T], object): + y: T + + self.assertEqual(typing.get_protocol_members(OldGeneric), {'y'}) + + class NewGeneric[T](Protocol, object): + z: T + + self.assertEqual(typing.get_protocol_members(NewGeneric), {'z'}) + def test_no_instantiation(self): class P(Protocol): pass @@ -2989,7 +3001,6 @@ class C: pass P.__init__(c, 1) self.assertEqual(c.x, 1) - def test_concrete_class_inheriting_init_from_protocol(self): class P(Protocol): x: int @@ -3146,7 +3157,6 @@ def x(self): ... with self.assertRaisesRegex(TypeError, only_classes_allowed): issubclass(1, BadPG) - def test_implicit_issubclass_between_two_protocols(self): @runtime_checkable class CallableMembersProto(Protocol): @@ -3188,237 +3198,1352 @@ def meth(self): ... # These two shouldn't be considered subclasses of CallableMembersProto, however, # since they don't have the `meth` protocol member - class EmptyProtocol(Protocol): ... - class UnrelatedProtocol(Protocol): - def wut(self): ... + class EmptyProtocol(Protocol): ... + class UnrelatedProtocol(Protocol): + def wut(self): ... + + self.assertNotIsSubclass(EmptyProtocol, CallableMembersProto) + self.assertNotIsSubclass(UnrelatedProtocol, CallableMembersProto) + + # These aren't protocols at all (despite having annotations), + # so they should only be considered subclasses of CallableMembersProto + # if they *actually have an attribute* matching the `meth` member + # (just having an annotation is insufficient) + + class AnnotatedButNotAProtocol: + meth: Callable[[], None] + + class NotAProtocolButAnImplicitSubclass: + def meth(self): pass + + class NotAProtocolButAnImplicitSubclass2: + meth: Callable[[], None] + def meth(self): pass + + class NotAProtocolButAnImplicitSubclass3: + meth: Callable[[], None] + meth2: Callable[[int, str], bool] + def meth(self): pass + def meth2(self, x, y): return True + + self.assertNotIsSubclass(AnnotatedButNotAProtocol, CallableMembersProto) + self.assertIsSubclass(NotAProtocolButAnImplicitSubclass, CallableMembersProto) + self.assertIsSubclass(NotAProtocolButAnImplicitSubclass2, CallableMembersProto) + self.assertIsSubclass(NotAProtocolButAnImplicitSubclass3, CallableMembersProto) + + # TODO: RUSTPYTHON + @unittest.skip("TODO: RUSTPYTHON (no gc)") + def test_isinstance_checks_not_at_whim_of_gc(self): + self.addCleanup(gc.enable) + gc.disable() + + with self.assertRaisesRegex( + TypeError, + "Protocols can only inherit from other protocols" + ): + class Foo(collections.abc.Mapping, Protocol): + pass + + self.assertNotIsInstance([], collections.abc.Mapping) + + def test_issubclass_and_isinstance_on_Protocol_itself(self): + class C: + def x(self): pass + + self.assertNotIsSubclass(object, Protocol) + self.assertNotIsInstance(object(), Protocol) + + self.assertNotIsSubclass(str, Protocol) + self.assertNotIsInstance('foo', Protocol) + + self.assertNotIsSubclass(C, Protocol) + self.assertNotIsInstance(C(), Protocol) + + only_classes_allowed = r"issubclass\(\) arg 1 must be a class" + + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass(1, Protocol) + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass('foo', Protocol) + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass(C(), Protocol) + + T = TypeVar('T') + + @runtime_checkable + class EmptyProtocol(Protocol): pass + + @runtime_checkable + class SupportsStartsWith(Protocol): + def startswith(self, x: str) -> bool: ... + + @runtime_checkable + class SupportsX(Protocol[T]): + def x(self): ... + + for proto in EmptyProtocol, SupportsStartsWith, SupportsX: + with self.subTest(proto=proto.__name__): + self.assertIsSubclass(proto, Protocol) + + # gh-105237 / PR #105239: + # check that the presence of Protocol subclasses + # where `issubclass(X, )` evaluates to True + # doesn't influence the result of `issubclass(X, Protocol)` + + self.assertIsSubclass(object, EmptyProtocol) + self.assertIsInstance(object(), EmptyProtocol) + self.assertNotIsSubclass(object, Protocol) + self.assertNotIsInstance(object(), Protocol) + + self.assertIsSubclass(str, SupportsStartsWith) + self.assertIsInstance('foo', SupportsStartsWith) + self.assertNotIsSubclass(str, Protocol) + self.assertNotIsInstance('foo', Protocol) + + self.assertIsSubclass(C, SupportsX) + self.assertIsInstance(C(), SupportsX) + self.assertNotIsSubclass(C, Protocol) + self.assertNotIsInstance(C(), Protocol) + + def test_protocols_issubclass_non_callable(self): + class C: + x = 1 + + @runtime_checkable + class PNonCall(Protocol): + x = 1 + + non_callable_members_illegal = ( + "Protocols with non-method members don't support issubclass()" + ) + + with self.assertRaisesRegex(TypeError, non_callable_members_illegal): + issubclass(C, PNonCall) + + self.assertIsInstance(C(), PNonCall) + PNonCall.register(C) + + with self.assertRaisesRegex(TypeError, non_callable_members_illegal): + issubclass(C, PNonCall) + + self.assertIsInstance(C(), PNonCall) + + # check that non-protocol subclasses are not affected + class D(PNonCall): ... + + self.assertNotIsSubclass(C, D) + self.assertNotIsInstance(C(), D) + D.register(C) + self.assertIsSubclass(C, D) + self.assertIsInstance(C(), D) + + with self.assertRaisesRegex(TypeError, non_callable_members_illegal): + issubclass(D, PNonCall) + + def test_no_weird_caching_with_issubclass_after_isinstance(self): + @runtime_checkable + class Spam(Protocol): + x: int + + class Eggs: + def __init__(self) -> None: + self.x = 42 + + self.assertIsInstance(Eggs(), Spam) + + # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, + # TypeError wouldn't be raised here, + # as the cached result of the isinstance() check immediately above + # would mean the issubclass() call would short-circuit + # before we got to the "raise TypeError" line + with self.assertRaisesRegex( + TypeError, + "Protocols with non-method members don't support issubclass()" + ): + issubclass(Eggs, Spam) + + def test_no_weird_caching_with_issubclass_after_isinstance_2(self): + @runtime_checkable + class Spam(Protocol): + x: int + + class Eggs: ... + + self.assertNotIsInstance(Eggs(), Spam) + + # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, + # TypeError wouldn't be raised here, + # as the cached result of the isinstance() check immediately above + # would mean the issubclass() call would short-circuit + # before we got to the "raise TypeError" line + with self.assertRaisesRegex( + TypeError, + "Protocols with non-method members don't support issubclass()" + ): + issubclass(Eggs, Spam) + + def test_no_weird_caching_with_issubclass_after_isinstance_3(self): + @runtime_checkable + class Spam(Protocol): + x: int + + class Eggs: + def __getattr__(self, attr): + if attr == "x": + return 42 + raise AttributeError(attr) + + self.assertNotIsInstance(Eggs(), Spam) + + # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, + # TypeError wouldn't be raised here, + # as the cached result of the isinstance() check immediately above + # would mean the issubclass() call would short-circuit + # before we got to the "raise TypeError" line + with self.assertRaisesRegex( + TypeError, + "Protocols with non-method members don't support issubclass()" + ): + issubclass(Eggs, Spam) + + def test_no_weird_caching_with_issubclass_after_isinstance_pep695(self): + @runtime_checkable + class Spam[T](Protocol): + x: T + + class Eggs[T]: + def __init__(self, x: T) -> None: + self.x = x + + self.assertIsInstance(Eggs(42), Spam) + + # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, + # TypeError wouldn't be raised here, + # as the cached result of the isinstance() check immediately above + # would mean the issubclass() call would short-circuit + # before we got to the "raise TypeError" line + with self.assertRaisesRegex( + TypeError, + "Protocols with non-method members don't support issubclass()" + ): + issubclass(Eggs, Spam) + + def test_protocols_isinstance(self): + T = TypeVar('T') + + @runtime_checkable + class P(Protocol): + def meth(x): ... + + @runtime_checkable + class PG(Protocol[T]): + def meth(x): ... + + @runtime_checkable + class WeirdProto(Protocol): + meth = str.maketrans + + @runtime_checkable + class WeirdProto2(Protocol): + meth = lambda *args, **kwargs: None + + class CustomCallable: + def __call__(self, *args, **kwargs): + pass + + @runtime_checkable + class WeirderProto(Protocol): + meth = CustomCallable() + + class BadP(Protocol): + def meth(x): ... + + class BadPG(Protocol[T]): + def meth(x): ... + + class C: + def meth(x): ... + + class C2: + def __init__(self): + self.meth = lambda: None + + for klass in C, C2: + for proto in P, PG, WeirdProto, WeirdProto2, WeirderProto: + with self.subTest(klass=klass.__name__, proto=proto.__name__): + self.assertIsInstance(klass(), proto) + + no_subscripted_generics = "Subscripted generics cannot be used with class and instance checks" + + with self.assertRaisesRegex(TypeError, no_subscripted_generics): + isinstance(C(), PG[T]) + with self.assertRaisesRegex(TypeError, no_subscripted_generics): + isinstance(C(), PG[C]) + + only_runtime_checkable_msg = ( + "Instance and class checks can only be used " + "with @runtime_checkable protocols" + ) + + with self.assertRaisesRegex(TypeError, only_runtime_checkable_msg): + isinstance(C(), BadP) + with self.assertRaisesRegex(TypeError, only_runtime_checkable_msg): + isinstance(C(), BadPG) + + def test_protocols_isinstance_properties_and_descriptors(self): + class C: + @property + def attr(self): + return 42 + + class CustomDescriptor: + def __get__(self, obj, objtype=None): + return 42 + + class D: + attr = CustomDescriptor() + + # Check that properties set on superclasses + # are still found by the isinstance() logic + class E(C): ... + class F(D): ... + + class Empty: ... + + T = TypeVar('T') + + @runtime_checkable + class P(Protocol): + @property + def attr(self): ... + + @runtime_checkable + class P1(Protocol): + attr: int + + @runtime_checkable + class PG(Protocol[T]): + @property + def attr(self): ... + + @runtime_checkable + class PG1(Protocol[T]): + attr: T + + @runtime_checkable + class MethodP(Protocol): + def attr(self): ... + + @runtime_checkable + class MethodPG(Protocol[T]): + def attr(self) -> T: ... + + for protocol_class in P, P1, PG, PG1, MethodP, MethodPG: + for klass in C, D, E, F: + with self.subTest( + klass=klass.__name__, + protocol_class=protocol_class.__name__ + ): + self.assertIsInstance(klass(), protocol_class) + + with self.subTest(klass="Empty", protocol_class=protocol_class.__name__): + self.assertNotIsInstance(Empty(), protocol_class) + + class BadP(Protocol): + @property + def attr(self): ... + + class BadP1(Protocol): + attr: int + + class BadPG(Protocol[T]): + @property + def attr(self): ... + + class BadPG1(Protocol[T]): + attr: T + + cases = ( + PG[T], PG[C], PG1[T], PG1[C], MethodPG[T], + MethodPG[C], BadP, BadP1, BadPG, BadPG1 + ) + + for obj in cases: + for klass in C, D, E, F, Empty: + with self.subTest(klass=klass.__name__, obj=obj): + with self.assertRaises(TypeError): + isinstance(klass(), obj) + + def test_protocols_isinstance_not_fooled_by_custom_dir(self): + @runtime_checkable + class HasX(Protocol): + x: int + + class CustomDirWithX: + x = 10 + def __dir__(self): + return [] + + class CustomDirWithoutX: + def __dir__(self): + return ["x"] + + self.assertIsInstance(CustomDirWithX(), HasX) + self.assertNotIsInstance(CustomDirWithoutX(), HasX) + + def test_protocols_isinstance_attribute_access_with_side_effects(self): + class C: + @property + def attr(self): + raise AttributeError('no') + + class CustomDescriptor: + def __get__(self, obj, objtype=None): + raise RuntimeError("NO") + + class D: + attr = CustomDescriptor() + + # Check that properties set on superclasses + # are still found by the isinstance() logic + class E(C): ... + class F(D): ... + + class WhyWouldYouDoThis: + def __getattr__(self, name): + raise RuntimeError("wut") + + T = TypeVar('T') + + @runtime_checkable + class P(Protocol): + @property + def attr(self): ... + + @runtime_checkable + class P1(Protocol): + attr: int + + @runtime_checkable + class PG(Protocol[T]): + @property + def attr(self): ... + + @runtime_checkable + class PG1(Protocol[T]): + attr: T + + @runtime_checkable + class MethodP(Protocol): + def attr(self): ... + + @runtime_checkable + class MethodPG(Protocol[T]): + def attr(self) -> T: ... + + for protocol_class in P, P1, PG, PG1, MethodP, MethodPG: + for klass in C, D, E, F: + with self.subTest( + klass=klass.__name__, + protocol_class=protocol_class.__name__ + ): + self.assertIsInstance(klass(), protocol_class) + + with self.subTest( + klass="WhyWouldYouDoThis", + protocol_class=protocol_class.__name__ + ): + self.assertNotIsInstance(WhyWouldYouDoThis(), protocol_class) + + def test_protocols_isinstance___slots__(self): + # As per the consensus in https://github.com/python/typing/issues/1367, + # this is desirable behaviour + @runtime_checkable + class HasX(Protocol): + x: int + + class HasNothingButSlots: + __slots__ = ("x",) + + self.assertIsInstance(HasNothingButSlots(), HasX) + + def test_protocols_isinstance_py36(self): + class APoint: + def __init__(self, x, y, label): + self.x = x + self.y = y + self.label = label + + class BPoint: + label = 'B' + + def __init__(self, x, y): + self.x = x + self.y = y + + class C: + def __init__(self, attr): + self.attr = attr + + def meth(self, arg): + return 0 + + class Bad: pass + + self.assertIsInstance(APoint(1, 2, 'A'), Point) + self.assertIsInstance(BPoint(1, 2), Point) + self.assertNotIsInstance(MyPoint(), Point) + self.assertIsInstance(BPoint(1, 2), Position) + self.assertIsInstance(Other(), Proto) + self.assertIsInstance(Concrete(), Proto) + self.assertIsInstance(C(42), Proto) + self.assertNotIsInstance(Bad(), Proto) + self.assertNotIsInstance(Bad(), Point) + self.assertNotIsInstance(Bad(), Position) + self.assertNotIsInstance(Bad(), Concrete) + self.assertNotIsInstance(Other(), Concrete) + self.assertIsInstance(NT(1, 2), Position) + + def test_protocols_isinstance_init(self): + T = TypeVar('T') + + @runtime_checkable + class P(Protocol): + x = 1 + + @runtime_checkable + class PG(Protocol[T]): + x = 1 + + class C: + def __init__(self, x): + self.x = x + + self.assertIsInstance(C(1), P) + self.assertIsInstance(C(1), PG) + + def test_protocols_isinstance_monkeypatching(self): + @runtime_checkable + class HasX(Protocol): + x: int + + class Foo: ... + + f = Foo() + self.assertNotIsInstance(f, HasX) + f.x = 42 + self.assertIsInstance(f, HasX) + del f.x + self.assertNotIsInstance(f, HasX) + + def test_protocol_checks_after_subscript(self): + class P(Protocol[T]): pass + class C(P[T]): pass + class Other1: pass + class Other2: pass + CA = C[Any] + + self.assertNotIsInstance(Other1(), C) + self.assertNotIsSubclass(Other2, C) + + class D1(C[Any]): pass + class D2(C[Any]): pass + CI = C[int] + + self.assertIsInstance(D1(), C) + self.assertIsSubclass(D2, C) + + def test_protocols_support_register(self): + @runtime_checkable + class P(Protocol): + x = 1 + + class PM(Protocol): + def meth(self): pass + + class D(PM): pass + + class C: pass + + D.register(C) + P.register(C) + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), D) + + def test_none_on_non_callable_doesnt_block_implementation(self): + @runtime_checkable + class P(Protocol): + x = 1 + + class A: + x = 1 + + class B(A): + x = None + + class C: + def __init__(self): + self.x = None + + self.assertIsInstance(B(), P) + self.assertIsInstance(C(), P) + + def test_none_on_callable_blocks_implementation(self): + @runtime_checkable + class P(Protocol): + def x(self): ... + + class A: + def x(self): ... + + class B(A): + x = None + + class C: + def __init__(self): + self.x = None + + self.assertNotIsInstance(B(), P) + self.assertNotIsInstance(C(), P) + + def test_non_protocol_subclasses(self): + class P(Protocol): + x = 1 + + @runtime_checkable + class PR(Protocol): + def meth(self): pass + + class NonP(P): + x = 1 + + class NonPR(PR): pass + + class C(metaclass=abc.ABCMeta): + x = 1 + + class D(metaclass=abc.ABCMeta): + def meth(self): pass + + self.assertNotIsInstance(C(), NonP) + self.assertNotIsInstance(D(), NonPR) + self.assertNotIsSubclass(C, NonP) + self.assertNotIsSubclass(D, NonPR) + self.assertIsInstance(NonPR(), PR) + self.assertIsSubclass(NonPR, PR) + + self.assertNotIn("__protocol_attrs__", vars(NonP)) + self.assertNotIn("__protocol_attrs__", vars(NonPR)) + self.assertNotIn("__non_callable_proto_members__", vars(NonP)) + self.assertNotIn("__non_callable_proto_members__", vars(NonPR)) + + self.assertEqual(get_protocol_members(P), {"x"}) + self.assertEqual(get_protocol_members(PR), {"meth"}) + + # the returned object should be immutable, + # and should be a different object to the original attribute + # to prevent users from (accidentally or deliberately) + # mutating the attribute on the original class + self.assertIsInstance(get_protocol_members(P), frozenset) + self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__) + self.assertIsInstance(get_protocol_members(PR), frozenset) + self.assertIsNot(get_protocol_members(PR), P.__protocol_attrs__) + + acceptable_extra_attrs = { + '_is_protocol', '_is_runtime_protocol', '__parameters__', + '__init__', '__annotations__', '__subclasshook__', + } + self.assertLessEqual(vars(NonP).keys(), vars(C).keys() | acceptable_extra_attrs) + self.assertLessEqual( + vars(NonPR).keys(), vars(D).keys() | acceptable_extra_attrs + ) + + def test_custom_subclasshook(self): + class P(Protocol): + x = 1 + + class OKClass: pass + + class BadClass: + x = 1 + + class C(P): + @classmethod + def __subclasshook__(cls, other): + return other.__name__.startswith("OK") + + self.assertIsInstance(OKClass(), C) + self.assertNotIsInstance(BadClass(), C) + self.assertIsSubclass(OKClass, C) + self.assertNotIsSubclass(BadClass, C) + + def test_custom_subclasshook_2(self): + @runtime_checkable + class HasX(Protocol): + # The presence of a non-callable member + # would mean issubclass() checks would fail with TypeError + # if it weren't for the custom `__subclasshook__` method + x = 1 + + @classmethod + def __subclasshook__(cls, other): + return hasattr(other, 'x') + + class Empty: pass + + class ImplementsHasX: + x = 1 + + self.assertIsInstance(ImplementsHasX(), HasX) + self.assertNotIsInstance(Empty(), HasX) + self.assertIsSubclass(ImplementsHasX, HasX) + self.assertNotIsSubclass(Empty, HasX) + + # isinstance() and issubclass() checks against this still raise TypeError, + # despite the presence of the custom __subclasshook__ method, + # as it's not decorated with @runtime_checkable + class NotRuntimeCheckable(Protocol): + @classmethod + def __subclasshook__(cls, other): + return hasattr(other, 'x') + + must_be_runtime_checkable = ( + "Instance and class checks can only be used " + "with @runtime_checkable protocols" + ) + + with self.assertRaisesRegex(TypeError, must_be_runtime_checkable): + issubclass(object, NotRuntimeCheckable) + with self.assertRaisesRegex(TypeError, must_be_runtime_checkable): + isinstance(object(), NotRuntimeCheckable) + + def test_issubclass_fails_correctly(self): + @runtime_checkable + class NonCallableMembers(Protocol): + x = 1 + + class NotRuntimeCheckable(Protocol): + def callable_member(self) -> int: ... + + @runtime_checkable + class RuntimeCheckable(Protocol): + def callable_member(self) -> int: ... + + class C: pass + + # These three all exercise different code paths, + # but should result in the same error message: + for protocol in NonCallableMembers, NotRuntimeCheckable, RuntimeCheckable: + with self.subTest(proto_name=protocol.__name__): + with self.assertRaisesRegex( + TypeError, r"issubclass\(\) arg 1 must be a class" + ): + issubclass(C(), protocol) + + def test_defining_generic_protocols(self): + T = TypeVar('T') + S = TypeVar('S') + + @runtime_checkable + class PR(Protocol[T, S]): + def meth(self): pass + + class P(PR[int, T], Protocol[T]): + y = 1 + + with self.assertRaises(TypeError): + PR[int] + with self.assertRaises(TypeError): + P[int, str] + + class C(PR[int, T]): pass + + self.assertIsInstance(C[str](), C) + + def test_defining_generic_protocols_old_style(self): + T = TypeVar('T') + S = TypeVar('S') + + @runtime_checkable + class PR(Protocol, Generic[T, S]): + def meth(self): pass + + class P(PR[int, str], Protocol): + y = 1 + + with self.assertRaises(TypeError): + issubclass(PR[int, str], PR) + self.assertIsSubclass(P, PR) + with self.assertRaises(TypeError): + PR[int] + + class P1(Protocol, Generic[T]): + def bar(self, x: T) -> str: ... + + class P2(Generic[T], Protocol): + def bar(self, x: T) -> str: ... + + @runtime_checkable + class PSub(P1[str], Protocol): + x = 1 + + class Test: + x = 1 + + def bar(self, x: str) -> str: + return x + + self.assertIsInstance(Test(), PSub) + + def test_pep695_generic_protocol_callable_members(self): + @runtime_checkable + class Foo[T](Protocol): + def meth(self, x: T) -> None: ... + + class Bar[T]: + def meth(self, x: T) -> None: ... + + self.assertIsInstance(Bar(), Foo) + self.assertIsSubclass(Bar, Foo) + + @runtime_checkable + class SupportsTrunc[T](Protocol): + def __trunc__(self) -> T: ... + + self.assertIsInstance(0.0, SupportsTrunc) + self.assertIsSubclass(float, SupportsTrunc) + + def test_init_called(self): + T = TypeVar('T') + + class P(Protocol[T]): pass + + class C(P[T]): + def __init__(self): + self.test = 'OK' + + self.assertEqual(C[int]().test, 'OK') + + class B: + def __init__(self): + self.test = 'OK' + + class D1(B, P[T]): + pass + + self.assertEqual(D1[int]().test, 'OK') + + class D2(P[T], B): + pass + + self.assertEqual(D2[int]().test, 'OK') + + def test_new_called(self): + T = TypeVar('T') + + class P(Protocol[T]): pass + + class C(P[T]): + def __new__(cls, *args): + self = super().__new__(cls, *args) + self.test = 'OK' + return self + + self.assertEqual(C[int]().test, 'OK') + with self.assertRaises(TypeError): + C[int](42) + with self.assertRaises(TypeError): + C[int](a=42) + + def test_protocols_bad_subscripts(self): + T = TypeVar('T') + S = TypeVar('S') + with self.assertRaises(TypeError): + class P(Protocol[T, T]): pass + with self.assertRaises(TypeError): + class Q(Protocol[int]): pass + with self.assertRaises(TypeError): + class R(Protocol[T], Protocol[S]): pass + with self.assertRaises(TypeError): + class S(typing.Mapping[T, S], Protocol[T]): pass + + def test_generic_protocols_repr(self): + T = TypeVar('T') + S = TypeVar('S') + + class P(Protocol[T, S]): pass + + self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]')) + self.assertTrue(repr(P[int, str]).endswith('P[int, str]')) + + def test_generic_protocols_eq(self): + T = TypeVar('T') + S = TypeVar('S') + + class P(Protocol[T, S]): pass + + self.assertEqual(P, P) + self.assertEqual(P[int, T], P[int, T]) + self.assertEqual(P[T, T][Tuple[T, S]][int, str], + P[Tuple[int, str], Tuple[int, str]]) + + def test_generic_protocols_special_from_generic(self): + T = TypeVar('T') + + class P(Protocol[T]): pass + + self.assertEqual(P.__parameters__, (T,)) + self.assertEqual(P[int].__parameters__, ()) + self.assertEqual(P[int].__args__, (int,)) + self.assertIs(P[int].__origin__, P) + + def test_generic_protocols_special_from_protocol(self): + @runtime_checkable + class PR(Protocol): + x = 1 + + class P(Protocol): + def meth(self): + pass + + T = TypeVar('T') + + class PG(Protocol[T]): + x = 1 + + def meth(self): + pass + + self.assertTrue(P._is_protocol) + self.assertTrue(PR._is_protocol) + self.assertTrue(PG._is_protocol) + self.assertFalse(P._is_runtime_protocol) + self.assertTrue(PR._is_runtime_protocol) + self.assertTrue(PG[int]._is_protocol) + self.assertEqual(typing._get_protocol_attrs(P), {'meth'}) + self.assertEqual(typing._get_protocol_attrs(PR), {'x'}) + self.assertEqual(frozenset(typing._get_protocol_attrs(PG)), + frozenset({'x', 'meth'})) + + def test_no_runtime_deco_on_nominal(self): + with self.assertRaises(TypeError): + @runtime_checkable + class C: pass + + class Proto(Protocol): + x = 1 + + with self.assertRaises(TypeError): + @runtime_checkable + class Concrete(Proto): + pass + + def test_none_treated_correctly(self): + @runtime_checkable + class P(Protocol): + x = None # type: int + + class B(object): pass + + self.assertNotIsInstance(B(), P) + + class C: + x = 1 + + class D: + x = None + + self.assertIsInstance(C(), P) + self.assertIsInstance(D(), P) + + class CI: + def __init__(self): + self.x = 1 + + class DI: + def __init__(self): + self.x = None + + self.assertIsInstance(CI(), P) + self.assertIsInstance(DI(), P) + + def test_protocols_in_unions(self): + class P(Protocol): + x = None # type: int + + Alias = typing.Union[typing.Iterable, P] + Alias2 = typing.Union[P, typing.Iterable] + self.assertEqual(Alias, Alias2) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_protocols_pickleable(self): + global P, CP # pickle wants to reference the class by name + T = TypeVar('T') + + @runtime_checkable + class P(Protocol[T]): + x = 1 + + class CP(P[int]): + pass + + c = CP() + c.foo = 42 + c.bar = 'abc' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(c, proto) + x = pickle.loads(z) + self.assertEqual(x.foo, 42) + self.assertEqual(x.bar, 'abc') + self.assertEqual(x.x, 1) + self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'}) + s = pickle.dumps(P, proto) + D = pickle.loads(s) + + class E: + x = 1 + + self.assertIsInstance(E(), D) + + def test_runtime_checkable_with_match_args(self): + @runtime_checkable + class P_regular(Protocol): + x: int + y: int + + @runtime_checkable + class P_match(Protocol): + __match_args__ = ('x', 'y') + x: int + y: int + + class Regular: + def __init__(self, x: int, y: int): + self.x = x + self.y = y + + class WithMatch: + __match_args__ = ('x', 'y', 'z') + def __init__(self, x: int, y: int, z: int): + self.x = x + self.y = y + self.z = z + + class Nope: ... + + self.assertIsInstance(Regular(1, 2), P_regular) + self.assertIsInstance(Regular(1, 2), P_match) + self.assertIsInstance(WithMatch(1, 2, 3), P_regular) + self.assertIsInstance(WithMatch(1, 2, 3), P_match) + self.assertNotIsInstance(Nope(), P_regular) + self.assertNotIsInstance(Nope(), P_match) + + def test_supports_int(self): + self.assertIsSubclass(int, typing.SupportsInt) + self.assertNotIsSubclass(str, typing.SupportsInt) + + def test_supports_float(self): + self.assertIsSubclass(float, typing.SupportsFloat) + self.assertNotIsSubclass(str, typing.SupportsFloat) + + def test_supports_complex(self): + + class C: + def __complex__(self): + return 0j + + self.assertIsSubclass(complex, typing.SupportsComplex) + self.assertIsSubclass(C, typing.SupportsComplex) + self.assertNotIsSubclass(str, typing.SupportsComplex) - self.assertNotIsSubclass(EmptyProtocol, CallableMembersProto) - self.assertNotIsSubclass(UnrelatedProtocol, CallableMembersProto) + def test_supports_bytes(self): - # These aren't protocols at all (despite having annotations), - # so they should only be considered subclasses of CallableMembersProto - # if they *actually have an attribute* matching the `meth` member - # (just having an annotation is insufficient) + class B: + def __bytes__(self): + return b'' + + self.assertIsSubclass(bytes, typing.SupportsBytes) + self.assertIsSubclass(B, typing.SupportsBytes) + self.assertNotIsSubclass(str, typing.SupportsBytes) + + def test_supports_abs(self): + self.assertIsSubclass(float, typing.SupportsAbs) + self.assertIsSubclass(int, typing.SupportsAbs) + self.assertNotIsSubclass(str, typing.SupportsAbs) + + def test_supports_round(self): + issubclass(float, typing.SupportsRound) + self.assertIsSubclass(float, typing.SupportsRound) + self.assertIsSubclass(int, typing.SupportsRound) + self.assertNotIsSubclass(str, typing.SupportsRound) + + def test_reversible(self): + self.assertIsSubclass(list, typing.Reversible) + self.assertNotIsSubclass(int, typing.Reversible) + + def test_supports_index(self): + self.assertIsSubclass(int, typing.SupportsIndex) + self.assertNotIsSubclass(str, typing.SupportsIndex) + + def test_bundled_protocol_instance_works(self): + self.assertIsInstance(0, typing.SupportsAbs) + class C1(typing.SupportsInt): + def __int__(self) -> int: + return 42 + class C2(C1): + pass + c = C2() + self.assertIsInstance(c, C1) - class AnnotatedButNotAProtocol: - meth: Callable[[], None] + def test_collections_protocols_allowed(self): + @runtime_checkable + class Custom(collections.abc.Iterable, Protocol): + def close(self): ... - class NotAProtocolButAnImplicitSubclass: - def meth(self): pass + class A: pass + class B: + def __iter__(self): + return [] + def close(self): + return 0 - class NotAProtocolButAnImplicitSubclass2: - meth: Callable[[], None] - def meth(self): pass + self.assertIsSubclass(B, Custom) + self.assertNotIsSubclass(A, Custom) - class NotAProtocolButAnImplicitSubclass3: - meth: Callable[[], None] - meth2: Callable[[int, str], bool] - def meth(self): pass - def meth2(self, x, y): return True + @runtime_checkable + class ReleasableBuffer(collections.abc.Buffer, Protocol): + def __release_buffer__(self, mv: memoryview) -> None: ... - self.assertNotIsSubclass(AnnotatedButNotAProtocol, CallableMembersProto) - self.assertIsSubclass(NotAProtocolButAnImplicitSubclass, CallableMembersProto) - self.assertIsSubclass(NotAProtocolButAnImplicitSubclass2, CallableMembersProto) - self.assertIsSubclass(NotAProtocolButAnImplicitSubclass3, CallableMembersProto) + class C: pass + class D: + def __buffer__(self, flags: int) -> memoryview: + return memoryview(b'') + def __release_buffer__(self, mv: memoryview) -> None: + pass - # TODO: RUSTPYTHON - @unittest.skip("TODO: RUSTPYTHON (no gc)") - def test_isinstance_checks_not_at_whim_of_gc(self): - self.addCleanup(gc.enable) - gc.disable() + self.assertIsSubclass(D, ReleasableBuffer) + self.assertIsInstance(D(), ReleasableBuffer) + self.assertNotIsSubclass(C, ReleasableBuffer) + self.assertNotIsInstance(C(), ReleasableBuffer) - with self.assertRaisesRegex( - TypeError, - "Protocols can only inherit from other protocols" - ): - class Foo(collections.abc.Mapping, Protocol): + def test_builtin_protocol_allowlist(self): + with self.assertRaises(TypeError): + class CustomProtocol(TestCase, Protocol): pass - self.assertNotIsInstance([], collections.abc.Mapping) + class CustomContextManager(typing.ContextManager, Protocol): + pass - def test_issubclass_and_isinstance_on_Protocol_itself(self): - class C: - def x(self): pass + class CustomAsyncIterator(typing.AsyncIterator, Protocol): + pass - self.assertNotIsSubclass(object, Protocol) - self.assertNotIsInstance(object(), Protocol) + def test_non_runtime_protocol_isinstance_check(self): + class P(Protocol): + x: int - self.assertNotIsSubclass(str, Protocol) - self.assertNotIsInstance('foo', Protocol) + with self.assertRaisesRegex(TypeError, "@runtime_checkable"): + isinstance(1, P) - self.assertNotIsSubclass(C, Protocol) - self.assertNotIsInstance(C(), Protocol) + def test_super_call_init(self): + class P(Protocol): + x: int - only_classes_allowed = r"issubclass\(\) arg 1 must be a class" + class Foo(P): + def __init__(self): + super().__init__() - with self.assertRaisesRegex(TypeError, only_classes_allowed): - issubclass(1, Protocol) - with self.assertRaisesRegex(TypeError, only_classes_allowed): - issubclass('foo', Protocol) - with self.assertRaisesRegex(TypeError, only_classes_allowed): - issubclass(C(), Protocol) + Foo() # Previously triggered RecursionError - T = TypeVar('T') + def test_get_protocol_members(self): + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(object) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(object()) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Protocol) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Generic) - @runtime_checkable - class EmptyProtocol(Protocol): pass + class P(Protocol): + a: int + def b(self) -> str: ... + @property + def c(self) -> int: ... - @runtime_checkable - class SupportsStartsWith(Protocol): - def startswith(self, x: str) -> bool: ... + self.assertEqual(get_protocol_members(P), {'a', 'b', 'c'}) + self.assertIsInstance(get_protocol_members(P), frozenset) + self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__) - @runtime_checkable - class SupportsX(Protocol[T]): - def x(self): ... + class Concrete: + a: int + def b(self) -> str: return "capybara" + @property + def c(self) -> int: return 5 - for proto in EmptyProtocol, SupportsStartsWith, SupportsX: - with self.subTest(proto=proto.__name__): - self.assertIsSubclass(proto, Protocol) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Concrete) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Concrete()) - # gh-105237 / PR #105239: - # check that the presence of Protocol subclasses - # where `issubclass(X, )` evaluates to True - # doesn't influence the result of `issubclass(X, Protocol)` + class ConcreteInherit(P): + a: int = 42 + def b(self) -> str: return "capybara" + @property + def c(self) -> int: return 5 - self.assertIsSubclass(object, EmptyProtocol) - self.assertIsInstance(object(), EmptyProtocol) - self.assertNotIsSubclass(object, Protocol) - self.assertNotIsInstance(object(), Protocol) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(ConcreteInherit) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(ConcreteInherit()) - self.assertIsSubclass(str, SupportsStartsWith) - self.assertIsInstance('foo', SupportsStartsWith) - self.assertNotIsSubclass(str, Protocol) - self.assertNotIsInstance('foo', Protocol) + def test_is_protocol(self): + self.assertTrue(is_protocol(Proto)) + self.assertTrue(is_protocol(Point)) + self.assertFalse(is_protocol(Concrete)) + self.assertFalse(is_protocol(Concrete())) + self.assertFalse(is_protocol(Generic)) + self.assertFalse(is_protocol(object)) - self.assertIsSubclass(C, SupportsX) - self.assertIsInstance(C(), SupportsX) - self.assertNotIsSubclass(C, Protocol) - self.assertNotIsInstance(C(), Protocol) + # Protocol is not itself a protocol + self.assertFalse(is_protocol(Protocol)) - def test_protocols_issubclass_non_callable(self): - class C: - x = 1 + def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta(self): + # Ensure the cache is empty, or this test won't work correctly + collections.abc.Sized._abc_registry_clear() - @runtime_checkable - class PNonCall(Protocol): - x = 1 + class Foo(collections.abc.Sized, Protocol): pass - non_callable_members_illegal = ( - "Protocols with non-method members don't support issubclass()" - ) + # gh-105144: this previously raised TypeError + # if a Protocol subclass of Sized had been created + # before any isinstance() checks against Sized + self.assertNotIsInstance(1, collections.abc.Sized) - with self.assertRaisesRegex(TypeError, non_callable_members_illegal): - issubclass(C, PNonCall) + def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta_2(self): + # Ensure the cache is empty, or this test won't work correctly + collections.abc.Sized._abc_registry_clear() - self.assertIsInstance(C(), PNonCall) - PNonCall.register(C) + class Foo(typing.Sized, Protocol): pass - with self.assertRaisesRegex(TypeError, non_callable_members_illegal): - issubclass(C, PNonCall) + # gh-105144: this previously raised TypeError + # if a Protocol subclass of Sized had been created + # before any isinstance() checks against Sized + self.assertNotIsInstance(1, typing.Sized) - self.assertIsInstance(C(), PNonCall) + def test_empty_protocol_decorated_with_final(self): + @final + @runtime_checkable + class EmptyProtocol(Protocol): ... - # check that non-protocol subclasses are not affected - class D(PNonCall): ... + self.assertIsSubclass(object, EmptyProtocol) + self.assertIsInstance(object(), EmptyProtocol) - self.assertNotIsSubclass(C, D) - self.assertNotIsInstance(C(), D) - D.register(C) - self.assertIsSubclass(C, D) - self.assertIsInstance(C(), D) + def test_protocol_decorated_with_final_callable_members(self): + @final + @runtime_checkable + class ProtocolWithMethod(Protocol): + def startswith(self, string: str) -> bool: ... - with self.assertRaisesRegex(TypeError, non_callable_members_illegal): - issubclass(D, PNonCall) + self.assertIsSubclass(str, ProtocolWithMethod) + self.assertNotIsSubclass(int, ProtocolWithMethod) + self.assertIsInstance('foo', ProtocolWithMethod) + self.assertNotIsInstance(42, ProtocolWithMethod) - def test_no_weird_caching_with_issubclass_after_isinstance(self): + def test_protocol_decorated_with_final_noncallable_members(self): + @final @runtime_checkable - class Spam(Protocol): + class ProtocolWithNonCallableMember(Protocol): x: int - class Eggs: - def __init__(self) -> None: - self.x = 42 + class Foo: + x = 42 - self.assertIsInstance(Eggs(), Spam) + only_callable_members_please = ( + r"Protocols with non-method members don't support issubclass()" + ) - # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, - # TypeError wouldn't be raised here, - # as the cached result of the isinstance() check immediately above - # would mean the issubclass() call would short-circuit - # before we got to the "raise TypeError" line - with self.assertRaisesRegex( - TypeError, - "Protocols with non-method members don't support issubclass()" - ): - issubclass(Eggs, Spam) + with self.assertRaisesRegex(TypeError, only_callable_members_please): + issubclass(Foo, ProtocolWithNonCallableMember) - def test_no_weird_caching_with_issubclass_after_isinstance_2(self): + with self.assertRaisesRegex(TypeError, only_callable_members_please): + issubclass(int, ProtocolWithNonCallableMember) + + self.assertIsInstance(Foo(), ProtocolWithNonCallableMember) + self.assertNotIsInstance(42, ProtocolWithNonCallableMember) + + def test_protocol_decorated_with_final_mixed_members(self): + @final @runtime_checkable - class Spam(Protocol): + class ProtocolWithMixedMembers(Protocol): x: int + def method(self) -> None: ... - class Eggs: ... + class Foo: + x = 42 + def method(self) -> None: ... - self.assertNotIsInstance(Eggs(), Spam) + only_callable_members_please = ( + r"Protocols with non-method members don't support issubclass()" + ) - # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, - # TypeError wouldn't be raised here, - # as the cached result of the isinstance() check immediately above - # would mean the issubclass() call would short-circuit - # before we got to the "raise TypeError" line - with self.assertRaisesRegex( - TypeError, - "Protocols with non-method members don't support issubclass()" - ): - issubclass(Eggs, Spam) + with self.assertRaisesRegex(TypeError, only_callable_members_please): + issubclass(Foo, ProtocolWithMixedMembers) - def test_no_weird_caching_with_issubclass_after_isinstance_3(self): + with self.assertRaisesRegex(TypeError, only_callable_members_please): + issubclass(int, ProtocolWithMixedMembers) + + self.assertIsInstance(Foo(), ProtocolWithMixedMembers) + self.assertNotIsInstance(42, ProtocolWithMixedMembers) + + def test_protocol_issubclass_error_message(self): @runtime_checkable - class Spam(Protocol): - x: int + class Vec2D(Protocol): + x: float + y: float - class Eggs: - def __getattr__(self, attr): - if attr == "x": - return 42 - raise AttributeError(attr) + def square_norm(self) -> float: + return self.x ** 2 + self.y ** 2 - self.assertNotIsInstance(Eggs(), Spam) + self.assertEqual(Vec2D.__protocol_attrs__, {'x', 'y', 'square_norm'}) + expected_error_message = ( + "Protocols with non-method members don't support issubclass()." + " Non-method members: 'x', 'y'." + ) + with self.assertRaisesRegex(TypeError, re.escape(expected_error_message)): + issubclass(int, Vec2D) - # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, - # TypeError wouldn't be raised here, - # as the cached result of the isinstance() check immediately above - # would mean the issubclass() call would short-circuit - # before we got to the "raise TypeError" line - with self.assertRaisesRegex( - TypeError, - "Protocols with non-method members don't support issubclass()" - ): - issubclass(Eggs, Spam) + def test_nonruntime_protocol_interaction_with_evil_classproperty(self): + class classproperty: + def __get__(self, instance, type): + raise RuntimeError("NO") - def test_no_weird_caching_with_issubclass_after_isinstance_pep695(self): - @runtime_checkable - class Spam[T](Protocol): - x: T + class Commentable(Protocol): + evil = classproperty() - class Eggs[T]: - def __init__(self, x: T) -> None: - self.x = x + # recognised as a protocol attr, + # but not actually accessed by the protocol metaclass + # (which would raise RuntimeError) for non-runtime protocols. + # See gh-113320 + self.assertEqual(get_protocol_members(Commentable), {"evil"}) - self.assertIsInstance(Eggs(42), Spam) + def test_runtime_protocol_interaction_with_evil_classproperty(self): + class CustomError(Exception): pass - # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, - # TypeError wouldn't be raised here, - # as the cached result of the isinstance() check immediately above - # would mean the issubclass() call would short-circuit - # before we got to the "raise TypeError" line - with self.assertRaisesRegex( - TypeError, - "Protocols with non-method members don't support issubclass()" - ): - issubclass(Eggs, Spam) - - # FIXME(arihant2math): start more porting from test_protocols_isinstance + class classproperty: + def __get__(self, instance, type): + raise CustomError + + with self.assertRaises(TypeError) as cm: + @runtime_checkable + class Commentable(Protocol): + evil = classproperty() + + exc = cm.exception + self.assertEqual( + exc.args[0], + "Failed to determine whether protocol member 'evil' is a method member" + ) + self.assertIs(type(exc.__cause__), CustomError) class GenericTests(BaseTestCase): @@ -4643,7 +5768,6 @@ def test_no_isinstance(self): with self.assertRaises(TypeError): issubclass(int, ClassVar) - class FinalTests(BaseTestCase): def test_basics(self): @@ -4772,8 +5896,6 @@ def cached(self): ... class OverrideDecoratorTests(BaseTestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_override(self): class Base: def normal_method(self): ... @@ -4832,8 +5954,6 @@ def static_method_bad_order(): self.assertIs(False, hasattr(Base.static_method_good_order, "__override__")) self.assertIs(False, hasattr(Base.static_method_bad_order, "__override__")) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_property(self): class Base: @property @@ -4860,8 +5980,6 @@ def wrong(self) -> int: self.assertFalse(hasattr(Child.wrong, "__override__")) self.assertFalse(hasattr(Child.wrong.fset, "__override__")) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_silent_failure(self): class CustomProp: __slots__ = ('fget',) @@ -4879,8 +5997,6 @@ def some(self): self.assertEqual(WithOverride.some, 1) self.assertFalse(hasattr(WithOverride.some, "__override__")) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_multiple_decorators(self): def with_wraps(f): # similar to `lru_cache` definition @wraps(f) @@ -4939,7 +6055,6 @@ def test_errors(self): self.assertIs(assert_type(arg, 'hello'), arg) - # We need this to make sure that `@no_type_check` respects `__module__` attr: from test.typinganndata import ann_module8 @@ -4953,6 +6068,7 @@ class NoTypeCheck_WithFunction: class ForwardRefTests(BaseTestCase): + def test_basics(self): class Node(Generic[T]): @@ -5561,11 +6677,11 @@ def test_overload_registry_repeated(self): self.assertEqual(list(get_overloads(impl)), overloads) + from test.typinganndata import ( ann_module, ann_module2, ann_module3, ann_module5, ann_module6, ) - T_a = TypeVar('T_a') class AwaitableWrapper(typing.Awaitable[T_a]): @@ -5598,7 +6714,6 @@ async def __aenter__(self) -> int: async def __aexit__(self, etype, eval, tb): return None - class A: y: float class B(A): @@ -5692,10 +6807,9 @@ class WeirdlyQuotedMovie(TypedDict): title: Annotated['Annotated[Required[str], "foobar"]', "another level"] year: NotRequired['Annotated[int, 2000]'] -# TODO: RUSTPYTHON -# class HasForeignBaseClass(mod_generics_cache.A): -# some_xrepr: 'XRepr' -# other_a: 'mod_generics_cache.A' +class HasForeignBaseClass(mod_generics_cache.A): + some_xrepr: 'XRepr' + other_a: 'mod_generics_cache.A' async def g_with(am: typing.AsyncContextManager[int]): x: int @@ -5709,7 +6823,6 @@ async def g_with(am: typing.AsyncContextManager[int]): gth = get_type_hints - class ForRefExample: @ann_module.dec def func(self: 'ForRefExample'): @@ -5748,8 +6861,6 @@ def test_get_type_hints_modules_forwardref(self): 'default_b': Optional[mod_generics_cache.B]} self.assertEqual(gth(mod_generics_cache), mgc_hints) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_get_type_hints_classes(self): self.assertEqual(gth(ann_module.C), # gth will find the right globalns {'y': Optional[ann_module.C]}) @@ -6040,7 +7151,6 @@ def h(x: collections.abc.Callable[P, int]): ... self.assertEqual(get_type_hints(h), {'x': collections.abc.Callable[P, int]}) - class GetUtilitiesTestCase(TestCase): def test_get_origin(self): T = TypeVar('T') @@ -6142,7 +7252,6 @@ class C(Generic[T]): pass self.assertEqual(get_args(Unpack[tuple[Unpack[Ts]]]), (tuple[Unpack[Ts]],)) - class CollectionsAbcTests(BaseTestCase): def test_hashable(self): @@ -6800,7 +7909,6 @@ class ProUserId(UserId): ... - class NamedTupleTests(BaseTestCase): class NestedEmployee(NamedTuple): name: str @@ -7291,6 +8399,22 @@ class TD2(TD1): self.assertIs(TD2.__total__, True) + def test_total_with_assigned_value(self): + class TD(TypedDict): + __total__ = "some_value" + + self.assertIs(TD.__total__, True) + + class TD2(TypedDict, total=True): + __total__ = "some_value" + + self.assertIs(TD2.__total__, True) + + class TD3(TypedDict, total=False): + __total__ = "some value" + + self.assertIs(TD3.__total__, False) + def test_optional_keys(self): class Point2Dor3D(Point2D, total=False): z: int @@ -8003,7 +9127,6 @@ class B(typing.Pattern): pass - class AnnotatedTests(BaseTestCase): def test_new(self): @@ -8511,7 +9634,6 @@ def test_cannot_subscript(self): TypeAlias[int] - class ParamSpecTests(BaseTestCase): def test_basic_plain(self): @@ -8565,7 +9687,6 @@ def test_args_kwargs(self): self.assertEqual(repr(P.args), "P.args") self.assertEqual(repr(P.kwargs), "P.kwargs") - def test_stringized(self): P = ParamSpec('P') class C(Generic[P]): @@ -8907,7 +10028,6 @@ class G(P.args): pass class H(P.kwargs): pass - class ConcatenateTests(BaseTestCase): def test_basics(self): P = ParamSpec('P') @@ -8974,7 +10094,6 @@ def test_var_substitution(self): self.assertEqual(C[Concatenate[str, P2]], Concatenate[int, str, P2]) self.assertEqual(C[...], Concatenate[int, ...]) - class TypeGuardTests(BaseTestCase): def test_basics(self): TypeGuard[int] # OK @@ -9080,6 +10199,7 @@ def test_no_isinstance(self): class SpecialAttrsTests(BaseTestCase): + def test_special_attrs(self): cls_to_check = { # ABC classes diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index ad07d3df72..af23005a67 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -1163,6 +1163,7 @@ mod _sqlite { callable: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { + name.ensure_valid_utf8(vm)?; let name = name.to_cstring(vm)?; let db = self.db_lock(vm)?; let Some(data) = CallbackData::new(callable.clone(), vm) else { @@ -1832,8 +1833,15 @@ mod _sqlite { let text_factory = zelf.connection.text_factory.to_owned(); if text_factory.is(PyStr::class(&vm.ctx)) { - let text = String::from_utf8(text).map_err(|_| { - new_operational_error(vm, "not valid UTF-8".to_owned()) + let text = String::from_utf8(text).map_err(|err| { + let col_name = st.column_name(i); + let col_name_str = ptr_to_str(col_name, vm).unwrap_or("?"); + let valid_up_to = err.utf8_error().valid_up_to(); + let text_prefix = String::from_utf8_lossy(&err.as_bytes()[..valid_up_to]); + let msg = format!( + "Could not decode to UTF-8 column '{col_name_str}' with text '{text_prefix}'" + ); + new_operational_error(vm, msg) })?; vm.ctx.new_str(text).into() } else if text_factory.is(PyBytes::class(&vm.ctx)) { diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index fe7ad6a98a..1c38314afe 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -457,7 +457,7 @@ impl PyStr { self.data.as_str() } - fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> { + pub fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> { if self.is_utf8() { Ok(()) } else {