From 5dfcdb0d49bf2d10b9db1b6214f6de16978bec55 Mon Sep 17 00:00:00 2001 From: Shantanu Jain Date: Sat, 10 Jan 2026 22:48:27 -0800 Subject: [PATCH] Narrowing for comparisons against x.__class__ --- mypy/checker.py | 2 ++ test-data/unit/check-narrowing.test | 28 +++++++++++++++++++++++++++ test-data/unit/fixtures/dict-full.pyi | 4 ++-- 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 522fb1c818e0..38b44c153e1f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6734,6 +6734,8 @@ def narrow_type_by_identity_equality( and len(type_expr.args) == 1 ): expr_in_type_expr = type_expr.args[0] + elif isinstance(type_expr, MemberExpr) and type_expr.name == "__class__": + expr_in_type_expr = type_expr.expr else: continue for j in expr_indices: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index fb14efea6663..5eb50d4af0af 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -3076,3 +3076,31 @@ if type(x) is not C: reveal_type(x) # N: Revealed type is "__main__.D" else: reveal_type(x) # N: Revealed type is "__main__.C" + +[case testDunderClassNarrowing] +# flags: --warn-unreachable +from typing import Any + +def foo(y: object): + if y.__class__ == int: + reveal_type(y) # N: Revealed type is "builtins.int" + else: + reveal_type(y) # N: Revealed type is "builtins.object" + + if y.__class__ is int: + reveal_type(y) # N: Revealed type is "builtins.int" + else: + reveal_type(y) # N: Revealed type is "builtins.object" + + +def bar(y: Any): + if y.__class__ == int: + reveal_type(y) # N: Revealed type is "Any" + else: + reveal_type(y) # N: Revealed type is "Any" + + if y.__class__ is int: + reveal_type(y) # N: Revealed type is "builtins.int" + else: + reveal_type(y) # N: Revealed type is "Any" +[builtins fixtures/dict-full.pyi] diff --git a/test-data/unit/fixtures/dict-full.pyi b/test-data/unit/fixtures/dict-full.pyi index f20369ce9332..88cd3260c702 100644 --- a/test-data/unit/fixtures/dict-full.pyi +++ b/test-data/unit/fixtures/dict-full.pyi @@ -13,6 +13,7 @@ KT = TypeVar('KT') VT = TypeVar('VT') class object: + __class__: object def __init__(self) -> None: pass def __init_subclass__(cls) -> None: pass def __eq__(self, other: object) -> bool: pass @@ -75,8 +76,7 @@ class float: pass class complex: pass class bool(int): pass -class ellipsis: - __class__: object +class ellipsis: pass def isinstance(x: object, t: Union[type, Tuple[type, ...]]) -> bool: pass class BaseException: pass