Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables the narrowing of variable types when checking a variable is "in" a collection. #17344

Merged
merged 7 commits into from
Jul 24, 2024
Merged
15 changes: 10 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5948,11 +5948,16 @@ def has_no_custom_eq_checks(t: Type) -> bool:
if_map, else_map = {}, {}

if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_overlapping_none(item_type):
collection_item_type = get_proper_type(
builtin_item_type(iterable_type)
)
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
# Narrow if the collection is a subtype
if (
collection_item_type is not None
and collection_item_type != item_type
and is_subtype(collection_item_type, item_type)
):
if_map[operands[left_index]] = collection_item_type
# Try and narrow away 'None'
elif is_overlapping_none(item_type):
if (
collection_item_type is not None
and not is_overlapping_none(collection_item_type)
Expand Down
112 changes: 110 additions & 2 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1376,13 +1376,13 @@ else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"

if val in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "None"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
if val not in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "None"
[builtins fixtures/primitives.pyi]

[case testNarrowingWithTupleOfTypes]
Expand Down Expand Up @@ -2114,3 +2114,111 @@ else:

[typing fixtures/typing-medium.pyi]
[builtins fixtures/ops.pyi]


[case testTypeNarrowingStringInLiteralUnion]
from typing import Literal, Tuple
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInLiteralUnionSubset]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b')
strIn: str = "b"
strOut: str = "c"
if strIn in typeAlpha:
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strIn) # N: Revealed type is "builtins.str"
if strOut in typeAlpha:
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strOut) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testNarrowingStringNotInLiteralUnion]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c')
strIn: str = "c"
strOut: str = "d"
if strIn not in typeAlpha:
reveal_type(strIn) # N: Revealed type is "builtins.str"
else:
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
if strOut in typeAlpha:
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strOut) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testNarrowingStringInLiteralUnionDontExpand]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c')
strIn: Literal['c'] = "c"
reveal_type(strIn) # N: Revealed type is "Literal['c']"
#Check we don't expand a Literal into the Union type
if strIn not in typeAlpha:
reveal_type(strIn) # N: Revealed type is "Literal['c']"
else:
reveal_type(strIn) # N: Revealed type is "Literal['c']"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInMixedUnion]
from typing import Literal, Tuple
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInSet]
from typing import Literal, Set
typ: Set[Literal['a', 'b']] = {'a', 'b'}
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
if x not in typ:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
[builtins fixtures/narrowing.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInList]
from typing import Literal, List
typ: List[Literal['a', 'b']] = ['a', 'b']
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
if x not in typ:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
[builtins fixtures/narrowing.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingUnionStringFloat]
from typing import Union
def foobar(foo: Union[str, float]):
if foo in ['a', 'b']:
reveal_type(foo) # N: Revealed type is "builtins.str"
else:
reveal_type(foo) # N: Revealed type is "Union[builtins.str, builtins.float]"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]
9 changes: 8 additions & 1 deletion test-data/unit/fixtures/narrowing.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Builtins stub used in check-narrowing test cases.
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable


Tco = TypeVar('Tco', covariant=True)
Expand All @@ -15,6 +15,13 @@ class function: pass
class ellipsis: pass
class int: pass
class str: pass
class float: pass
class dict(Generic[KT, VT]): pass

def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass

class list(Sequence[Tco]):
def __contains__(self, other: object) -> bool: pass
class set(Iterable[Tco], Generic[Tco]):
def __init__(self, iterable: Iterable[Tco] = ...) -> None: ...
def __contains__(self, item: object) -> bool: pass
Loading