Skip to content

Commit

Permalink
compiler: Add support for C-level MPI_Allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Apr 4, 2024
1 parent 23ff475 commit ef35a51
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 178 deletions.
60 changes: 33 additions & 27 deletions devito/builtins/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

import devito as dv
from devito.builtins.utils import MPIReduction
from devito.builtins.utils import make_retval


__all__ = ['norm', 'sumall', 'sum', 'inner', 'mmin', 'mmax']
Expand Down Expand Up @@ -44,15 +44,15 @@ def norm(f, order=2):
p, eqns = f.guard() if f.is_SparseFunction else (f, [])

dtype = accumulator_mapper[f.dtype]
n = make_retval(f.grid, dtype)
s = dv.types.Symbol(name='sum', dtype=dtype)

with MPIReduction(f, dtype=dtype) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, dv.Abs(Pow(p, order))), dv.Eq(mr.n[0], s)],
name='norm%d' % order)
op.apply(**kwargs)
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, dv.Abs(Pow(p, order))), dv.Eq(n[0], s)],
name='norm%d' % order)
op.apply(**kwargs)

v = np.power(mr.v, 1/order)
v = np.power(n.data[0], 1/order)

return f.dtype(v)

Expand Down Expand Up @@ -129,15 +129,15 @@ def sumall(f):
p, eqns = f.guard() if f.is_SparseFunction else (f, [])

dtype = accumulator_mapper[f.dtype]
n = make_retval(f.grid, dtype)
s = dv.types.Symbol(name='sum', dtype=dtype)

with MPIReduction(f, dtype=dtype) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, p), dv.Eq(mr.n[0], s)],
name='sum')
op.apply(**kwargs)
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, p), dv.Eq(n[0], s)],
name='sum')
op.apply(**kwargs)

return f.dtype(mr.v)
return f.dtype(n.data[0])


@dv.switchconfig(log_level='ERROR')
Expand Down Expand Up @@ -184,15 +184,15 @@ def inner(f, g):
rhs, eqns = f.guard(f*g) if f.is_SparseFunction else (f*g, [])

dtype = accumulator_mapper[f.dtype]
n = make_retval(f.grid or g.grid, dtype)
s = dv.types.Symbol(name='sum', dtype=dtype)

with MPIReduction(f, g, dtype=dtype) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, rhs), dv.Eq(mr.n[0], s)],
name='inner')
op.apply(**kwargs)
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, rhs), dv.Eq(n[0], s)],
name='inner')
op.apply(**kwargs)

return f.dtype(mr.v)
return f.dtype(n.data[0])


@dv.switchconfig(log_level='ERROR')
Expand All @@ -208,11 +208,14 @@ def mmin(f):
if isinstance(f, dv.Constant):
return f.data
elif isinstance(f, dv.types.dense.DiscreteFunction):
with MPIReduction(f, op=dv.mpi.MPI.MIN) as mr:
mr.n.data[0] = np.min(f.data_ro_domain).item()
return mr.v.item()
v = np.min(f.data_ro_domain)
if f.grid is None or not dv.configuration['mpi']:
return v.item()
else:
comm = f.grid.distributor.comm
return comm.allreduce(v, dv.mpi.MPI.MIN).item()
else:
raise ValueError("Expected Function, not `%s`" % type(f))
raise ValueError("Expected Function, got `%s`" % type(f))


@dv.switchconfig(log_level='ERROR')
Expand All @@ -228,8 +231,11 @@ def mmax(f):
if isinstance(f, dv.Constant):
return f.data
elif isinstance(f, dv.types.dense.DiscreteFunction):
with MPIReduction(f, op=dv.mpi.MPI.MAX) as mr:
mr.n.data[0] = np.max(f.data_ro_domain).item()
return mr.v.item()
v = np.max(f.data_ro_domain)
if f.grid is None or not dv.configuration['mpi']:
return v.item()
else:
comm = f.grid.distributor.comm
return comm.allreduce(v, dv.mpi.MPI.MAX).item()
else:
raise ValueError("Expected Function, not `%s`" % type(f))
raise ValueError("Expected Function, got `%s`" % type(f))
52 changes: 13 additions & 39 deletions devito/builtins/utils.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,26 @@
from functools import wraps

import numpy as np

import devito as dv
from devito.symbolics import uxreplace
from devito.tools import as_tuple

__all__ = ['MPIReduction', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args']
__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args']


class MPIReduction(object):
def make_retval(grid, dtype):
"""
A context manager to build MPI-aware reduction Operators.
Devito does not support passing values by reference. This function
creates a dummy Function of size 1 to store the return value of a builtin
applied to `f`.
"""

def __init__(self, *functions, op=dv.mpi.MPI.SUM, dtype=None):
grids = {f.grid for f in functions}
if len(grids) == 0:
self.grid = None
elif len(grids) == 1:
self.grid = grids.pop()
else:
raise ValueError("Multiple Grids found")
if dtype is not None:
self.dtype = dtype
else:
dtype = {f.dtype for f in functions}
if len(dtype) == 1:
self.dtype = np.result_type(dtype.pop(), np.float32).type
else:
raise ValueError("Illegal mixed data types")
self.v = None
self.op = op

def __enter__(self):
i = dv.Dimension(name='mri',)
self.n = dv.Function(name='n', shape=(1,), dimensions=(i,),
grid=self.grid, dtype=self.dtype, space='host')
self.n.data[:] = 0
return self

def __exit__(self, exc_type, exc_value, traceback):
if self.grid is None or not dv.configuration['mpi']:
assert self.n.data.size == 1
self.v = self.n.data[0]
else:
comm = self.grid.distributor.comm
self.v = comm.allreduce(np.asarray(self.n.data), self.op)[0]
if grid is None:
raise ValueError("Expected Grid, got None")

i = dv.Dimension(name='mri',)
n = dv.Function(name='n', shape=(1,), dimensions=(i,), grid=grid,
dtype=dtype, space='host')
n.data[:] = 0
return n


def nbl_to_padsize(nbl, ndim):
Expand Down
19 changes: 10 additions & 9 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,20 @@ def _normalize_gpu_fit(cls, oo, **kwargs):
return as_tuple(cls.GPU_FIT)

@classmethod
def _rcompile_wrapper(cls, **kwargs0):
options = kwargs0['options']
def _rcompile_wrapper(cls, **kwargs):
def wrapper(expressions, mode='default', **options):

def wrapper(expressions, mode='default', **kwargs1):
if mode == 'host':
kwargs = {**{
par_disabled = kwargs['options']['par-disabled']
target = {
'platform': 'cpu64',
'language': 'C' if options['par-disabled'] else 'openmp',
'compiler': 'custom',
}, **kwargs1}
'language': 'C' if par_disabled else 'openmp',
'compiler': 'custom'
}
else:
kwargs = {**kwargs0, **kwargs1}
return rcompile(expressions, kwargs)
target = None

return rcompile(expressions, kwargs, options, target=target)

return wrapper

Expand Down
75 changes: 29 additions & 46 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from devito.ir.clusters.cluster import Cluster, ClusterGroup
from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
from devito.mpi.reduction_scheme import DistributedReduction
from devito.mpi.reduction_scheme import DistReduce
from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace,
xreplace_indices)
from devito.tools import (DefaultOrderedDict, Stamp, as_mapper, flatten,
is_integer, timed_pass, toposort)
is_integer, split, timed_pass, toposort)
from devito.types import Array, Eq, Symbol
from devito.types.dimension import BOTTOM, ModuloDimension

Expand Down Expand Up @@ -378,26 +378,17 @@ def communications(clusters):
return clusters


class Comms(Queue):
#TODO: MAYBE DROP ME
class HaloComms(Queue):

"""
Abstract base class for injecting Clusters representing communications
for distributed-memory parallelism.
Inject Clusters representing halo exchanges for distributed-memory parallelism.
"""

_q_guards_in_key = True
_q_properties_in_key = True

B = Symbol(name='⊥')


class HaloComms(Comms):

"""
A specialization of Comms to handle halo exchanges.
"""

def process(self, clusters):
return self._process_fatd(clusters, 1, seen=set())

Expand Down Expand Up @@ -451,49 +442,41 @@ def callback(self, clusters, prefix, seen=None):


def reduction_comms(clusters):
# Detect the underlying Grid
#TODO: pretty rudimentary, but it's a start
for c in clusters:
try:
grid = c.grid
break
except ValueError:
continue
else:
return clusters

# Detect global reductions along the distributed Dimensions
found = {}
processed = []
fifo = []
for c in clusters:
if not any(grid.is_distributed(d) for d in c.ispace.itdims):
continue

# Schedule the global reductions encountered before `c`, if the
# IterationSpace of `c` is such that the reduction can be carried out
found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace))
if found:
exprs = [Eq(dr.var, dr) for dr in found]
processed.append(c.rebuild(exprs=exprs))

# Detect the global reductions in `c`
for e in c.exprs:
op = e.operation
if op is None:
if op is None or c.is_sparse:
continue
elif found.get(e.lhs, op) != op:
raise ValueError("Inconsistent reduction operations")
else:
found[e.lhs] = e.operation

# Place global reductions right before they're required
processed = []
for c in clusters:
for var, op in list(found.items()):
if var in c.scope.read_only:
expr = Eq(var, DistributedReduction(var, op=op, grid=grid))
processed.append(c.rebuild(exprs=expr))
var = e.lhs
grid = c.grid
if grid is None:
continue

# The IterationSpace within which the global reduction is carried out
ispace = c.ispace.project(lambda d: d in var.free_symbols)
if ispace.itdims == c.ispace.itdims:
# Inc/Max/Min/... being used for a non-reduction operation
continue

found.pop(var)
fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace))

processed.append(c)

# Leftover reductions are placed at the very end
while found:
var, op = found.popitem()
expr = Eq(var, DistributedReduction(var, op=op, grid=grid))
processed.append(Cluster(exprs=[expr], ispace=null_ispace))
if fifo:
exprs = [Eq(dr.var, dr) for dr in fifo]
processed.append(Cluster(exprs=exprs, ispace=null_ispace))

return processed

Expand Down
14 changes: 8 additions & 6 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
normalize_properties, normalize_syncs, minimum,
maximum, null_ispace)
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
from devito.mpi.reduction_scheme import DistributedReduction
from devito.mpi.reduction_scheme import DistReduce
from devito.symbolics import estimate_cost
from devito.tools import as_tuple, flatten, frozendict, infer_dtype
from devito.types import WeakFence, CriticalRegion
Expand Down Expand Up @@ -181,8 +181,11 @@ def has_increments(self):

@cached_property
def grid(self):
grids = set(f.grid for f in self.functions if f.is_DiscreteFunction) - {None}
if len(grids) == 1:
grids = set(f.grid for f in self.functions if f.is_AbstractFunction)
grids.discard(None)
if len(grids) == 0:
return None
elif len(grids) == 1:
return grids.pop()
else:
raise ValueError("Cluster has no unique Grid")
Expand Down Expand Up @@ -211,7 +214,7 @@ def is_dense(self):
dims = {d for d in self.properties if d._defines & target}
if any(pset & self.properties[d] for d in dims):
return True
except ValueError:
except (AttributeError, ValueError):
pass

# Fallback to legacy is_dense checks
Expand Down Expand Up @@ -241,8 +244,7 @@ def is_halo_touch(self):

@property
def is_dist_reduce(self):
return self.exprs and all(isinstance(e.rhs, DistributedReduction)
for e in self.exprs)
return self.exprs and all(isinstance(e.rhs, DistReduce) for e in self.exprs)

@property
def is_fence(self):
Expand Down
Loading

0 comments on commit ef35a51

Please sign in to comment.