Skip to content

Commit

Permalink
Support type inference for defaultdict() (python#8167)
Browse files Browse the repository at this point in the history
This allows inferring type of `x`, for example:

```
from collections import defaultdict

x = defaultdict(list)  # defaultdict[str, List[int]]
x['foo'].append(1)
```

The implemention is not pretty and we have probably
reached about the maximum reasonable level of special
casing in type inference now.

There is a hack to work around the problem with leaking
type variable types in nested generics calls (I think).
This will break some (likely very rare) use cases.
  • Loading branch information
JukkaL authored Dec 18, 2019
1 parent 5f16416 commit c1cd529
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 51 deletions.
53 changes: 46 additions & 7 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2813,14 +2813,26 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool
partial_type = PartialType(None, name)
elif isinstance(init_type, Instance):
fullname = init_type.type.fullname
if (isinstance(lvalue, (NameExpr, MemberExpr)) and
is_ref = isinstance(lvalue, RefExpr)
if (is_ref and
(fullname == 'builtins.list' or
fullname == 'builtins.set' or
fullname == 'builtins.dict' or
fullname == 'collections.OrderedDict') and
all(isinstance(t, (NoneType, UninhabitedType))
for t in get_proper_types(init_type.args))):
partial_type = PartialType(init_type.type, name)
elif is_ref and fullname == 'collections.defaultdict':
arg0 = get_proper_type(init_type.args[0])
arg1 = get_proper_type(init_type.args[1])
if (isinstance(arg0, (NoneType, UninhabitedType)) and
isinstance(arg1, Instance) and
self.is_valid_defaultdict_partial_value_type(arg1)):
# Erase type argument, if one exists (this fills in Anys)
arg1 = self.named_type(arg1.type.fullname)
partial_type = PartialType(init_type.type, name, arg1)
else:
return False
else:
return False
else:
Expand All @@ -2829,6 +2841,28 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool
self.partial_types[-1].map[name] = lvalue
return True

def is_valid_defaultdict_partial_value_type(self, t: Instance) -> bool:
"""Check if t can be used as the basis for a partial defaultddict value type.
Examples:
* t is 'int' --> True
* t is 'list[<nothing>]' --> True
* t is 'dict[...]' --> False (only generic types with a single type
argument supported)
"""
if len(t.args) == 0:
return True
if len(t.args) == 1:
arg = get_proper_type(t.args[0])
# TODO: This is too permissive -- we only allow TypeVarType since
# they leak in cases like defaultdict(list) due to a bug.
# This can result in incorrect types being inferred, but only
# in rare cases.
if isinstance(arg, (TypeVarType, UninhabitedType, NoneType)):
return True
return False

def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
"""Store inferred variable type.
Expand Down Expand Up @@ -3018,16 +3052,21 @@ def try_infer_partial_type_from_indexed_assignment(
if partial_types is None:
return
typename = type_type.fullname
if typename == 'builtins.dict' or typename == 'collections.OrderedDict':
if (typename == 'builtins.dict'
or typename == 'collections.OrderedDict'
or typename == 'collections.defaultdict'):
# TODO: Don't infer things twice.
key_type = self.expr_checker.accept(lvalue.index)
value_type = self.expr_checker.accept(rvalue)
if (is_valid_inferred_type(key_type) and
is_valid_inferred_type(value_type)):
if not self.current_node_deferred:
var.type = self.named_generic_type(typename,
[key_type, value_type])
del partial_types[var]
is_valid_inferred_type(value_type) and
not self.current_node_deferred and
not (typename == 'collections.defaultdict' and
var.type.value_type is not None and
not is_equivalent(value_type, var.type.value_type))):
var.type = self.named_generic_type(typename,
[key_type, value_type])
del partial_types[var]

def visit_expression_stmt(self, s: ExpressionStmt) -> None:
self.expr_checker.accept(s.expr, allow_none_return=True, always_allow_any=True)
Expand Down
125 changes: 90 additions & 35 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,42 +567,91 @@ def get_partial_self_var(self, expr: MemberExpr) -> Optional[Var]:
} # type: ClassVar[Dict[str, Dict[str, List[str]]]]

def try_infer_partial_type(self, e: CallExpr) -> None:
if isinstance(e.callee, MemberExpr) and isinstance(e.callee.expr, RefExpr):
var = e.callee.expr.node
if var is None and isinstance(e.callee.expr, MemberExpr):
var = self.get_partial_self_var(e.callee.expr)
if not isinstance(var, Var):
"""Try to make partial type precise from a call."""
if not isinstance(e.callee, MemberExpr):
return
callee = e.callee
if isinstance(callee.expr, RefExpr):
# Call a method with a RefExpr callee, such as 'x.method(...)'.
ret = self.get_partial_var(callee.expr)
if ret is None:
return
partial_types = self.chk.find_partial_types(var)
if partial_types is not None and not self.chk.current_node_deferred:
partial_type = var.type
if (partial_type is None or
not isinstance(partial_type, PartialType) or
partial_type.type is None):
# A partial None type -> can't infer anything.
return
typename = partial_type.type.fullname
methodname = e.callee.name
# Sometimes we can infer a full type for a partial List, Dict or Set type.
# TODO: Don't infer argument expression twice.
if (typename in self.item_args and methodname in self.item_args[typename]
and e.arg_kinds == [ARG_POS]):
item_type = self.accept(e.args[0])
if mypy.checker.is_valid_inferred_type(item_type):
var.type = self.chk.named_generic_type(typename, [item_type])
del partial_types[var]
elif (typename in self.container_args
and methodname in self.container_args[typename]
and e.arg_kinds == [ARG_POS]):
arg_type = get_proper_type(self.accept(e.args[0]))
if isinstance(arg_type, Instance):
arg_typename = arg_type.type.fullname
if arg_typename in self.container_args[typename][methodname]:
if all(mypy.checker.is_valid_inferred_type(item_type)
for item_type in arg_type.args):
var.type = self.chk.named_generic_type(typename,
list(arg_type.args))
del partial_types[var]
var, partial_types = ret
typ = self.try_infer_partial_value_type_from_call(e, callee.name, var)
if typ is not None:
var.type = typ
del partial_types[var]
elif isinstance(callee.expr, IndexExpr) and isinstance(callee.expr.base, RefExpr):
# Call 'x[y].method(...)'; may infer type of 'x' if it's a partial defaultdict.
if callee.expr.analyzed is not None:
return # A special form
base = callee.expr.base
index = callee.expr.index
ret = self.get_partial_var(base)
if ret is None:
return
var, partial_types = ret
partial_type = get_partial_instance_type(var.type)
if partial_type is None or partial_type.value_type is None:
return
value_type = self.try_infer_partial_value_type_from_call(e, callee.name, var)
if value_type is not None:
# Infer key type.
key_type = self.accept(index)
if mypy.checker.is_valid_inferred_type(key_type):
# Store inferred partial type.
assert partial_type.type is not None
typename = partial_type.type.fullname
var.type = self.chk.named_generic_type(typename,
[key_type, value_type])
del partial_types[var]

def get_partial_var(self, ref: RefExpr) -> Optional[Tuple[Var, Dict[Var, Context]]]:
var = ref.node
if var is None and isinstance(ref, MemberExpr):
var = self.get_partial_self_var(ref)
if not isinstance(var, Var):
return None
partial_types = self.chk.find_partial_types(var)
if partial_types is None:
return None
return var, partial_types

def try_infer_partial_value_type_from_call(
self,
e: CallExpr,
methodname: str,
var: Var) -> Optional[Instance]:
"""Try to make partial type precise from a call such as 'x.append(y)'."""
if self.chk.current_node_deferred:
return None
partial_type = get_partial_instance_type(var.type)
if partial_type is None:
return None
if partial_type.value_type:
typename = partial_type.value_type.type.fullname
else:
assert partial_type.type is not None
typename = partial_type.type.fullname
# Sometimes we can infer a full type for a partial List, Dict or Set type.
# TODO: Don't infer argument expression twice.
if (typename in self.item_args and methodname in self.item_args[typename]
and e.arg_kinds == [ARG_POS]):
item_type = self.accept(e.args[0])
if mypy.checker.is_valid_inferred_type(item_type):
return self.chk.named_generic_type(typename, [item_type])
elif (typename in self.container_args
and methodname in self.container_args[typename]
and e.arg_kinds == [ARG_POS]):
arg_type = get_proper_type(self.accept(e.args[0]))
if isinstance(arg_type, Instance):
arg_typename = arg_type.type.fullname
if arg_typename in self.container_args[typename][methodname]:
if all(mypy.checker.is_valid_inferred_type(item_type)
for item_type in arg_type.args):
return self.chk.named_generic_type(typename,
list(arg_type.args))
return None

def apply_function_plugin(self,
callee: CallableType,
Expand Down Expand Up @@ -4299,3 +4348,9 @@ def is_operator_method(fullname: Optional[str]) -> bool:
short_name in nodes.op_methods.values() or
short_name in nodes.reverse_op_methods.values() or
short_name in nodes.unary_op_methods.values())


def get_partial_instance_type(t: Optional[Type]) -> Optional[PartialType]:
if t is None or not isinstance(t, PartialType) or t.type is None:
return None
return t
7 changes: 6 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,13 +1763,18 @@ class PartialType(ProperType):
# None for the 'None' partial type; otherwise a generic class
type = None # type: Optional[mypy.nodes.TypeInfo]
var = None # type: mypy.nodes.Var
# For partial defaultdict[K, V], the type V (K is unknown). If V is generic,
# the type argument is Any and will be replaced later.
value_type = None # type: Optional[Instance]

def __init__(self,
type: 'Optional[mypy.nodes.TypeInfo]',
var: 'mypy.nodes.Var') -> None:
var: 'mypy.nodes.Var',
value_type: 'Optional[Instance]' = None) -> None:
super().__init__()
self.type = type
self.var = var
self.value_type = value_type

def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_partial_type(self)
Expand Down
3 changes: 3 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class ellipsis: pass
# Primitive types are special in generated code.

class int:
@overload
def __init__(self) -> None: pass
@overload
def __init__(self, x: object, base: int = 10) -> None: pass
def __add__(self, n: int) -> int: pass
def __sub__(self, n: int) -> int: pass
Expand Down
94 changes: 94 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -2976,3 +2976,97 @@ x: Optional[str]
y = filter(None, [x])
reveal_type(y) # N: Revealed type is 'builtins.list[builtins.str*]'
[builtins fixtures/list.pyi]

[case testPartialDefaultDict]
from collections import defaultdict
x = defaultdict(int)
x[''] = 1
reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]'

y = defaultdict(int) # E: Need type annotation for 'y'

z = defaultdict(int) # E: Need type annotation for 'z'
z[''] = ''
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictInconsistentValueTypes]
from collections import defaultdict
a = defaultdict(int) # E: Need type annotation for 'a'
a[''] = ''
a[''] = 1
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictListValue]
# flags: --no-strict-optional
from collections import defaultdict
a = defaultdict(list)
a['x'].append(1)
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'

b = defaultdict(lambda: [])
b[1].append('x')
reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictListValueStrictOptional]
# flags: --strict-optional
from collections import defaultdict
a = defaultdict(list)
a['x'].append(1)
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'

b = defaultdict(lambda: [])
b[1].append('x')
reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictSpecialCases]
from collections import defaultdict
class A:
def f(self) -> None:
self.x = defaultdict(list)
self.x['x'].append(1)
reveal_type(self.x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'
self.y = defaultdict(list) # E: Need type annotation for 'y'
s = self
s.y['x'].append(1)

x = {} # E: Need type annotation for 'x' (hint: "x: Dict[<type>, <type>] = ...")
x['x'].append(1)

y = defaultdict(list) # E: Need type annotation for 'y'
y[[]].append(1)
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictSpecialCases2]
from collections import defaultdict

x = defaultdict(lambda: [1]) # E: Need type annotation for 'x'
x[1].append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
reveal_type(x) # N: Revealed type is 'collections.defaultdict[Any, builtins.list[builtins.int]]'

xx = defaultdict(lambda: {'x': 1}) # E: Need type annotation for 'xx'
xx[1]['z'] = 3
reveal_type(xx) # N: Revealed type is 'collections.defaultdict[Any, builtins.dict[builtins.str, builtins.int]]'

y = defaultdict(dict) # E: Need type annotation for 'y'
y['x'][1] = [3]

z = defaultdict(int) # E: Need type annotation for 'z'
z[1].append('')
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictSpecialCase3]
from collections import defaultdict

x = defaultdict(list)
x['a'] = [1, 2, 3]
reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int*]]'

y = defaultdict(list) # E: Need type annotation for 'y'
y['a'] = []
reveal_type(y) # N: Revealed type is 'collections.defaultdict[Any, Any]'
[builtins fixtures/dict.pyi]
1 change: 1 addition & 0 deletions test-data/unit/fixtures/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class list(Sequence[T]): # needed by some test cases
def __iter__(self) -> Iterator[T]: pass
def __mul__(self, x: int) -> list[T]: pass
def __contains__(self, item: object) -> bool: pass
def append(self, item: T) -> None: pass

class tuple(Generic[T]): pass
class function: pass
Expand Down
12 changes: 7 additions & 5 deletions test-data/unit/lib-stub/collections.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Union, Optional, Dict, TypeVar
from typing import Any, Iterable, Union, Optional, Dict, TypeVar, overload, Optional, Callable

def namedtuple(
typename: str,
Expand All @@ -10,8 +10,10 @@ def namedtuple(
defaults: Optional[Iterable[Any]] = ...
) -> Any: ...

K = TypeVar('K')
V = TypeVar('V')
KT = TypeVar('KT')
VT = TypeVar('VT')

class OrderedDict(Dict[K, V]):
def __setitem__(self, k: K, v: V) -> None: ...
class OrderedDict(Dict[KT, VT]): ...

class defaultdict(Dict[KT, VT]):
def __init__(self, default_factory: Optional[Callable[[], VT]]) -> None: ...
6 changes: 3 additions & 3 deletions test-data/unit/python2eval.test
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,11 @@ if MYPY:
x = b'abc'
[out]

[case testNestedGenericFailedInference]
[case testDefaultDictInference]
from collections import defaultdict
def foo() -> None:
x = defaultdict(list) # type: ignore
x = defaultdict(list)
x['lol'].append(10)
reveal_type(x)
[out]
_testNestedGenericFailedInference.py:5: note: Revealed type is 'collections.defaultdict[Any, builtins.list[Any]]'
_testDefaultDictInference.py:5: note: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'

0 comments on commit c1cd529

Please sign in to comment.