Skip to content

Commit

Permalink
Support matrix multiplication operator (@) (#2287)
Browse files Browse the repository at this point in the history
Closes #705.
  • Loading branch information
elazarg authored and gvanrossum committed Oct 20, 2016
1 parent 9a6cce0 commit 48fa2ef
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 7 deletions.
2 changes: 0 additions & 2 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def from_operator(self, op: ast35.operator) -> str:
op_name = ASTConverter.op_map.get(type(op))
if op_name is None:
raise RuntimeError('Unknown operator ' + str(type(op)))
elif op_name == '@':
raise RuntimeError('mypy does not support the MatMult operator')
else:
return op_name

Expand Down
4 changes: 2 additions & 2 deletions mypy/lex.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ def lex(string: Union[str, bytes], first_line: int = 1,

# List of regular expressions that match non-alphabetical operators
operators = [re.compile('[-+*/<>.%&|^~]'),
re.compile('==|!=|<=|>=|\\*\\*|//|<<|>>|<>')]
re.compile('==|!=|<=|>=|\\*\\*|@|//|<<|>>|<>')]

# List of regular expressions that match punctuator tokens
punctuators = [re.compile('[=,()@`]|(->)'),
re.compile('\\['),
re.compile(']'),
re.compile('([-+*/%&|^]|\\*\\*|//|<<|>>)=')]
re.compile('([-+*/%@&|^]|\\*\\*|//|<<|>>)=')]


# Map single-character string escape sequences to corresponding characters.
Expand Down
4 changes: 3 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
'%': '__mod__',
'//': '__floordiv__',
'**': '__pow__',
'@': '__matmul__',
'&': '__and__',
'|': '__or__',
'^': '__xor__',
Expand All @@ -1349,7 +1350,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T:


ops_with_inplace_method = {
'+', '-', '*', '/', '%', '//', '**', '&', '|', '^', '<<', '>>'}
'+', '-', '*', '/', '%', '//', '**', '@', '&', '|', '^', '<<', '>>'}

inplace_operator_methods = set(
'__i' + op_methods[op][2:] for op in ops_with_inplace_method)
Expand All @@ -1362,6 +1363,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
'__mod__': '__rmod__',
'__floordiv__': '__rfloordiv__',
'__pow__': '__rpow__',
'__matmul__': '__rmatmul__',
'__and__': '__rand__',
'__or__': '__ror__',
'__xor__': '__rxor__',
Expand Down
4 changes: 2 additions & 2 deletions mypy/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'**': 16,
'-u': 15, '+u': 15, '~': 15, # unary operators (-, + and ~)
'<cast>': 14,
'*': 13, '/': 13, '//': 13, '%': 13,
'*': 13, '/': 13, '//': 13, '%': 13, '@': 13,
'+': 12, '-': 12,
'>>': 11, '<<': 11,
'&': 10,
Expand All @@ -61,7 +61,7 @@


op_assign = set([
'+=', '-=', '*=', '/=', '//=', '%=', '**=', '|=', '&=', '^=', '>>=',
'+=', '-=', '*=', '/=', '//=', '%=', '**=', '@=', '|=', '&=', '^=', '>>=',
'<<='])

op_comp = set([
Expand Down
15 changes: 15 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,21 @@ main:3: error: Unsupported operand types for * ("A" and "C")
main:4: error: Incompatible types in assignment (expression has type "C", variable has type "A")
main:5: error: Unsupported left operand type for * ("B")

[case testMatMul]
a, b, c = None, None, None # type: (A, B, C)
c = a @ c # E: Unsupported operand types for @ ("A" and "C")
a = a @ b # E: Incompatible types in assignment (expression has type "C", variable has type "A")
c = b @ a # E: Unsupported left operand type for @ ("B")
c = a @ b

class A:
def __matmul__(self, x: 'B') -> 'C':
pass
class B:
pass
class C:
pass

[case testDiv]

a, b, c = None, None, None # type: (A, B, C)
Expand Down
7 changes: 7 additions & 0 deletions test-data/unit/check-fastparse.test
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,10 @@ def f(a):
pass
[out]
main:3: error: invalid type comment

[case testFastParseMatMul]
# flags: --fast-parser
from typing import Any
x = None # type: Any
x @ 1
x @= 1
14 changes: 14 additions & 0 deletions test-data/unit/check-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ class C: pass
main:3: error: Unsupported operand types for * ("A" and "A")
main:4: error: Unsupported left operand type for * ("C")

[case testMatMulAssign]
a, c = None, None # type: (A, C)
a @= a # E: Unsupported operand types for @ ("A" and "A")
c @= a # E: Unsupported left operand type for @ ("C")
a @= c

class A:
def __matmul__(self, x: 'C') -> 'A': pass

class C: pass

[case testDivAssign]

a, c = None, None # type: (A, C)
Expand Down Expand Up @@ -295,11 +306,14 @@ import typing
class A:
def __iadd__(self, x: int) -> 'A': pass
def __imul__(self, x: str) -> 'A': pass
def __imatmul__(self, x: str) -> 'A': pass
a = A()
a += 1
a *= ''
a @= ''
a += '' # E: Argument 1 to "__iadd__" of "A" has incompatible type "str"; expected "int"
a *= 1 # E: Argument 1 to "__imul__" of "A" has incompatible type "int"; expected "str"
a @= 1 # E: Argument 1 to "__imatmul__" of "A" has incompatible type "int"; expected "str"

[case testInplaceSetitem]
class A(object):
Expand Down

0 comments on commit 48fa2ef

Please sign in to comment.