diff --git a/mypy/checker.py b/mypy/checker.py index 9a826cd41496..4b98ed2937ed 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2813,7 +2813,8 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool partial_type = PartialType(None, name) elif isinstance(init_type, Instance): fullname = init_type.type.fullname - if (isinstance(lvalue, (NameExpr, MemberExpr)) and + is_ref = isinstance(lvalue, RefExpr) + if (is_ref and (fullname == 'builtins.list' or fullname == 'builtins.set' or fullname == 'builtins.dict' or @@ -2821,6 +2822,17 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool all(isinstance(t, (NoneType, UninhabitedType)) for t in get_proper_types(init_type.args))): partial_type = PartialType(init_type.type, name) + elif is_ref and fullname == 'collections.defaultdict': + arg0 = get_proper_type(init_type.args[0]) + arg1 = get_proper_type(init_type.args[1]) + if (isinstance(arg0, (NoneType, UninhabitedType)) and + isinstance(arg1, Instance) and + self.is_valid_defaultdict_partial_value_type(arg1)): + # Erase type argument, if one exists (this fills in Anys) + arg1 = self.named_type(arg1.type.fullname) + partial_type = PartialType(init_type.type, name, arg1) + else: + return False else: return False else: @@ -2829,6 +2841,28 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool self.partial_types[-1].map[name] = lvalue return True + def is_valid_defaultdict_partial_value_type(self, t: Instance) -> bool: + """Check if t can be used as the basis for a partial defaultddict value type. + + Examples: + + * t is 'int' --> True + * t is 'list[]' --> True + * t is 'dict[...]' --> False (only generic types with a single type + argument supported) + """ + if len(t.args) == 0: + return True + if len(t.args) == 1: + arg = get_proper_type(t.args[0]) + # TODO: This is too permissive -- we only allow TypeVarType since + # they leak in cases like defaultdict(list) due to a bug. + # This can result in incorrect types being inferred, but only + # in rare cases. + if isinstance(arg, (TypeVarType, UninhabitedType, NoneType)): + return True + return False + def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None: """Store inferred variable type. @@ -3018,16 +3052,21 @@ def try_infer_partial_type_from_indexed_assignment( if partial_types is None: return typename = type_type.fullname - if typename == 'builtins.dict' or typename == 'collections.OrderedDict': + if (typename == 'builtins.dict' + or typename == 'collections.OrderedDict' + or typename == 'collections.defaultdict'): # TODO: Don't infer things twice. key_type = self.expr_checker.accept(lvalue.index) value_type = self.expr_checker.accept(rvalue) if (is_valid_inferred_type(key_type) and - is_valid_inferred_type(value_type)): - if not self.current_node_deferred: - var.type = self.named_generic_type(typename, - [key_type, value_type]) - del partial_types[var] + is_valid_inferred_type(value_type) and + not self.current_node_deferred and + not (typename == 'collections.defaultdict' and + var.type.value_type is not None and + not is_equivalent(value_type, var.type.value_type))): + var.type = self.named_generic_type(typename, + [key_type, value_type]) + del partial_types[var] def visit_expression_stmt(self, s: ExpressionStmt) -> None: self.expr_checker.accept(s.expr, allow_none_return=True, always_allow_any=True) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 35c58478ce1e..0e5b42abde0a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -567,42 +567,91 @@ def get_partial_self_var(self, expr: MemberExpr) -> Optional[Var]: } # type: ClassVar[Dict[str, Dict[str, List[str]]]] def try_infer_partial_type(self, e: CallExpr) -> None: - if isinstance(e.callee, MemberExpr) and isinstance(e.callee.expr, RefExpr): - var = e.callee.expr.node - if var is None and isinstance(e.callee.expr, MemberExpr): - var = self.get_partial_self_var(e.callee.expr) - if not isinstance(var, Var): + """Try to make partial type precise from a call.""" + if not isinstance(e.callee, MemberExpr): + return + callee = e.callee + if isinstance(callee.expr, RefExpr): + # Call a method with a RefExpr callee, such as 'x.method(...)'. + ret = self.get_partial_var(callee.expr) + if ret is None: return - partial_types = self.chk.find_partial_types(var) - if partial_types is not None and not self.chk.current_node_deferred: - partial_type = var.type - if (partial_type is None or - not isinstance(partial_type, PartialType) or - partial_type.type is None): - # A partial None type -> can't infer anything. - return - typename = partial_type.type.fullname - methodname = e.callee.name - # Sometimes we can infer a full type for a partial List, Dict or Set type. - # TODO: Don't infer argument expression twice. - if (typename in self.item_args and methodname in self.item_args[typename] - and e.arg_kinds == [ARG_POS]): - item_type = self.accept(e.args[0]) - if mypy.checker.is_valid_inferred_type(item_type): - var.type = self.chk.named_generic_type(typename, [item_type]) - del partial_types[var] - elif (typename in self.container_args - and methodname in self.container_args[typename] - and e.arg_kinds == [ARG_POS]): - arg_type = get_proper_type(self.accept(e.args[0])) - if isinstance(arg_type, Instance): - arg_typename = arg_type.type.fullname - if arg_typename in self.container_args[typename][methodname]: - if all(mypy.checker.is_valid_inferred_type(item_type) - for item_type in arg_type.args): - var.type = self.chk.named_generic_type(typename, - list(arg_type.args)) - del partial_types[var] + var, partial_types = ret + typ = self.try_infer_partial_value_type_from_call(e, callee.name, var) + if typ is not None: + var.type = typ + del partial_types[var] + elif isinstance(callee.expr, IndexExpr) and isinstance(callee.expr.base, RefExpr): + # Call 'x[y].method(...)'; may infer type of 'x' if it's a partial defaultdict. + if callee.expr.analyzed is not None: + return # A special form + base = callee.expr.base + index = callee.expr.index + ret = self.get_partial_var(base) + if ret is None: + return + var, partial_types = ret + partial_type = get_partial_instance_type(var.type) + if partial_type is None or partial_type.value_type is None: + return + value_type = self.try_infer_partial_value_type_from_call(e, callee.name, var) + if value_type is not None: + # Infer key type. + key_type = self.accept(index) + if mypy.checker.is_valid_inferred_type(key_type): + # Store inferred partial type. + assert partial_type.type is not None + typename = partial_type.type.fullname + var.type = self.chk.named_generic_type(typename, + [key_type, value_type]) + del partial_types[var] + + def get_partial_var(self, ref: RefExpr) -> Optional[Tuple[Var, Dict[Var, Context]]]: + var = ref.node + if var is None and isinstance(ref, MemberExpr): + var = self.get_partial_self_var(ref) + if not isinstance(var, Var): + return None + partial_types = self.chk.find_partial_types(var) + if partial_types is None: + return None + return var, partial_types + + def try_infer_partial_value_type_from_call( + self, + e: CallExpr, + methodname: str, + var: Var) -> Optional[Instance]: + """Try to make partial type precise from a call such as 'x.append(y)'.""" + if self.chk.current_node_deferred: + return None + partial_type = get_partial_instance_type(var.type) + if partial_type is None: + return None + if partial_type.value_type: + typename = partial_type.value_type.type.fullname + else: + assert partial_type.type is not None + typename = partial_type.type.fullname + # Sometimes we can infer a full type for a partial List, Dict or Set type. + # TODO: Don't infer argument expression twice. + if (typename in self.item_args and methodname in self.item_args[typename] + and e.arg_kinds == [ARG_POS]): + item_type = self.accept(e.args[0]) + if mypy.checker.is_valid_inferred_type(item_type): + return self.chk.named_generic_type(typename, [item_type]) + elif (typename in self.container_args + and methodname in self.container_args[typename] + and e.arg_kinds == [ARG_POS]): + arg_type = get_proper_type(self.accept(e.args[0])) + if isinstance(arg_type, Instance): + arg_typename = arg_type.type.fullname + if arg_typename in self.container_args[typename][methodname]: + if all(mypy.checker.is_valid_inferred_type(item_type) + for item_type in arg_type.args): + return self.chk.named_generic_type(typename, + list(arg_type.args)) + return None def apply_function_plugin(self, callee: CallableType, @@ -4299,3 +4348,9 @@ def is_operator_method(fullname: Optional[str]) -> bool: short_name in nodes.op_methods.values() or short_name in nodes.reverse_op_methods.values() or short_name in nodes.unary_op_methods.values()) + + +def get_partial_instance_type(t: Optional[Type]) -> Optional[PartialType]: + if t is None or not isinstance(t, PartialType) or t.type is None: + return None + return t diff --git a/mypy/types.py b/mypy/types.py index ae678acedb3a..40b8d311d5cd 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1763,13 +1763,18 @@ class PartialType(ProperType): # None for the 'None' partial type; otherwise a generic class type = None # type: Optional[mypy.nodes.TypeInfo] var = None # type: mypy.nodes.Var + # For partial defaultdict[K, V], the type V (K is unknown). If V is generic, + # the type argument is Any and will be replaced later. + value_type = None # type: Optional[Instance] def __init__(self, type: 'Optional[mypy.nodes.TypeInfo]', - var: 'mypy.nodes.Var') -> None: + var: 'mypy.nodes.Var', + value_type: 'Optional[Instance]' = None) -> None: super().__init__() self.type = type self.var = var + self.value_type = value_type def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_partial_type(self) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index c7a1b35c7cbe..323800429522 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -26,6 +26,9 @@ class ellipsis: pass # Primitive types are special in generated code. class int: + @overload + def __init__(self) -> None: pass + @overload def __init__(self, x: object, base: int = 10) -> None: pass def __add__(self, n: int) -> int: pass def __sub__(self, n: int) -> int: pass diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 91b3a93506f5..19d1554c5ef6 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -2976,3 +2976,97 @@ x: Optional[str] y = filter(None, [x]) reveal_type(y) # N: Revealed type is 'builtins.list[builtins.str*]' [builtins fixtures/list.pyi] + +[case testPartialDefaultDict] +from collections import defaultdict +x = defaultdict(int) +x[''] = 1 +reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]' + +y = defaultdict(int) # E: Need type annotation for 'y' + +z = defaultdict(int) # E: Need type annotation for 'z' +z[''] = '' +reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]' +[builtins fixtures/dict.pyi] + +[case testPartialDefaultDictInconsistentValueTypes] +from collections import defaultdict +a = defaultdict(int) # E: Need type annotation for 'a' +a[''] = '' +a[''] = 1 +reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]' +[builtins fixtures/dict.pyi] + +[case testPartialDefaultDictListValue] +# flags: --no-strict-optional +from collections import defaultdict +a = defaultdict(list) +a['x'].append(1) +reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]' + +b = defaultdict(lambda: []) +b[1].append('x') +reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]' +[builtins fixtures/dict.pyi] + +[case testPartialDefaultDictListValueStrictOptional] +# flags: --strict-optional +from collections import defaultdict +a = defaultdict(list) +a['x'].append(1) +reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]' + +b = defaultdict(lambda: []) +b[1].append('x') +reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]' +[builtins fixtures/dict.pyi] + +[case testPartialDefaultDictSpecialCases] +from collections import defaultdict +class A: + def f(self) -> None: + self.x = defaultdict(list) + self.x['x'].append(1) + reveal_type(self.x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]' + self.y = defaultdict(list) # E: Need type annotation for 'y' + s = self + s.y['x'].append(1) + +x = {} # E: Need type annotation for 'x' (hint: "x: Dict[, ] = ...") +x['x'].append(1) + +y = defaultdict(list) # E: Need type annotation for 'y' +y[[]].append(1) +[builtins fixtures/dict.pyi] + +[case testPartialDefaultDictSpecialCases2] +from collections import defaultdict + +x = defaultdict(lambda: [1]) # E: Need type annotation for 'x' +x[1].append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int" +reveal_type(x) # N: Revealed type is 'collections.defaultdict[Any, builtins.list[builtins.int]]' + +xx = defaultdict(lambda: {'x': 1}) # E: Need type annotation for 'xx' +xx[1]['z'] = 3 +reveal_type(xx) # N: Revealed type is 'collections.defaultdict[Any, builtins.dict[builtins.str, builtins.int]]' + +y = defaultdict(dict) # E: Need type annotation for 'y' +y['x'][1] = [3] + +z = defaultdict(int) # E: Need type annotation for 'z' +z[1].append('') +reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]' +[builtins fixtures/dict.pyi] + +[case testPartialDefaultDictSpecialCase3] +from collections import defaultdict + +x = defaultdict(list) +x['a'] = [1, 2, 3] +reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int*]]' + +y = defaultdict(list) # E: Need type annotation for 'y' +y['a'] = [] +reveal_type(y) # N: Revealed type is 'collections.defaultdict[Any, Any]' +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index 9e7970b34705..99c950d8fc9f 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -42,6 +42,7 @@ class list(Sequence[T]): # needed by some test cases def __iter__(self) -> Iterator[T]: pass def __mul__(self, x: int) -> list[T]: pass def __contains__(self, item: object) -> bool: pass + def append(self, item: T) -> None: pass class tuple(Generic[T]): pass class function: pass diff --git a/test-data/unit/lib-stub/collections.pyi b/test-data/unit/lib-stub/collections.pyi index c93fea198ebf..c5b5ef0504e6 100644 --- a/test-data/unit/lib-stub/collections.pyi +++ b/test-data/unit/lib-stub/collections.pyi @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Union, Optional, Dict, TypeVar +from typing import Any, Iterable, Union, Optional, Dict, TypeVar, overload, Optional, Callable def namedtuple( typename: str, @@ -10,8 +10,10 @@ def namedtuple( defaults: Optional[Iterable[Any]] = ... ) -> Any: ... -K = TypeVar('K') -V = TypeVar('V') +KT = TypeVar('KT') +VT = TypeVar('VT') -class OrderedDict(Dict[K, V]): - def __setitem__(self, k: K, v: V) -> None: ... +class OrderedDict(Dict[KT, VT]): ... + +class defaultdict(Dict[KT, VT]): + def __init__(self, default_factory: Optional[Callable[[], VT]]) -> None: ... diff --git a/test-data/unit/python2eval.test b/test-data/unit/python2eval.test index 2267cadb1a08..93fe668a8b81 100644 --- a/test-data/unit/python2eval.test +++ b/test-data/unit/python2eval.test @@ -420,11 +420,11 @@ if MYPY: x = b'abc' [out] -[case testNestedGenericFailedInference] +[case testDefaultDictInference] from collections import defaultdict def foo() -> None: - x = defaultdict(list) # type: ignore + x = defaultdict(list) x['lol'].append(10) reveal_type(x) [out] -_testNestedGenericFailedInference.py:5: note: Revealed type is 'collections.defaultdict[Any, builtins.list[Any]]' +_testDefaultDictInference.py:5: note: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'