Skip to content

Commit

Permalink
Merge pull request #2276 from devitocodes/recipes-tweaks
Browse files Browse the repository at this point in the history
api: Minor fixes to arithmetic operations with scalar and tensors
  • Loading branch information
FabioLuporini committed Dec 7, 2023
2 parents 1bf0991 + 582c39c commit 310226e
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def pull_dims(exprs, flag=True):
"""
dims = set()
for e in as_tuple(exprs):
dims.update({i for i in e.free_symbols if i.is_Dimension})
dims.update({i for i in e.free_symbols if isinstance(i, Dimension)})
if flag:
return set().union(*[d._defines for d in dims])
else:
Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def abridge_dim_names(iet):
# Find SubDimensions or SubDimension-derived dimensions used as indices in
# the expression in the innermost loop
indexeds = FindSymbols('indexeds').visit(tree.inner)
dims = set().union(*[pull_dims(i, flag=False) for i in indexeds])
dims = pull_dims(indexeds, flag=False)
dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])]
dims = [d for d in dims if not d.is_SubIterator]
names = [d.root.name for d in dims]
Expand Down
10 changes: 10 additions & 0 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sympy

from sympy.core.assumptions import _assume_rules
from sympy.core.decorators import call_highest_priority
from cached_property import cached_property

from devito.data import default_allocator
Expand Down Expand Up @@ -704,6 +705,15 @@ def adjoint(self, inner=True):
# Real valued adjoint is transpose
return self.transpose(inner=inner)

@call_highest_priority('__radd__')
def __add__(self, other):
try:
# Most case support sympy add
return super().__add__(other)
except TypeError:
# Sympy doesn't support add with scalars
return self.applyfunc(lambda x: x + other)

def _eval_matrix_mul(self, other):
"""
Copy paste from sympy to avoid explicit call to sympy.Add
Expand Down
7 changes: 6 additions & 1 deletion devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def _flatten(self):
"""
if self.lhs.is_Matrix:
# Maps the Equations to retrieve the rhs from relevant lhs
eqs = dict(zip(as_tuple(self.lhs), as_tuple(self.rhs)))
try:
eqs = dict(zip(self.lhs, self.rhs))
except TypeError:
# Same rhs for all lhs
assert not self.rhs.is_Matrix
eqs = {i: self.rhs for i in self.lhs}
# Get the relevant equations from the lhs structure. .values removes
# the symmetric duplicates and off-diagonal zeros.
lhss = self.lhs.values()
Expand Down

0 comments on commit 310226e

Please sign in to comment.