diff --git a/mypy/checker.py b/mypy/checker.py index 36ceb26d1cfc..f0242230db88 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -32,7 +32,12 @@ ) from mypy.checkpattern import PatternChecker from mypy.constraints import SUPERTYPE_OF -from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values +from mypy.erasetype import ( + erase_type, + erase_typevars, + remove_instance_last_known_values, + shallow_erase_type_for_equality, +) from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode from mypy.errors import ( ErrorInfo, @@ -45,7 +50,7 @@ from mypy.expandtype import expand_type from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash from mypy.maptype import map_instance_to_supertype -from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types +from mypy.meet import is_overlapping_types, meet_types from mypy.message_registry import ErrorMessage from mypy.messages import ( SUGGESTED_TEST_FIXTURES, @@ -6540,19 +6545,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa narrowable_indices={0}, ) - # We only try and narrow away 'None' for now - if ( - not is_unreachable_map(if_map) - and is_overlapping_none(item_type) - and not is_overlapping_none(collection_item_type) - and not ( - isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == "builtins.object" - ) - and is_overlapping_erased_types(item_type, collection_item_type) - ): - if_map[operands[left_index]] = remove_optional(item_type) - if right_index in narrowable_operand_index_to_hash: if_type, else_type = self.conditional_types_for_iterable( item_type, iterable_type @@ -6676,6 +6668,9 @@ def narrow_type_by_identity_equality( target_type = operand_types[j] if should_coerce_literals: target_type = coerce_to_literal(target_type) + # Type A[T1] could compare equal to A[T2] even if T1 is disjoint from T2 + # e.g. cast(list[int], []) == cast(list[str], []) + target_type = shallow_erase_type_for_equality(target_type) if ( # See comments in ambiguous_enum_equality_keys @@ -6689,7 +6684,7 @@ def narrow_type_by_identity_equality( if_map, else_map = conditional_types_to_typemaps( operands[i], *conditional_types(expr_type, [target]) ) - if is_target_for_value_narrowing(get_proper_type(target_type)): + if is_target_for_value_narrowing(target_type): all_if_maps.append(if_map) all_else_maps.append(else_map) else: @@ -6758,13 +6753,15 @@ def narrow_type_by_identity_equality( target_type = operand_types[j] if should_coerce_literals: target_type = coerce_to_literal(target_type) + target_type = shallow_erase_type_for_equality(target_type) + target = TypeRange(target_type, is_upper_bound=False) if_map, else_map = conditional_types_to_typemaps( operands[i], *conditional_types(expr_type, [target], default=expr_type) ) or_if_maps.append(if_map) - if is_target_for_value_narrowing(get_proper_type(target_type)): + if is_target_for_value_narrowing(target_type): or_else_maps.append(else_map) all_if_maps.append(reduce_or_conditional_type_maps(or_if_maps)) @@ -8609,13 +8606,7 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty return result -BUILTINS_CUSTOM_EQ_CHECKS: Final = { - "builtins.bytearray", - "builtins.memoryview", - "builtins.list", - "builtins.dict", - "builtins.set", -} +BUILTINS_CUSTOM_EQ_CHECKS: Final = {"builtins.bytearray", "builtins.memoryview"} def has_custom_eq_checks(t: Type) -> bool: diff --git a/mypy/erasetype.py b/mypy/erasetype.py index f2912fe22a9e..93034e0689ea 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -285,3 +285,17 @@ def visit_union_type(self, t: UnionType) -> Type: merged.append(orig_item) return UnionType.make_union(merged) return new + + +def shallow_erase_type_for_equality(typ: Type) -> ProperType: + """Erase type variables from Instance's inside a type.""" + p_typ = get_proper_type(typ) + if isinstance(p_typ, Instance): + if not p_typ.args: + return p_typ + args = erased_vars(p_typ.type.defn.type_vars, TypeOfAny.special_form) + return Instance(p_typ.type, args, p_typ.line) + if isinstance(p_typ, UnionType): + items = [shallow_erase_type_for_equality(item) for item in p_typ.items] + return UnionType.make_union(items) + return p_typ diff --git a/mypy/meet.py b/mypy/meet.py index 365544d4584f..a029ec52f7e0 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -3,7 +3,6 @@ from collections.abc import Callable from mypy import join -from mypy.erasetype import erase_type from mypy.maptype import map_instance_to_supertype from mypy.state import state from mypy.subtypes import ( @@ -657,18 +656,6 @@ def _type_object_overlap(left: Type, right: Type) -> bool: return False -def is_overlapping_erased_types( - left: Type, right: Type, *, ignore_promotions: bool = False -) -> bool: - """The same as 'is_overlapping_erased_types', except the types are erased first.""" - return is_overlapping_types( - erase_type(left), - erase_type(right), - ignore_promotions=ignore_promotions, - prohibit_none_typevar_overlap=True, - ) - - def are_typed_dicts_overlapping( left: TypedDictType, right: TypedDictType, is_overlapping: Callable[[Type, Type], bool] ) -> bool: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 4c3ee5dde206..43ab5b917a7e 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1065,6 +1065,92 @@ def f(x: Custom, y: CustomSub): reveal_type(y) # N: Revealed type is "__main__.CustomSub" [builtins fixtures/tuple.pyi] +[case testNarrowingCustomEqualityGeneric] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from typing import Union + +class Custom: + def __eq__(self, other: object) -> bool: + raise + +class Default: ... + +def f1(x: list[Custom] | Default, y: list[int]): + if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int]") + reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]" + reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]" + else: + reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default" + reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]" + +f1([], []) + +def f2(x: list[Custom] | Default, y: list[int] | list[Default]): + if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]") + reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]" + reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]" + else: + reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default" + reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]" + +listcustom_or_default = Union[list[Custom], Default] +listint_or_default = Union[list[int], list[Default]] + +def f2_with_alias(x: listcustom_or_default, y: listint_or_default): + if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]") + reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]" + reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]" + else: + reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default" + reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]" + +def f3(x: Custom | dict[str, str], y: dict[int, int]): + if x == y: + reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]" + reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]" + else: + reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]" + reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]" +[builtins fixtures/primitives.pyi] + +[case testNarrowingRecursiveCallable] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from typing import Callable + +class A: ... +class B: ... + +T = Callable[[A], "S"] +S = Callable[[B], "T"] + +def f(x: S, y: T): + if x == y: # E: Unsupported left operand type for == ("Callable[[B], T]") + reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..." + reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..." + else: + reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..." + reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..." +[builtins fixtures/tuple.pyi] + +[case testNarrowingRecursiveUnion] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from typing import Union + +class A: ... +class B: ... + +T = Union[A, "S"] +S = Union[B, "T"] # E: Invalid recursive alias: a union item of itself + +def f(x: S, y: T): + if x == y: + reveal_type(x) # N: Revealed type is "Any" + reveal_type(y) # N: Revealed type is "__main__.A | Any" +[builtins fixtures/tuple.pyi] + [case testNarrowingUnreachableCases] # flags: --strict-equality --warn-unreachable from typing import Literal, Union diff --git a/test-data/unit/check-tuples.test b/test-data/unit/check-tuples.test index 5dcc34027c70..d517c8cede96 100644 --- a/test-data/unit/check-tuples.test +++ b/test-data/unit/check-tuples.test @@ -1540,7 +1540,9 @@ class B: pass def f1(possibles: Tuple[int, Tuple[A]], x: Optional[Tuple[B]]): if x in possibles: - reveal_type(x) # N: Revealed type is "tuple[__main__.B]" + # TODO: this branch is actually unreachable + # This is an easy fix: https://github.com/python/mypy/pull/20660 + reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None" else: reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"