Skip to content

Commit

Permalink
compiler: Refine topo-fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Apr 18, 2023
1 parent 73fa6c8 commit 959cbb4
Showing 1 changed file with 42 additions and 16 deletions.
58 changes: 42 additions & 16 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,19 +146,34 @@ def callback(self, cgroups, prefix):
else:
return [ClusterGroup(processed, prefix)]

def _key(self, c):
# Two Clusters/ClusterGroups are fusion candidates if their key is identical
class Key(tuple):

key = (frozenset(c.ispace.itintervals),)
"""
A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that
two Clusters (ClusterGroups) are topo-fusible if and only if their Key is
identical.
A Key contains several elements that can logically be split into two
groups -- the `strict` and the `weak` components of the Key.
Two Clusters (ClusterGroups) having same `strict` but different `weak` parts
are, as by definition, not fusible; however, since at least their `strict`
parts match, they can at least be topologically reordered.
"""

# If there are writes to thread-shared object, make it part of the key.
# This will promote fusion of non-adjacent Clusters writing to (some form of)
# shared memory, which in turn will minimize the number of necessary barriers
key += (any(f._mem_shared for f in c.scope.writes),)
# Same story for reads from thread-shared objects
key += (any(f._mem_shared for f in c.scope.reads),)
def __new__(cls, strict, weak):
obj = super().__new__(cls, strict + weak)
obj.strict = tuple(strict)
obj.weak = tuple(weak)

key += (c.guards if any(c.guards) else None,)
return obj

def _key(self, c):
strict = []

strict.extend([
frozenset(c.ispace.itintervals),
c.guards if any(c.guards) else None
])

# We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and
# WithLocks, but not with any other SyncOps
Expand All @@ -182,15 +197,26 @@ def _key(self, c):
else:
mapper[k].add(s)
mapper[k] = frozenset(mapper[k])
if any(mapper.values()):
mapper = frozendict(mapper)
key += (mapper,)
strict.append(frozendict(mapper))

weak = []

# Clusters representing HaloTouches should get merged, if possible
key += (c.is_halo_touch,)
weak.append(c.is_halo_touch)

# If there are writes to thread-shared object, make it part of the key.
# This will promote fusion of non-adjacent Clusters writing to (some form of)
# shared memory, which in turn will minimize the number of necessary barriers
# Same story for reads from thread-shared objects
weak.extend([
any(f._mem_shared for f in c.scope.writes),
any(f._mem_shared for f in c.scope.reads)
])

# Promoting adjacency of IndexDerivatives will maximize their reuse
key += (any(e.find(IndexDerivative) for e in c.exprs),)
weak.append(any(e.find(IndexDerivative) for e in c.exprs))

key = self.Key(strict, weak)

return key

Expand Down Expand Up @@ -240,7 +266,7 @@ def dump():
def _toposort(self, cgroups, prefix):
# Are there any ClusterGroups that could potentially be fused? If
# not, do not waste time computing a new topological ordering
counter = Counter(self._key(cg) for cg in cgroups)
counter = Counter(self._key(cg).strict for cg in cgroups)
if not any(v > 1 for it, v in counter.most_common()):
return ClusterGroup(cgroups, prefix)

Expand Down

0 comments on commit 959cbb4

Please sign in to comment.