Skip to content

Commit

Permalink
compiler: Revamp FIndexed for correct reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Apr 19, 2023
1 parent 8dc6527 commit 62414bb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
2 changes: 1 addition & 1 deletion devito/passes/iet/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _(f, indexeds, tracker, strides, sregistry):

if len(i.indices) == i.function.ndim:
v = tuple(strides.values())[-n:]
subs[i] = FIndexed(i, pname, strides=v)
subs[i] = FIndexed.from_indexed(i, pname, strides=v)
else:
# Honour custom indexing
subs[i] = i.base[sum(i.indices)]
Expand Down
23 changes: 15 additions & 8 deletions devito/types/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,21 @@ class FIndexed(Indexed, Pickable):
`uX[x*ny + y]`, where `X` is a string provided by the caller.
"""

__rargs__ = ('indexed', 'pname')
__rargs__ = ('base', '*indices')
__rkwargs__ = ('strides',)

def __new__(cls, indexed, pname, strides=None):
plabel = Symbol(name=pname, dtype=indexed.dtype)
base = IndexedData(plabel, None, function=indexed.function)
obj = super().__new__(cls, base, *indexed.indices)

obj.indexed = indexed
obj.pname = pname
def __new__(cls, base, *args, strides=None):
obj = super().__new__(cls, base, *args)
obj.strides = as_tuple(strides)

return obj

@classmethod
def from_indexed(cls, indexed, pname, strides=None):
label = Symbol(name=pname, dtype=indexed.dtype)
base = IndexedData(label, None, function=indexed.function)
return FIndexed(base, *indexed.indices, strides=strides)

def __repr__(self):
return "%s(%s)" % (self.name, ", ".join(str(i) for i in self.indices))

Expand All @@ -90,10 +91,16 @@ def __repr__(self):
def _hashable_content(self):
return super()._hashable_content() + (self.strides,)

func = Pickable._rebuild

@property
def name(self):
return self.function.name

@property
def pname(self):
return self.base.name

@property
def free_symbols(self):
# The functional representation of the FIndexed "hides" the strides, which
Expand Down
3 changes: 2 additions & 1 deletion tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
CallFromPointer, Cast, FieldFromPointer,
FieldFromComposite, IntDiv, ccode, uxreplace)
from devito.types import Array, LocalObject, Object, Symbol as dSymbol
from devito.types import Array, FIndexed, LocalObject, Object, Symbol as dSymbol # noqa


def test_float_indices():
Expand Down Expand Up @@ -333,6 +333,7 @@ def test_solve_time():
('f[x, y+1]', '{f.indexed: g.indexed}', 'g[x, y+1]'),
('cos(f)', '{cos: sin}', 'sin(f)'),
('cos(f + sin(g))', '{cos: sin, sin: cos}', 'sin(f + cos(g))'),
('FIndexed(f.indexed, x, y)', '{x: 0}', 'FIndexed(f.indexed, 0, y)'),
])
def test_uxreplace(expr, subs, expected):
grid = Grid(shape=(4, 4))
Expand Down

0 comments on commit 62414bb

Please sign in to comment.