diff --git a/mypy/checker.py b/mypy/checker.py index a650bdf2a639..94f0d12b55b2 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7757,7 +7757,10 @@ def conditional_types( ] ) remaining_type = restrict_subtype_away(current_type, proposed_precise_type) - return proposed_type, remaining_type + proposed_type_with_data = _transfer_type_var_args_from_current_to_proposed( + current_type, proposed_type + ) + return proposed_type_with_data, remaining_type else: # An isinstance check, but we don't understand the type return current_type, default @@ -7781,6 +7784,128 @@ def conditional_types_to_typemaps( return cast(Tuple[TypeMap, TypeMap], tuple(maps)) +def _transfer_type_var_args_from_current_to_proposed(current: Type, proposed: Type) -> Type: + """Check if the current type is among the bases of the proposed type. If so, try to transfer + the type variable arguments of the current type's instance to a copy of the proposed type's + instance. This increases information when narrowing generic classes so that, for example, + Sequence[int] is narrowed to List[int] instead of List[Any].""" + + def _get_instance_path_from_current_to_proposed( + this: Instance, target: TypeInfo + ) -> list[Instance] | None: + """Search for the current type among the bases of the proposed type and return the + "instance path" from the current to proposed type. Or None, if the current type is not a + nominal super type. At most one path is returned, which means there is no special handling + of (inconsistent) multiple inheritance.""" + if target == this.type: + return [this] + for base in this.type.bases: + path = _get_instance_path_from_current_to_proposed(base, target) + if path is not None: + path.append(this) + return path + return None + + # Handle "tuple of Instance" cases, e.g. `isinstance(x, (A, B))`: + proposed = get_proper_type(proposed) + if isinstance(proposed, UnionType): + items = [ + _transfer_type_var_args_from_current_to_proposed(current, item) + for item in flatten_nested_unions(proposed.items) + ] + return make_simplified_union(items) + + # Otherwise handle only Instances: + if not isinstance(proposed, Instance): + return proposed + + # Handle union cases like `a: A[int] | A[str]; isinstance(a, B)`: + current = get_proper_type(current) + if isinstance(current, UnionType): + items = [ + _transfer_type_var_args_from_current_to_proposed(item, proposed) + for item in flatten_nested_unions(current.items) + ] + return make_simplified_union(items) + + # Special handling for trivial "tuple is tuple" cases (handling tuple subclasses seems + # complicated, especially as long as `builtins.tuple` is not variadic): + if isinstance(current, TupleType) and (proposed.type.fullname == "builtins.tuple"): + return current + + # Here comes the main logic: + if isinstance(current, Instance): + + # Only consider nominal subtyping: + instances = _get_instance_path_from_current_to_proposed(proposed, current.type) + if instances is None: + return proposed + assert len(instances) > 0 # shortest case: proposed type is current type + + # Make a list of the proposed type's type variable arguments that allows to replace each + # `Any` with one type variable argument or multiple type variable tuple arguments of the + # current type: + proposed_args: list[Type | tuple[Type, ...]] = list(proposed.args) + + # Try to transfer each type variable argument from the current to the base type separately: + for pos1, typevar1 in enumerate(instances[0].args): + if isinstance(typevar1, UnpackType): + typevar1 = typevar1.type + if (len(instances) > 1) and not isinstance(typevar1, (TypeVarType, TypeVarTupleType)): + continue + # Find the position of the intermediate types' and finally the proposed type's + # related type variable (if not available, `pos2` becomes `None`): + pos2: int | None = pos1 + for instance in instances[1:]: + for pos2, typevar2 in enumerate(instance.type.defn.type_vars): + if typevar1 == typevar2: + if instance.type.has_type_var_tuple_type: + assert (pre := instance.type.type_var_tuple_prefix) is not None + if pos2 > pre: + pos2 += len(instance.args) - len(instance.type.defn.type_vars) + typevar1 = instance.args[pos2] + if isinstance(typevar1, UnpackType): + typevar1 = typevar1.type + break + else: + pos2 = None + break + + # Transfer the current type's type variable argument or type variable tuple arguments: + if pos2 is not None: + proposed_arg = proposed_args[pos2] + assert not isinstance(proposed_arg, tuple) + if isinstance(get_proper_type(proposed_arg), (AnyType, UnpackType)): + if current.type.has_type_var_tuple_type: + assert (pre := current.type.type_var_tuple_prefix) is not None + assert (suf := current.type.type_var_tuple_suffix) is not None + if pos1 < pre: + proposed_args[pos2] = current.args[pos1] + elif pos1 == pre: + proposed_args[pos2] = current.args[pre : len(current.args) - suf] + else: + middle = len(current.args) - pre - suf + proposed_args[pos2] = current.args[pos1 + middle - 1] + else: + proposed_args[pos2] = current.args[pos1] + + # Combine all type variable and type variable tuple arguments to a flat list: + flattened_proposed_args: list[Type] = [] + for arg in proposed_args: + if isinstance(arg, tuple): + flattened_proposed_args.extend(arg) + else: + flattened_proposed_args.append(arg) + # Some later checks seem to expect flattened unions: + for arg_ in flattened_proposed_args: + if isinstance(arg_ := get_proper_type(arg_), UnionType): + arg_.items = flatten_nested_unions(arg_.items) + + return proposed.copy_modified(args=flattened_proposed_args) + + return proposed + + def gen_unique_name(base: str, table: SymbolTable) -> str: """Generate a name that does not appear in table by appending numbers to base.""" if base not in table: diff --git a/mypy/types.py b/mypy/types.py index cc9c65299ee8..cbb18acda1d7 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -3675,16 +3675,14 @@ def flatten_nested_unions( ) -> list[Type]: """Flatten nested unions in a type list.""" if not isinstance(types, list): - typelist = list(types) - else: - typelist = cast("list[Type]", types) + types = list(types) # Fast path: most of the time there is nothing to flatten - if not any(isinstance(t, (TypeAliasType, UnionType)) for t in typelist): # type: ignore[misc] - return typelist + if not any(isinstance(t, (TypeAliasType, UnionType)) for t in types): # type: ignore[misc] + return types flat_items: list[Type] = [] - for t in typelist: + for t in types: if handle_type_alias_type: if not handle_recursive and isinstance(t, TypeAliasType) and t.is_recursive: tp: Type = t diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index d740708991d0..77401e7d31ed 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2106,6 +2106,172 @@ if isinstance(x, (Z, NoneType)): # E: Subclass of "X" and "Z" cannot exist: "Z" [builtins fixtures/isinstance.pyi] +[case testKeepTypeVarArgsWhenNarrowingGenericsWithIsInstance] +from typing import Generic, Sequence, Tuple, TypeVar, Union + +s: Sequence[str] +if isinstance(s, tuple): + reveal_type(s) # N: Revealed type is "builtins.tuple[builtins.str, ...]" +else: + reveal_type(s) # N: Revealed type is "typing.Sequence[builtins.str]" +if isinstance(s, list): + reveal_type(s) # N: Revealed type is "builtins.list[builtins.str]" +else: + reveal_type(s) # N: Revealed type is "typing.Sequence[builtins.str]" + +t1: Tuple[str, int] +if isinstance(t1, tuple): + reveal_type(t1) # N: Revealed type is "Tuple[builtins.str, builtins.int]" +else: + reveal_type(t1) + +t2: Tuple[str, ...] +if isinstance(t2, tuple): + reveal_type(t2) # N: Revealed type is "builtins.tuple[builtins.str, ...]" +else: + reveal_type(t2) + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +class A(Generic[T1]): ... +class B(A[T1], Generic[T1, T2]):... +a: A[str] +if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__.B[builtins.str, Any]" +else: + reveal_type(a) # N: Revealed type is "__main__.A[builtins.str]" +class C(A[str], Generic[T1]):... +if isinstance(a, C): + reveal_type(a) # N: Revealed type is "__main__.C[Any]" +else: + reveal_type(a) # N: Revealed type is "__main__.A[builtins.str]" + +class AA(Generic[T1]): ... +class BB(A[T1], AA[T1], Generic[T1, T2]):... +aa: Union[A[int], Union[AA[str], AA[int]]] +if isinstance(aa, BB): + reveal_type(aa) # N: Revealed type is "Union[__main__.BB[builtins.int, Any], __main__.BB[builtins.str, Any]]" +else: + reveal_type(aa) # N: Revealed type is "Union[__main__.A[builtins.int], __main__.AA[builtins.str], __main__.AA[builtins.int]]" + +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") +T6 = TypeVar("T6") +T7 = TypeVar("T7") +T8 = TypeVar("T8") +T9 = TypeVar("T9") +T10 = TypeVar("T10") +T11 = TypeVar("T11") +class A1(Generic[T1, T2]): ... +class A2(Generic[T3, T4]): ... +class B1(A1[T5, T6]):... +class B2(A2[T7, T8]):... +class C1(B1[T9, T10], B2[T11, T9]):... +a2: A2[str, int] +if isinstance(a2, C1): + reveal_type(a2) # N: Revealed type is "__main__.C1[builtins.int, Any, builtins.str]" +else: + reveal_type(a2) # N: Revealed type is "__main__.A2[builtins.str, builtins.int]" +[builtins fixtures/tuple.pyi] + +[case testKeepTypeVarArgsWhenNarrowingGenericsInUnionsWithIsInstance] +from typing import Generic, TypeVar, Union + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +class A(Generic[T1]): ... +class B(A[T1], Generic[T1, T2]):... +class C(A[T2], Generic[T1, T2]):... +a: Union[A[str], A[int]] +if isinstance(a, (B, C)): + reveal_type(a) # N: Revealed type is "Union[__main__.B[builtins.str, Any], __main__.C[Any, builtins.str], __main__.B[builtins.int, Any], __main__.C[Any, builtins.int]]" +else: + reveal_type(a) # N: Revealed type is "Union[__main__.A[builtins.str], __main__.A[builtins.int]]" +[builtins fixtures/isinstance.pyi] + +[case testKeepTypeVarArgsWhenNarrowingTupleTypeToTuple] +from typing import Sequence, Tuple, Union + +class A: ... +class B: ... +x: Union[Tuple[A], Tuple[A, B], Tuple[B, ...], Sequence[Tuple[A]]] +if isinstance(x, tuple): + reveal_type(x) # N: Revealed type is "Union[Tuple[__main__.A], Tuple[__main__.A, __main__.B], builtins.tuple[__main__.B, ...], builtins.tuple[Tuple[__main__.A], ...]]" +else: + reveal_type(x) # N: Revealed type is "typing.Sequence[Tuple[__main__.A]]" +[builtins fixtures/tuple.pyi] + +[case testKeepTypeVarArgsWhenNarrowingGenericsWithIsSubclass] +from typing import Generic, Sequence, Type, TypeVar + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +class A(Generic[T1]): ... +class B(A[T1], Generic[T1, T2]):... +a: Type[A[str]] +if issubclass(a, B): + reveal_type(a) # N: Revealed type is "Type[__main__.B[builtins.str, Any]]" +else: + reveal_type(a) # N: Revealed type is "Type[__main__.A[builtins.str]]" +class C(A[str], Generic[T1]):... +if issubclass(a, C): + reveal_type(a) # N: Revealed type is "Type[__main__.C[Any]]" +else: + reveal_type(a) # N: Revealed type is "Type[__main__.A[builtins.str]]" +[builtins fixtures/isinstance.pyi] + +[case testKeepTypeVarTupleArgsWhenNarrowingGenericsWithIsInstance] +from typing import Generic, Sequence, Tuple, TypeVar +from typing_extensions import TypeVarTuple, Unpack + +TP = TypeVarTuple("TP") +class A(Generic[Unpack[TP]]): ... +class B(A[Unpack[TP]]): ... +a: A[str, int] +if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__.B[builtins.str, builtins.int]" +else: + reveal_type(a) # N: Revealed type is "__main__.A[builtins.str, builtins.int]" + +def f1(a: A[Unpack[Tuple[str, ...]]]): + if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__.B[Unpack[builtins.tuple[builtins.str, ...]]]" + +T = TypeVar("T") +def f2(a: A[T, str, T]): + if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__.B[T`-1, builtins.str, T`-1]" + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") +T6 = TypeVar("T6") +class C(Generic[T1, Unpack[TP], T2]): ... +class D(C[T1, Unpack[TP], T2], Generic[T2, T4, T6, Unpack[TP], T5, T3, T1]): ... +class E(D[T1, T2, float, Unpack[TP], float, T3, T4]): ... +c: C[int, str, int, str] +if isinstance(c, E): + reveal_type(c) # N: Revealed type is "__main__.E[builtins.str, Any, builtins.str, builtins.int, Any, builtins.int]" +else: + reveal_type(c) # N: Revealed type is "__main__.C[builtins.int, builtins.str, builtins.int, builtins.str]" + +class F(E[T1, T2, str, int, T3, T4]): ... +if isinstance(c, F): + reveal_type(c) # N: Revealed type is "__main__.F[builtins.str, Any, Any, builtins.int]" +[builtins fixtures/tuple.pyi] + +[case testKeepTypeVarArgsWhenNarrowingGenericsWithIsInstanceMappingIterableOverlap] +# flags: --python-version 3.12 +# see PR 17099 +from typing import Iterable + +def f(x: dict[str, str] | Iterable[bytes]) -> None: + if isinstance(x, dict): + reveal_type(x) # N: Revealed type is "Union[builtins.dict[builtins.str, builtins.str], builtins.dict[builtins.bytes, Any]]" +[builtins fixtures/dict.pyi] [case testTypeNarrowingReachableNegative] # flags: --warn-unreachable from typing import Literal