Skip to content

Commit

Permalink
Merge pull request #2112 from devitocodes/parlang-after-cire
Browse files Browse the repository at this point in the history
compiler: Improve lowering of IndexDerivatives
  • Loading branch information
FabioLuporini committed Jun 5, 2023
2 parents 51e11e5 + 8077f6e commit 5f1ff64
Show file tree
Hide file tree
Showing 14 changed files with 531 additions and 277 deletions.
18 changes: 5 additions & 13 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def callback(self, clusters, prefix):
mapper[size][si].add(iaf)

# Construct the ModuloDimensions
mds = OrderedDict()
mds = []
for size, v in mapper.items():
for si, iafs in list(v.items()):
# Offsets are sorted so that the semantic order (t0, t1, t2) follows
Expand All @@ -290,15 +290,10 @@ def callback(self, clusters, prefix):
# sorting offsets {-1, 0, 1} as {0, -1, 1} assigning -inf to 0
siafs = sorted(iafs, key=lambda i: -np.inf if i - si == 0 else (i - si))

# Create the ModuloDimensions. Note that if `size < len(iafs)` then
# the same ModuloDimension may be used for multiple offsets
for iaf in siafs[:size]:
for iaf in siafs:
name = '%s%d' % (si.name, len(mds))
offset = uxreplace(iaf, {si: d.root})
md = ModuloDimension(name, si, offset, size, origin=iaf)

key = lambda i: i.subs(si, 0) % size
mds[md] = [i for i in siafs if key(i) == key(iaf)]
mds.append(ModuloDimension(name, si, offset, size, origin=iaf))

# Replacement rule for ModuloDimensions
def rule(size, e):
Expand All @@ -320,11 +315,8 @@ def rule(size, e):
exprs = c.exprs
groups = as_mapper(mds, lambda d: d.modulo)
for size, v in groups.items():
mapper = {}
for md in v:
mapper.update({i: md for i in mds[md]})

func = partial(xreplace_indices, mapper=mapper, key=partial(rule, size))
subs = {md.origin: md for md in v}
func = partial(xreplace_indices, mapper=subs, key=partial(rule, size))
exprs = [e.apply(func) for e in exprs]

# Augment IterationSpace
Expand Down
13 changes: 11 additions & 2 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,12 @@ def distance(self, other):
# Indexed representing an arbitrary access along `x`, within the `t`
# IterationSpace, while the sink lives within the `tx` IterationSpace
if len(self.itintervals[n:]) != len(other.itintervals[n:]):
ret.append(S.Infinity)
return Vector(*ret)
v = Vector(*ret)
if v != 0:
return v
else:
ret.append(S.Infinity)
return Vector(*ret)

# It still could be an imaginary dependence, e.g. `a[3] -> a[4]` or, more
# nasty, `a[i+1, 3] -> a[i, 4]`
Expand Down Expand Up @@ -562,6 +566,11 @@ def is_lex_equal(self):
"""
return self.source.timestamp == self.sink.timestamp

@cached_property
def is_lex_ne(self):
"""True if the source's and sink's timestamps differ, False otherwise."""
return self.source.timestamp != self.sink.timestamp

@cached_property
def is_lex_negative(self):
"""
Expand Down
5 changes: 4 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from devito.mpi import MPI
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, unevaluate)
generate_macros, minimize_symbols, unevaluate)
from devito.symbolics import estimate_cost
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
Expand Down Expand Up @@ -458,6 +458,9 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
# Extract the necessary macros from the symbolic objects
generate_macros(graph)

# Target-independent optimizations
minimize_symbols(graph)

return graph.root, graph

# Read-only properties exposed to the outside world
Expand Down
83 changes: 63 additions & 20 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from devito.finite_differences import IndexDerivative
from devito.ir import Interval, IterationSpace
from devito.ir import Interval, IterationSpace, Queue
from devito.passes.clusters.misc import fuse
from devito.symbolics import (retrieve_dimensions, reuse_if_untouched, q_leaf,
uxreplace)
Expand All @@ -11,43 +11,44 @@

@timed_pass()
def lower_index_derivatives(clusters, mode=None, **kwargs):
clusters, weights = _lower_index_derivatives(clusters, **kwargs)
clusters, weights, mapper = _lower_index_derivatives(clusters, **kwargs)

if not weights:
return clusters

if mode != 'noop':
clusters = fuse(clusters, toposort='maximal')

clusters = CDE(mapper).process(clusters)

return clusters


def _lower_index_derivatives(clusters, sregistry=None, **kwargs):
processed = []
weights = {}
processed = []
mapper = {}

def dump(exprs, c):
if exprs:
processed.append(c.rebuild(exprs=exprs))
exprs[:] = []

for c in clusters:

exprs = []
seen = {}
for e in c.exprs:
expr, v = _lower_index_derivatives_core(e, c, weights, seen, sregistry)
expr, v = _core(e, c, weights, mapper, sregistry)
if v:
dump(exprs, c)
processed.extend(v)
exprs.append(expr)
processed.extend(v)

dump(exprs, c)

return processed, weights
return processed, weights, mapper


def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
def _core(expr, c, weights, mapper, sregistry):
"""
Recursively carry out the core of `lower_index_derivatives`.
"""
Expand All @@ -57,7 +58,7 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
args = []
processed = []
for a in expr.args:
e, clusters = _lower_index_derivatives_core(a, c, weights, seen, sregistry)
e, clusters = _core(a, c, weights, mapper, sregistry)
args.append(e)
processed.extend(clusters)

Expand All @@ -76,12 +77,6 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
w = weights[k] = w0._rebuild(name=name)
expr = uxreplace(expr, {w0.indexed: w.indexed})

# Have I seen this IndexDerivative already?
try:
return seen[expr], []
except KeyError:
pass

dims = retrieve_dimensions(expr, deep=True)
dims = filter_ordered(d for d in dims if isinstance(d, StencilDimension))

Expand All @@ -91,7 +86,7 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
# upper and lower offsets, we honor it
dims = tuple(d for d in dims if d not in c.ispace)

intervals = [Interval(d, 0, 0) for d in dims]
intervals = [Interval(d) for d in dims]
ispace0 = IterationSpace(intervals)

extra = (c.ispace.itdimensions + dims,)
Expand All @@ -103,13 +98,61 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
ispace1 = ispace.project(lambda d: d is not dims[-1])
processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1))

# Track IndexDerivative to avoid intra-Cluster duplicates
seen[expr] = s

# Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the
# StencilDimensions starting points
subs = {expr.weights: expr.weights.subs(d, d - d._min) for d in dims}
expr1 = Inc(s, uxreplace(expr.expr, subs))
processed.append(c.rebuild(exprs=expr1, ispace=ispace))

# Track lowered IndexDerivative for subsequent optimization by the caller
mapper.setdefault(expr1.rhs, []).append(s)

return s, processed


class CDE(Queue):

"""
Common derivative elimination.
"""

def __init__(self, mapper):
super().__init__()

self.mapper = {k: v for k, v in mapper.items() if len(v) > 1}

def process(self, clusters):
return self._process_fdta(clusters, 1, subs0={}, seen=set())

def callback(self, clusters, prefix, subs0=None, seen=None):
subs = {}
processed = []
for c in clusters:
if c in seen:
processed.append(c)
continue

exprs = []
for e in c.exprs:
k, v = e.args

if k in subs0:
continue

try:
subs0[k] = subs[v]
continue
except KeyError:
pass

if v in self.mapper:
subs[v] = k
exprs.append(e)
else:
exprs.append(uxreplace(e, {**subs0, **subs}))

processed.append(c.rebuild(exprs=exprs))

seen.update(processed)

return processed
63 changes: 47 additions & 16 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import Counter, defaultdict
from itertools import groupby, product

from devito.finite_differences import IndexDerivative
from devito.ir.clusters import Cluster, ClusterGroup, Queue, cluster_pass
from devito.ir.support import (SEQUENTIAL, SEPARABLE, Scope, ReleaseLock,
WaitLock, WithLock, FetchUpdate, PrefetchUpdate)
Expand Down Expand Up @@ -145,19 +146,34 @@ def callback(self, cgroups, prefix):
else:
return [ClusterGroup(processed, prefix)]

def _key(self, c):
# Two Clusters/ClusterGroups are fusion candidates if their key is identical
class Key(tuple):

key = (frozenset(c.ispace.itintervals),)
"""
A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that
two Clusters (ClusterGroups) are topo-fusible if and only if their Key is
identical.
A Key contains several elements that can logically be split into two
groups -- the `strict` and the `weak` components of the Key.
Two Clusters (ClusterGroups) having same `strict` but different `weak` parts
are, as by definition, not fusible; however, since at least their `strict`
parts match, they can at least be topologically reordered.
"""

# If there are writes to thread-shared object, make it part of the key.
# This will promote fusion of non-adjacent Clusters writing to (some form of)
# shared memory, which in turn will minimize the number of necessary barriers
key += (any(f._mem_shared for f in c.scope.writes),)
# Same story for reads from thread-shared objects
key += (any(f._mem_shared for f in c.scope.reads),)
def __new__(cls, strict, weak):
obj = super().__new__(cls, strict + weak)
obj.strict = tuple(strict)
obj.weak = tuple(weak)

return obj

def _key(self, c):
strict = []

key += (c.guards if any(c.guards) else None,)
strict.extend([
frozenset(c.ispace.itintervals),
c.guards if any(c.guards) else None
])

# We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and
# WithLocks, but not with any other SyncOps
Expand All @@ -180,13 +196,28 @@ def _key(self, c):
mapper[k].add(type(s))
else:
mapper[k].add(s)
mapper[k] = frozenset(mapper[k])
if any(mapper.values()):
mapper = frozendict(mapper)
key += (mapper,)
if k in mapper:
mapper[k] = frozenset(mapper[k])
strict.append(frozendict(mapper))

weak = []

# Clusters representing HaloTouches should get merged, if possible
key += (c.is_halo_touch,)
weak.append(c.is_halo_touch)

# If there are writes to thread-shared object, make it part of the key.
# This will promote fusion of non-adjacent Clusters writing to (some form of)
# shared memory, which in turn will minimize the number of necessary barriers
# Same story for reads from thread-shared objects
weak.extend([
any(f._mem_shared for f in c.scope.writes),
any(f._mem_shared for f in c.scope.reads)
])

# Promoting adjacency of IndexDerivatives will maximize their reuse
weak.append(any(e.find(IndexDerivative) for e in c.exprs))

key = self.Key(strict, weak)

return key

Expand Down Expand Up @@ -236,7 +267,7 @@ def dump():
def _toposort(self, cgroups, prefix):
# Are there any ClusterGroups that could potentially be fused? If
# not, do not waste time computing a new topological ordering
counter = Counter(self._key(cg) for cg in cgroups)
counter = Counter(self._key(cg).strict for cg in cgroups)
if not any(v > 1 for it, v in counter.most_common()):
return ClusterGroup(cgroups, prefix)

Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _(f, indexeds, tracker, strides, sregistry):

if len(i.indices) == i.function.ndim:
v = tuple(strides.values())[-n:]
subs[i] = FIndexed(i, pname, strides=v)
subs[i] = FIndexed.from_indexed(i, pname, strides=v)
else:
# Honour custom indexing
subs[i] = i.base[sum(i.indices)]
Expand Down
Loading

0 comments on commit 5f1ff64

Please sign in to comment.