-
-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Better narrowing with custom equality #20643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6620,16 +6620,11 @@ def narrow_type_by_identity_equality( | |
| ...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices' | ||
| would be the list [1, 2, 3]. | ||
|
|
||
| The 'narrowable_operand_indices' parameter is the set of all indices we are allowed | ||
| The 'narrowable_indices' parameter is the set of all indices we are allowed | ||
| to refine the types of: that is, all operands that will potentially be a part of | ||
| the output TypeMaps. | ||
|
|
||
| """ | ||
| # should_narrow_by_identity_equality: | ||
| # If operator is "==" or "!=", we cannot narrow if we detect the presence of a user defined | ||
| # custom __eq__ or __ne__ method | ||
| should_narrow_by_identity_equality: bool | ||
|
|
||
| # is_target_for_value_narrowing: | ||
| # If the operator returns True when compared to this target, do we narrow in else branch? | ||
| # E.g. if operator is "==", then: | ||
|
|
@@ -6646,7 +6641,7 @@ def narrow_type_by_identity_equality( | |
| if operator in {"is", "is not"}: | ||
| is_target_for_value_narrowing = is_singleton_identity_type | ||
| should_coerce_literals = True | ||
| should_narrow_by_identity_equality = True | ||
| custom_eq_indices = set() | ||
| enum_comparison_is_ambiguous = False | ||
|
|
||
| elif operator in {"==", "!="}: | ||
|
|
@@ -6659,19 +6654,11 @@ def narrow_type_by_identity_equality( | |
| should_coerce_literals = True | ||
| break | ||
|
|
||
| expr_types = [operand_types[i] for i in expr_indices] | ||
| should_narrow_by_identity_equality = not any(map(has_custom_eq_checks, expr_types)) | ||
| custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks(operand_types[i])} | ||
| enum_comparison_is_ambiguous = True | ||
| else: | ||
| raise AssertionError | ||
|
|
||
| if not should_narrow_by_identity_equality: | ||
| # This is a bit of a legacy code path that might be a little unsound since it ignores | ||
| # custom __eq__. We should see if we can get rid of it in favour of `return {}, {}` | ||
| return self.refine_away_none_in_comparison( | ||
| operands, operand_types, expr_indices, narrowable_indices | ||
| ) | ||
|
|
||
| value_targets = [] | ||
| type_targets = [] | ||
| for i in expr_indices: | ||
|
|
@@ -6683,6 +6670,10 @@ def narrow_type_by_identity_equality( | |
| # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member. | ||
| # See testMatchEnumSingleChoice | ||
| expr_type = coerce_to_literal(expr_type) | ||
| if i in custom_eq_indices: | ||
| # We can't use types with custom __eq__ as targets for narrowing | ||
| # E.g. if (x: int | None) == (y: CustomEq | None), we cannot narrow x to None | ||
| continue | ||
| if is_target_for_value_narrowing(get_proper_type(expr_type)): | ||
| value_targets.append((i, TypeRange(expr_type, is_upper_bound=False))) | ||
| else: | ||
|
|
@@ -6694,7 +6685,11 @@ def narrow_type_by_identity_equality( | |
| for i in expr_indices: | ||
| if i not in narrowable_indices: | ||
| continue | ||
| expr_type = coerce_to_literal(operand_types[i]) | ||
| if i in custom_eq_indices: | ||
| # Handled later | ||
| continue | ||
| expr_type = operand_types[i] | ||
| expr_type = coerce_to_literal(expr_type) | ||
| expr_type = try_expanding_sum_type_to_union(expr_type, None) | ||
| expr_enum_keys = ambiguous_enum_equality_keys(expr_type) | ||
| for j, target in value_targets: | ||
|
|
@@ -6715,6 +6710,9 @@ def narrow_type_by_identity_equality( | |
| for i in expr_indices: | ||
| if i not in narrowable_indices: | ||
| continue | ||
| if i in custom_eq_indices: | ||
| # Handled later | ||
| continue | ||
| expr_type = operand_types[i] | ||
| for j, target in type_targets: | ||
| if i == j: | ||
|
|
@@ -6723,9 +6721,63 @@ def narrow_type_by_identity_equality( | |
| operands[i], *conditional_types(expr_type, [target]) | ||
| ) | ||
| if if_map: | ||
| else_map = {} # this is the big difference compared to the above | ||
| # For type_targets, we cannot narrow in the negative case | ||
| # e.g. if (x: str | None) != (y: str), we cannot narrow x to None | ||
| else_map = {} | ||
|
Comment on lines
+6724
to
+6726
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part of the diff also seems unrelated? |
||
| partial_type_maps.append((if_map, else_map)) | ||
|
|
||
| for i in custom_eq_indices: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a later PR that condenses this a little bit / adds more comments Maybe worth looking at the version on my dev branch: https://github.com/hauntsaninja/mypy/pull/5/files#diff-f96a2d6138bc6cdf2a07c4d37f6071cc25c1631afc107e277a28d5b59fc0ef04R6699 |
||
| if i not in narrowable_indices: | ||
| continue | ||
| union_expr_type = get_proper_type(operand_types[i]) | ||
| if not isinstance(union_expr_type, UnionType): | ||
| expr_type = operand_types[i] | ||
| for j, target in value_targets: | ||
| _if_map, else_map = conditional_types_to_typemaps( | ||
| operands[i], *conditional_types(expr_type, [target]) | ||
| ) | ||
| if else_map: | ||
| partial_type_maps.append(({}, else_map)) | ||
| continue | ||
|
|
||
| or_if_maps: list[TypeMap] = [] | ||
| or_else_maps: list[TypeMap] = [] | ||
| for expr_type in union_expr_type.items: | ||
| if has_custom_eq_checks(expr_type): | ||
| or_if_maps.append({operands[i]: expr_type}) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -snip- Nevermind, I see what this does. |
||
|
|
||
| for j in expr_indices: | ||
| if j in custom_eq_indices: | ||
| continue | ||
| target_type = operand_types[j] | ||
| if should_coerce_literals: | ||
| target_type = coerce_to_literal(target_type) | ||
| target = TypeRange(target_type, is_upper_bound=False) | ||
| is_value_target = is_target_for_value_narrowing(get_proper_type(target_type)) | ||
|
|
||
| if is_value_target: | ||
| expr_type = coerce_to_literal(expr_type) | ||
| expr_type = try_expanding_sum_type_to_union(expr_type, None) | ||
| if_map, else_map = conditional_types_to_typemaps( | ||
| operands[i], *conditional_types(expr_type, [target], default=expr_type) | ||
| ) | ||
| or_if_maps.append(if_map) | ||
| if is_value_target: | ||
| or_else_maps.append(else_map) | ||
|
|
||
| final_if_map: TypeMap = {} | ||
| final_else_map: TypeMap = {} | ||
| if or_if_maps: | ||
| final_if_map = or_if_maps[0] | ||
| for if_map in or_if_maps[1:]: | ||
| final_if_map = or_conditional_maps(final_if_map, if_map) | ||
| if or_else_maps: | ||
| final_else_map = or_else_maps[0] | ||
| for else_map in or_else_maps[1:]: | ||
| final_else_map = or_conditional_maps(final_else_map, else_map) | ||
|
|
||
| partial_type_maps.append((final_if_map, final_else_map)) | ||
|
|
||
| for i in expr_indices: | ||
| type_expr = operands[i] | ||
| if ( | ||
|
|
@@ -6943,49 +6995,6 @@ def _propagate_walrus_assignments( | |
| return parent_expr | ||
| return expr | ||
|
|
||
| def refine_away_none_in_comparison( | ||
| self, | ||
| operands: list[Expression], | ||
| operand_types: list[Type], | ||
| chain_indices: list[int], | ||
| narrowable_operand_indices: AbstractSet[int], | ||
| ) -> tuple[TypeMap, TypeMap]: | ||
| """Produces conditional type maps refining away None in an identity/equality chain. | ||
|
|
||
| For more details about what the different arguments mean, see the | ||
| docstring of 'narrow_type_by_identity_equality' up above. | ||
| """ | ||
|
|
||
| non_optional_types = [] | ||
| for i in chain_indices: | ||
| typ = operand_types[i] | ||
| if not is_overlapping_none(typ): | ||
| non_optional_types.append(typ) | ||
|
|
||
| if_map, else_map = {}, {} | ||
|
|
||
| if not non_optional_types or (len(non_optional_types) != len(chain_indices)): | ||
|
|
||
| # Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be | ||
| # convenient but is strictly not type-safe): | ||
| for i in narrowable_operand_indices: | ||
| expr_type = operand_types[i] | ||
| if not is_overlapping_none(expr_type): | ||
| continue | ||
| if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types): | ||
| if_map[operands[i]] = remove_optional(expr_type) | ||
|
|
||
| # Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and | ||
| # so type-safe but less convenient, because e.g. `Optional[A] == None` still results | ||
| # in `Optional[A]`): | ||
| if any(isinstance(get_proper_type(ot), NoneType) for ot in operand_types): | ||
| for i in narrowable_operand_indices: | ||
| expr_type = operand_types[i] | ||
| if is_overlapping_none(expr_type): | ||
| else_map[operands[i]] = remove_optional(expr_type) | ||
|
|
||
| return if_map, else_map | ||
|
|
||
| def is_len_of_tuple(self, expr: Expression) -> bool: | ||
| """Is this expression a `len(x)` call where x is a tuple or union of tuples?""" | ||
| if not isinstance(expr, CallExpr): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh, why did these lines get added? Maybe they're right (not sure what it's doing...) but seems weird to add code to this existing codepath?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no semantic change in the expr_type stuff, the important change here is
if i in custom_eq_indices: continueAs for why I have a stray line: managing a stack of like twenty commits is a little fiddly