Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 70 additions & 61 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6620,16 +6620,11 @@ def narrow_type_by_identity_equality(
...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
would be the list [1, 2, 3].

The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
The 'narrowable_indices' parameter is the set of all indices we are allowed
to refine the types of: that is, all operands that will potentially be a part of
the output TypeMaps.

"""
# should_narrow_by_identity_equality:
# If operator is "==" or "!=", we cannot narrow if we detect the presence of a user defined
# custom __eq__ or __ne__ method
should_narrow_by_identity_equality: bool

# is_target_for_value_narrowing:
# If the operator returns True when compared to this target, do we narrow in else branch?
# E.g. if operator is "==", then:
Expand All @@ -6646,7 +6641,7 @@ def narrow_type_by_identity_equality(
if operator in {"is", "is not"}:
is_target_for_value_narrowing = is_singleton_identity_type
should_coerce_literals = True
should_narrow_by_identity_equality = True
custom_eq_indices = set()
enum_comparison_is_ambiguous = False

elif operator in {"==", "!="}:
Expand All @@ -6659,19 +6654,11 @@ def narrow_type_by_identity_equality(
should_coerce_literals = True
break

expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity_equality = not any(map(has_custom_eq_checks, expr_types))
custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks(operand_types[i])}
enum_comparison_is_ambiguous = True
else:
raise AssertionError

if not should_narrow_by_identity_equality:
# This is a bit of a legacy code path that might be a little unsound since it ignores
# custom __eq__. We should see if we can get rid of it in favour of `return {}, {}`
return self.refine_away_none_in_comparison(
operands, operand_types, expr_indices, narrowable_indices
)

value_targets = []
type_targets = []
for i in expr_indices:
Expand All @@ -6683,6 +6670,10 @@ def narrow_type_by_identity_equality(
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
# See testMatchEnumSingleChoice
expr_type = coerce_to_literal(expr_type)
if i in custom_eq_indices:
# We can't use types with custom __eq__ as targets for narrowing
# E.g. if (x: int | None) == (y: CustomEq | None), we cannot narrow x to None
continue
if is_target_for_value_narrowing(get_proper_type(expr_type)):
value_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))
else:
Expand All @@ -6694,7 +6685,11 @@ def narrow_type_by_identity_equality(
for i in expr_indices:
if i not in narrowable_indices:
continue
expr_type = coerce_to_literal(operand_types[i])
if i in custom_eq_indices:
# Handled later
continue
expr_type = operand_types[i]
expr_type = coerce_to_literal(expr_type)
Comment on lines +6691 to +6692
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, why did these lines get added? Maybe they're right (not sure what it's doing...) but seems weird to add code to this existing codepath?

Copy link
Collaborator Author

@hauntsaninja hauntsaninja Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no semantic change in the expr_type stuff, the important change here is if i in custom_eq_indices: continue

As for why I have a stray line: managing a stack of like twenty commits is a little fiddly

expr_type = try_expanding_sum_type_to_union(expr_type, None)
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
for j, target in value_targets:
Expand All @@ -6715,6 +6710,9 @@ def narrow_type_by_identity_equality(
for i in expr_indices:
if i not in narrowable_indices:
continue
if i in custom_eq_indices:
# Handled later
continue
expr_type = operand_types[i]
for j, target in type_targets:
if i == j:
Expand All @@ -6723,9 +6721,63 @@ def narrow_type_by_identity_equality(
operands[i], *conditional_types(expr_type, [target])
)
if if_map:
else_map = {} # this is the big difference compared to the above
# For type_targets, we cannot narrow in the negative case
# e.g. if (x: str | None) != (y: str), we cannot narrow x to None
else_map = {}
Comment on lines +6724 to +6726
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the diff also seems unrelated?

partial_type_maps.append((if_map, else_map))

for i in custom_eq_indices:
Copy link
Collaborator Author

@hauntsaninja hauntsaninja Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a later PR that condenses this a little bit / adds more comments

Maybe worth looking at the version on my dev branch: https://github.com/hauntsaninja/mypy/pull/5/files#diff-f96a2d6138bc6cdf2a07c4d37f6071cc25c1631afc107e277a28d5b59fc0ef04R6699

if i not in narrowable_indices:
continue
union_expr_type = get_proper_type(operand_types[i])
if not isinstance(union_expr_type, UnionType):
expr_type = operand_types[i]
for j, target in value_targets:
_if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
if else_map:
partial_type_maps.append(({}, else_map))
continue

or_if_maps: list[TypeMap] = []
or_else_maps: list[TypeMap] = []
for expr_type in union_expr_type.items:
if has_custom_eq_checks(expr_type):
or_if_maps.append({operands[i]: expr_type})
Copy link
Collaborator

@A5rocks A5rocks Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-snip-

Nevermind, I see what this does.


for j in expr_indices:
if j in custom_eq_indices:
continue
target_type = operand_types[j]
if should_coerce_literals:
target_type = coerce_to_literal(target_type)
target = TypeRange(target_type, is_upper_bound=False)
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))

if is_value_target:
expr_type = coerce_to_literal(expr_type)
expr_type = try_expanding_sum_type_to_union(expr_type, None)
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target], default=expr_type)
)
or_if_maps.append(if_map)
if is_value_target:
or_else_maps.append(else_map)

final_if_map: TypeMap = {}
final_else_map: TypeMap = {}
if or_if_maps:
final_if_map = or_if_maps[0]
for if_map in or_if_maps[1:]:
final_if_map = or_conditional_maps(final_if_map, if_map)
if or_else_maps:
final_else_map = or_else_maps[0]
for else_map in or_else_maps[1:]:
final_else_map = or_conditional_maps(final_else_map, else_map)

partial_type_maps.append((final_if_map, final_else_map))

for i in expr_indices:
type_expr = operands[i]
if (
Expand Down Expand Up @@ -6943,49 +6995,6 @@ def _propagate_walrus_assignments(
return parent_expr
return expr

def refine_away_none_in_comparison(
self,
operands: list[Expression],
operand_types: list[Type],
chain_indices: list[int],
narrowable_operand_indices: AbstractSet[int],
) -> tuple[TypeMap, TypeMap]:
"""Produces conditional type maps refining away None in an identity/equality chain.

For more details about what the different arguments mean, see the
docstring of 'narrow_type_by_identity_equality' up above.
"""

non_optional_types = []
for i in chain_indices:
typ = operand_types[i]
if not is_overlapping_none(typ):
non_optional_types.append(typ)

if_map, else_map = {}, {}

if not non_optional_types or (len(non_optional_types) != len(chain_indices)):

# Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
# convenient but is strictly not type-safe):
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if not is_overlapping_none(expr_type):
continue
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
if_map[operands[i]] = remove_optional(expr_type)

# Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and
# so type-safe but less convenient, because e.g. `Optional[A] == None` still results
# in `Optional[A]`):
if any(isinstance(get_proper_type(ot), NoneType) for ot in operand_types):
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if is_overlapping_none(expr_type):
else_map[operands[i]] = remove_optional(expr_type)

return if_map, else_map

def is_len_of_tuple(self, expr: Expression) -> bool:
"""Is this expression a `len(x)` call where x is a tuple or union of tuples?"""
if not isinstance(expr, CallExpr):
Expand Down
Loading