diff --git a/mypy/checker.py b/mypy/checker.py index 522fb1c818e0..a6a50f03de3d 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6620,16 +6620,11 @@ def narrow_type_by_identity_equality( ...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices' would be the list [1, 2, 3]. - The 'narrowable_operand_indices' parameter is the set of all indices we are allowed + The 'narrowable_indices' parameter is the set of all indices we are allowed to refine the types of: that is, all operands that will potentially be a part of the output TypeMaps. """ - # should_narrow_by_identity_equality: - # If operator is "==" or "!=", we cannot narrow if we detect the presence of a user defined - # custom __eq__ or __ne__ method - should_narrow_by_identity_equality: bool - # is_target_for_value_narrowing: # If the operator returns True when compared to this target, do we narrow in else branch? # E.g. if operator is "==", then: @@ -6646,7 +6641,7 @@ def narrow_type_by_identity_equality( if operator in {"is", "is not"}: is_target_for_value_narrowing = is_singleton_identity_type should_coerce_literals = True - should_narrow_by_identity_equality = True + custom_eq_indices = set() enum_comparison_is_ambiguous = False elif operator in {"==", "!="}: @@ -6659,19 +6654,11 @@ def narrow_type_by_identity_equality( should_coerce_literals = True break - expr_types = [operand_types[i] for i in expr_indices] - should_narrow_by_identity_equality = not any(map(has_custom_eq_checks, expr_types)) + custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks(operand_types[i])} enum_comparison_is_ambiguous = True else: raise AssertionError - if not should_narrow_by_identity_equality: - # This is a bit of a legacy code path that might be a little unsound since it ignores - # custom __eq__. We should see if we can get rid of it in favour of `return {}, {}` - return self.refine_away_none_in_comparison( - operands, operand_types, expr_indices, narrowable_indices - ) - value_targets = [] type_targets = [] for i in expr_indices: @@ -6683,6 +6670,10 @@ def narrow_type_by_identity_equality( # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member. # See testMatchEnumSingleChoice expr_type = coerce_to_literal(expr_type) + if i in custom_eq_indices: + # We can't use types with custom __eq__ as targets for narrowing + # E.g. if (x: int | None) == (y: CustomEq | None), we cannot narrow x to None + continue if is_target_for_value_narrowing(get_proper_type(expr_type)): value_targets.append((i, TypeRange(expr_type, is_upper_bound=False))) else: @@ -6694,7 +6685,11 @@ def narrow_type_by_identity_equality( for i in expr_indices: if i not in narrowable_indices: continue - expr_type = coerce_to_literal(operand_types[i]) + if i in custom_eq_indices: + # Handled later + continue + expr_type = operand_types[i] + expr_type = coerce_to_literal(expr_type) expr_type = try_expanding_sum_type_to_union(expr_type, None) expr_enum_keys = ambiguous_enum_equality_keys(expr_type) for j, target in value_targets: @@ -6715,6 +6710,9 @@ def narrow_type_by_identity_equality( for i in expr_indices: if i not in narrowable_indices: continue + if i in custom_eq_indices: + # Handled later + continue expr_type = operand_types[i] for j, target in type_targets: if i == j: @@ -6723,9 +6721,63 @@ def narrow_type_by_identity_equality( operands[i], *conditional_types(expr_type, [target]) ) if if_map: - else_map = {} # this is the big difference compared to the above + # For type_targets, we cannot narrow in the negative case + # e.g. if (x: str | None) != (y: str), we cannot narrow x to None + else_map = {} partial_type_maps.append((if_map, else_map)) + for i in custom_eq_indices: + if i not in narrowable_indices: + continue + union_expr_type = get_proper_type(operand_types[i]) + if not isinstance(union_expr_type, UnionType): + expr_type = operand_types[i] + for j, target in value_targets: + _if_map, else_map = conditional_types_to_typemaps( + operands[i], *conditional_types(expr_type, [target]) + ) + if else_map: + partial_type_maps.append(({}, else_map)) + continue + + or_if_maps: list[TypeMap] = [] + or_else_maps: list[TypeMap] = [] + for expr_type in union_expr_type.items: + if has_custom_eq_checks(expr_type): + or_if_maps.append({operands[i]: expr_type}) + + for j in expr_indices: + if j in custom_eq_indices: + continue + target_type = operand_types[j] + if should_coerce_literals: + target_type = coerce_to_literal(target_type) + target = TypeRange(target_type, is_upper_bound=False) + is_value_target = is_target_for_value_narrowing(get_proper_type(target_type)) + + if is_value_target: + expr_type = coerce_to_literal(expr_type) + expr_type = try_expanding_sum_type_to_union(expr_type, None) + if_map, else_map = conditional_types_to_typemaps( + operands[i], *conditional_types(expr_type, [target], default=expr_type) + ) + or_if_maps.append(if_map) + if is_value_target: + or_else_maps.append(else_map) + + final_if_map: TypeMap = {} + final_else_map: TypeMap = {} + if or_if_maps: + final_if_map = or_if_maps[0] + for if_map in or_if_maps[1:]: + final_if_map = or_conditional_maps(final_if_map, if_map) + if or_else_maps: + final_else_map = or_else_maps[0] + for else_map in or_else_maps[1:]: + final_else_map = or_conditional_maps(final_else_map, else_map) + + partial_type_maps.append((final_if_map, final_else_map)) + for i in expr_indices: type_expr = operands[i] if ( @@ -6943,49 +6995,6 @@ def _propagate_walrus_assignments( return parent_expr return expr - def refine_away_none_in_comparison( - self, - operands: list[Expression], - operand_types: list[Type], - chain_indices: list[int], - narrowable_operand_indices: AbstractSet[int], - ) -> tuple[TypeMap, TypeMap]: - """Produces conditional type maps refining away None in an identity/equality chain. - - For more details about what the different arguments mean, see the - docstring of 'narrow_type_by_identity_equality' up above. - """ - - non_optional_types = [] - for i in chain_indices: - typ = operand_types[i] - if not is_overlapping_none(typ): - non_optional_types.append(typ) - - if_map, else_map = {}, {} - - if not non_optional_types or (len(non_optional_types) != len(chain_indices)): - - # Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be - # convenient but is strictly not type-safe): - for i in narrowable_operand_indices: - expr_type = operand_types[i] - if not is_overlapping_none(expr_type): - continue - if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types): - if_map[operands[i]] = remove_optional(expr_type) - - # Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and - # so type-safe but less convenient, because e.g. `Optional[A] == None` still results - # in `Optional[A]`): - if any(isinstance(get_proper_type(ot), NoneType) for ot in operand_types): - for i in narrowable_operand_indices: - expr_type = operand_types[i] - if is_overlapping_none(expr_type): - else_map[operands[i]] = remove_optional(expr_type) - - return if_map, else_map - def is_len_of_tuple(self, expr: Expression) -> bool: """Is this expression a `len(x)` call where x is a tuple or union of tuples?""" if not isinstance(expr, CallExpr): diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index fb14efea6663..4a694d6e517f 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -823,9 +823,8 @@ def bar(x: Union[SingletonFoo, Foo], y: SingletonFoo) -> None: reveal_type(x) # N: Revealed type is "Literal[__main__.SingletonFoo.A]" [builtins fixtures/primitives.pyi] -[case testNarrowingEqualityDisabledForCustomEquality] +[case testNarrowingEqualityCustomEqualityDisabled] from typing import Literal, Union -from enum import Enum class Custom: def __eq__(self, other: object) -> bool: return True @@ -834,15 +833,20 @@ class Default: pass x1: Union[Custom, Literal[1], Literal[2]] if x1 == 1: - reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1] | Literal[2]" + reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1]" else: - reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1] | Literal[2]" + reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[2]" x2: Union[Default, Literal[1], Literal[2]] if x2 == 1: reveal_type(x2) # N: Revealed type is "Literal[1]" else: reveal_type(x2) # N: Revealed type is "__main__.Default | Literal[2]" +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityCustomEqualityEnum] +from typing import Literal, Union +from enum import Enum class CustomEnum(Enum): A = 1 @@ -855,7 +859,7 @@ key: Literal[CustomEnum.A] if x3 == key: reveal_type(x3) # N: Revealed type is "__main__.CustomEnum" else: - reveal_type(x3) # N: Revealed type is "__main__.CustomEnum" + reveal_type(x3) # N: Revealed type is "Literal[__main__.CustomEnum.B]" # For comparison, this narrows since we bypass __eq__ if x3 is key: @@ -864,7 +868,7 @@ else: reveal_type(x3) # N: Revealed type is "Literal[__main__.CustomEnum.B]" [builtins fixtures/primitives.pyi] -[case testNarrowingEqualityDisabledForCustomEqualityChain] +[case testNarrowingEqualityCustomEqualityChainedComparison] # flags: --strict-equality --warn-unreachable from typing import Literal, Union @@ -877,21 +881,13 @@ x: Literal[1, 2, None] y: Custom z: Default -# We could maybe try doing something clever, but for simplicity we -# treat the whole chain as contaminated and mostly disable narrowing. -# -# The only exception is that we do at least strip away the 'None'. We -# (perhaps optimistically) assume no custom class would be pathological -# enough to declare itself to be equal to None and so permit this narrowing, -# since it's often convenient in practice. if 1 == x == y: - reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2]" + reveal_type(x) # N: Revealed type is "Literal[1]" reveal_type(y) # N: Revealed type is "__main__.Custom" else: - reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2] | None" + reveal_type(x) # N: Revealed type is "Literal[2] | None" reveal_type(y) # N: Revealed type is "__main__.Custom" -# No contamination here if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Literal[1, 2] | None", right operand type: "Default") reveal_type(x) # E: Statement is unreachable reveal_type(z) @@ -900,6 +896,101 @@ else: reveal_type(z) # N: Revealed type is "__main__.Default" [builtins fixtures/primitives.pyi] +[case testNarrowingCustomEqualityLiteralElseBranch] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from typing import Literal + +class Custom: + def __eq__(self, other: object) -> bool: + raise + +def f(v: Custom | Literal["text"]) -> Custom | None: + if v == "text": + reveal_type(v) # N: Revealed type is "__main__.Custom | Literal['text']" + return None + else: + reveal_type(v) # N: Revealed type is "__main__.Custom" + return v + +def g(v: Custom | Literal["text"]) -> Custom | None: + if v != "text": + reveal_type(v) # N: Revealed type is "__main__.Custom" + return None + else: + reveal_type(v) # N: Revealed type is "__main__.Custom | Literal['text']" + return v # E: Incompatible return value type (got "Custom | Literal['text']", expected "Custom | None") +[builtins fixtures/primitives.pyi] + +[case testNarrowingCustomEqualityUnion] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from typing import Any + +def realistic(x: dict[str, Any]): + val = x.get("hey") + if val == 12: + reveal_type(val) # N: Revealed type is "Any | Literal[12]?" + +def f1(x: Any | None): + if x == 12: + reveal_type(x) # N: Revealed type is "Any | Literal[12]?" + +class Custom: + def __eq__(self, other: object) -> bool: + raise + +def f2(x: Custom | None): + if x == 12: + reveal_type(x) # N: Revealed type is "__main__.Custom" + else: + reveal_type(x) # N: Revealed type is "__main__.Custom | None" +[builtins fixtures/dict.pyi] + +[case testNarrowingCustomEqualityUnion2] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from typing import Any + +class Custom: + def __eq__(self, other: object) -> bool: + raise + +def f3(x: str | Custom, y: str | int): + if x == y: + reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom" + reveal_type(y) # N: Revealed type is "builtins.str | builtins.int" + else: + reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom" + reveal_type(y) # N: Revealed type is "builtins.str | builtins.int" + +def f4(x: str | Any, y: str | int): + if x == y: + reveal_type(x) # N: Revealed type is "builtins.str | Any | builtins.int" + reveal_type(y) # N: Revealed type is "builtins.str | builtins.int" + else: + reveal_type(x) # N: Revealed type is "builtins.str | Any" + reveal_type(y) # N: Revealed type is "builtins.str | builtins.int" +[builtins fixtures/dict.pyi] + +[case testNarrowingCustomEqualityUnion3] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from typing import Any + +class Custom: + def __eq__(self, other: object) -> bool: + raise + +def f(x: Custom | None, y: int | None): + if x == y: + reveal_type(x) # N: Revealed type is "__main__.Custom | None" + reveal_type(y) # N: Revealed type is "builtins.int | None" + else: + reveal_type(x) # N: Revealed type is "__main__.Custom | None" + reveal_type(y) # N: Revealed type is "builtins.int | None" +[builtins fixtures/primitives.pyi] + [case testNarrowingUnreachableCases] # flags: --strict-equality --warn-unreachable from typing import Literal, Union @@ -2157,7 +2248,7 @@ def f3(x: object) -> None: def f4(x: int | Any) -> None: if x == IE.X: - reveal_type(x) # N: Revealed type is "builtins.int | Any" + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X] | Any" else: reveal_type(x) # N: Revealed type is "builtins.int | Any" @@ -2232,9 +2323,9 @@ def f5(x: E | str | int) -> None: def f6(x: IE | Any) -> None: if x == IE.X: - reveal_type(x) # N: Revealed type is "__main__.IE | Any" + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X] | Any" else: - reveal_type(x) # N: Revealed type is "__main__.IE | Any" + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.Y] | Any" def f7(x: IE | None) -> None: if x == IE.X: @@ -2321,7 +2412,7 @@ def f(x: str | int) -> None: z = y [builtins fixtures/primitives.pyi] -[case testConsistentNarrowingInWithCustomEq] +[case testConsistentNarrowingEqAndInWithCustomEq] # flags: --python-version 3.10 # https://github.com/python/mypy/issues/17864 @@ -2339,11 +2430,17 @@ class C: class D(C): pass -def f(x: C) -> None: +def f1(x: C) -> None: if x in [D(5)]: reveal_type(x) # D # N: Revealed type is "__main__.C" -f(C(5)) +f1(C(5)) + +def f2(x: C) -> None: + if x == D(5): + reveal_type(x) # D # N: Revealed type is "__main__.C" + +f2(C(5)) [builtins fixtures/primitives.pyi] [case testNarrowingTypeVarNone] diff --git a/test-data/unit/check-unreachable-code.test b/test-data/unit/check-unreachable-code.test index c813ea9c1c38..22c4c80916cb 100644 --- a/test-data/unit/check-unreachable-code.test +++ b/test-data/unit/check-unreachable-code.test @@ -1674,3 +1674,21 @@ def x() -> None: main:4: error: Statement is unreachable if 5: ^~~~~ + +[case testReachableEqualityNarrowingAny] +# flags: --warn-unreachable +# https://github.com/python/mypy/issues/20532 +from __future__ import annotations +from typing import Any + +def print(s: str): pass + +def main(contents: Any, commit: str | None) -> None: + if ( + contents.get("commit") == commit + and (commit is not None or print("can_be_reached")) + ): + pass + +main({"commit": None}, None) +[builtins fixtures/tuple.pyi]