Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Improve lowering of IndexDerivatives #2112

Merged
merged 9 commits into from
Jun 5, 2023
Prev Previous commit
Next Next commit
compiler: Improve unexpansion
  • Loading branch information
FabioLuporini committed May 31, 2023
commit 32fe68d183b4e1d05ab4c21c8623026968e4353f
8 changes: 6 additions & 2 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,12 @@ def distance(self, other):
# Indexed representing an arbitrary access along `x`, within the `t`
# IterationSpace, while the sink lives within the `tx` IterationSpace
if len(self.itintervals[n:]) != len(other.itintervals[n:]):
ret.append(S.Infinity)
return Vector(*ret)
v = Vector(*ret)
if v != 0:
return v
else:
ret.append(S.Infinity)
return Vector(*ret)

# It still could be an imaginary dependence, e.g. `a[3] -> a[4]` or, more
# nasty, `a[i+1, 3] -> a[i, 4]`
Expand Down
96 changes: 63 additions & 33 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from devito.finite_differences import IndexDerivative
from devito.ir import Interval, IterationSpace
from devito.ir import Interval, IterationSpace, Queue
from devito.passes.clusters.misc import fuse
from devito.symbolics import (retrieve_dimensions, reuse_if_untouched, q_leaf,
uxreplace)
Expand All @@ -11,52 +11,44 @@

@timed_pass()
def lower_index_derivatives(clusters, mode=None, **kwargs):
clusters, weights = _lower_index_derivatives(clusters, **kwargs)
clusters, weights, mapper = _lower_index_derivatives(clusters, **kwargs)

if not weights:
return clusters

if mode != 'noop':
clusters = fuse(clusters, toposort='maximal')

clusters = CDE(mapper).process(clusters)

return clusters


def _lower_index_derivatives(clusters, sregistry=None, **kwargs):
processed = []
weights = {}
processed = []
mapper = {}

def dump(exprs, c):
if exprs:
processed.append(c.rebuild(exprs=exprs))
exprs[:] = []

for c in clusters:
# Can I reuse common IndexDerivatives popping up in different exprs
# within `c`?
# NOTE: this could be refined to rather identify groups of consecutive
# exprs sharing IndexDerivatives, but it's in practice an overkill, since
# we only end up here in artificious cases
unreusable = any(d.is_indep() and d.is_lex_ne for d in c.scope.d_all_gen())

exprs = []
seen = {}
for e in c.exprs:
expr, v = _lower_index_derivatives_core(e, c, weights, seen, sregistry)
if v and unreusable:
expr, v = _core(e, c, weights, mapper, sregistry)
if v:
dump(exprs, c)
processed.extend(v)
exprs.append(expr)
processed.extend(v)

if unreusable:
seen = {}

dump(exprs, c)

return processed, weights
return processed, weights, mapper


def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
def _core(expr, c, weights, mapper, sregistry):
"""
Recursively carry out the core of `lower_index_derivatives`.
"""
Expand All @@ -66,7 +58,7 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
args = []
processed = []
for a in expr.args:
e, clusters = _lower_index_derivatives_core(a, c, weights, seen, sregistry)
e, clusters = _core(a, c, weights, mapper, sregistry)
args.append(e)
processed.extend(clusters)

Expand All @@ -85,12 +77,6 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
w = weights[k] = w0._rebuild(name=name)
expr = uxreplace(expr, {w0.indexed: w.indexed})

# Have I seen this IndexDerivative already?
try:
return seen[expr], []
except (KeyError, TypeError):
pass

dims = retrieve_dimensions(expr, deep=True)
dims = filter_ordered(d for d in dims if isinstance(d, StencilDimension))

Expand All @@ -100,7 +86,7 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
# upper and lower offsets, we honor it
dims = tuple(d for d in dims if d not in c.ispace)

intervals = [Interval(d, 0, 0) for d in dims]
intervals = [Interval(d) for d in dims]
ispace0 = IterationSpace(intervals)

extra = (c.ispace.itdimensions + dims,)
Expand All @@ -112,16 +98,60 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
ispace1 = ispace.project(lambda d: d is not dims[-1])
processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1))

# Track IndexDerivative to avoid intra-Cluster duplicates
try:
seen[expr] = s
except TypeError:
pass

# Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the
# StencilDimensions starting points
subs = {expr.weights: expr.weights.subs(d, d - d._min) for d in dims}
expr1 = Inc(s, uxreplace(expr.expr, subs))
processed.append(c.rebuild(exprs=expr1, ispace=ispace))

# Track lowered IndexDerivative for subsequent optimization by the caller
mapper.setdefault(expr1.rhs, []).append(s)

return s, processed


class CDE(Queue):

"""
Common derivative elimination.
"""

def __init__(self, mapper):
super().__init__()

self.mapper = {k: v for k, v in mapper.items() if len(v) > 1}

def process(self, clusters):
return self._process_fdta(clusters, 1, subs={}, seen=set())

def callback(self, clusters, prefix, subs=None, seen=None):
processed = []
for c in clusters:
if c in seen:
processed.append(c)
continue

exprs = []
for e in c.exprs:
k, v = e.args

if k in subs:
continue

try:
subs[k] = subs[v]
continue
except KeyError:
pass

if v in self.mapper:
subs[v] = k
exprs.append(e)
else:
exprs.append(uxreplace(e, subs))

processed.append(c.rebuild(exprs=exprs))

seen.update(processed)

return processed
4 changes: 4 additions & 0 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import Counter, defaultdict
from itertools import groupby, product

from devito.finite_differences import IndexDerivative
from devito.ir.clusters import Cluster, ClusterGroup, Queue, cluster_pass
from devito.ir.support import (SEQUENTIAL, SEPARABLE, Scope, ReleaseLock,
WaitLock, WithLock, FetchUpdate, PrefetchUpdate)
Expand Down Expand Up @@ -188,6 +189,9 @@ def _key(self, c):
# Clusters representing HaloTouches should get merged, if possible
key += (c.is_halo_touch,)

# Promoting adjacency of IndexDerivatives will maximize their reuse
key += (any(e.find(IndexDerivative) for e in c.exprs),)

return key

def _apply_heuristics(self, clusters):
Expand Down
9 changes: 2 additions & 7 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def test_v6(self):
grid = Grid(shape=(16, 16))

f = Function(name='f', grid=grid, space_order=4)
g = Function(name='g', grid=grid, space_order=4)
p0 = TimeFunction(name='p0', grid=grid, time_order=2, space_order=4,
save=Buffer(2))
m0 = TimeFunction(name='m0', grid=grid, time_order=2, space_order=4,
Expand All @@ -231,11 +230,7 @@ def test_v6(self):
opt=('advanced', {'expand': False}))

# Check code generation
assert op._profiler._sections['section0'].sops == 183
assert_structure(
op,
['t,x,y', 't,x,y,i1', 't,x,y,i1,i0', 't,x,y,i1', 't,x,y,i1,i0'],
't,x,y,i1,i0,i1,i0'
)
assert op._profiler._sections['section0'].sops == 133
assert_structure(op, ['t,x,y', 't,x,y,i1', 't,x,y,i1,i0'], 't,x,y,i1,i0')

op.cfunction