Skip to content

Commit

Permalink
api: process injected expression dimensions in case it's not the spar…
Browse files Browse the repository at this point in the history
…se function
  • Loading branch information
mloubout committed Sep 15, 2023
1 parent 76a3c4b commit 610eeb2
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 33 deletions.
7 changes: 6 additions & 1 deletion devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
symbolic_max=d.symbolic_max + h.right)
eqs = [eq.xreplace(subs) for eq in eqs]

dv.Operator(eqs, name=name, **kwargs)()
op = dv.Operator(eqs, name=name, **kwargs)
try:
op()
except ValueError:
# Corner case such as assign(u, v) with v a Buffered TimeFunction
op(time_M=f._time_size)


def smooth(f, g, axis=None):
Expand Down
23 changes: 16 additions & 7 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,17 @@ def _rdim(self):

return DimensionTuple(*rdims, getters=self._gdims)

def _augment_implicit_dims(self, implicit_dims):
def _augment_implicit_dims(self, implicit_dims, extras=None):
if extras is not None:
extra = set([i for v in extras for i in v.dimensions]) - set(self._gdims)
extra = tuple(extra)
else:
extra = tuple()

if self.sfunction._sparse_position == -1:
return self.sfunction.dimensions + as_tuple(implicit_dims)
return self.sfunction.dimensions + as_tuple(implicit_dims) + extra
else:
return as_tuple(implicit_dims) + self.sfunction.dimensions
return as_tuple(implicit_dims) + self.sfunction.dimensions + extra

def _coeff_temps(self, implicit_dims):
return []
Expand Down Expand Up @@ -252,8 +258,6 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
interpolation expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Derivatives must be evaluated before the introduction of indirect accesses
try:
_expr = expr.evaluate
Expand All @@ -263,6 +267,9 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):

variables = list(retrieve_function_carriers(_expr))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims)

# List of indirection indices for all adjacent grid points
idx_subs, temps = self._interp_idx(variables, implicit_dims=implicit_dims)

Expand Down Expand Up @@ -295,8 +302,6 @@ def _inject(self, field, expr, implicit_dims=None):
injection expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Make iterable to support inject((u, v), expr=expr)
# or inject((u, v), expr=(expr1, expr2))
fields, exprs = as_tuple(field), as_tuple(expr)
Expand All @@ -315,6 +320,10 @@ def _inject(self, field, expr, implicit_dims=None):
_exprs = exprs

variables = list(v for e in _exprs for v in retrieve_function_carriers(e))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)

variables = variables + list(fields)

# List of indirection indices for all adjacent grid points
Expand Down
8 changes: 8 additions & 0 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,13 +566,18 @@ def _prepare_arguments(self, autotune=None, **kwargs):
"`%s=%s`, while `%s=%s` is expected. Perhaps you "
"forgot to override `%s`?" %
(p, k, v, k, args[k], p))

args = kwargs['args'] = args.reduce_all()

# DiscreteFunctions may be created from CartesianDiscretizations, which in
# turn could be Grids or SubDomains. Both may provide arguments
discretizations = {getattr(kwargs[p.name], 'grid', None) for p in overrides}
discretizations.update({getattr(p, 'grid', None) for p in defaults})
discretizations.discard(None)
# Remove subgrids if multiple grids
if len(discretizations) > 1:
discretizations = {g for g in discretizations
if not any(d.is_Derived for d in g.dimensions)}
for i in discretizations:
args.update(i._arg_values(**kwargs))

Expand All @@ -585,6 +590,9 @@ def _prepare_arguments(self, autotune=None, **kwargs):
if configuration['mpi']:
raise ValueError("Multiple Grids found")
try:
# Take biggest grid, i.e discard grids with subdimensions
grids = {g for g in grids if not any(d.is_Sub for d in g.dimensions)}
# First grid as there is no heuristic on how to choose from the leftovers
grid = grids.pop()
except KeyError:
grid = None
Expand Down
3 changes: 3 additions & 0 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def compare_to_first(v):
return candidates[0]
elif all(map(compare_to_first, candidates)):
# return first non-range
for c in candidates:
if not isinstance(c, range):
return c
return candidates[0]
else:
raise ValueError("Unable to find unique value for key %s, candidates: %s"
Expand Down
15 changes: 6 additions & 9 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,14 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs):
# may represent sets of legal values. If that's the case, here we just
# pick one. Note that we sort for determinism
try:
loc_minv = loc_minv.start
loc_minv = loc_minv.stop
except AttributeError:
try:
loc_minv = sorted(loc_minv).pop(0)
except TypeError:
pass
try:
loc_maxv = loc_maxv.start
loc_maxv = loc_maxv.stop
except AttributeError:
try:
loc_maxv = sorted(loc_maxv).pop(0)
Expand Down Expand Up @@ -859,7 +859,8 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
factor = defaults[dim._factor.name] = dim._factor.data
except AttributeError:
factor = dim._factor
defaults[dim.parent.max_name] = range(1, factor*(size))

defaults[dim.parent.max_name] = range(1, factor*size - 1)

return defaults

Expand Down Expand Up @@ -983,8 +984,7 @@ def bound_symbols(self):
return set(self.parent.bound_symbols)

def _arg_defaults(self, alias=None, **kwargs):
dim = alias or self
return {dim.parent.size_name: range(self.symbolic_size, np.iinfo(np.int64).max)}
return {}

def _arg_values(self, *args, **kwargs):
return {}
Expand Down Expand Up @@ -1466,10 +1466,7 @@ def _arg_defaults(self, _min=None, size=None, **kwargs):
A SteppingDimension does not know its max point and therefore
does not have a size argument.
"""
args = {self.parent.min_name: _min}
if size:
args[self.parent.size_name] = range(size-1, np.iinfo(np.int32).max)
return args
return {self.parent.min_name: _min}

def _arg_values(self, *args, **kwargs):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_over_injection():

# Check generated code
assert len(retrieve_iteration_tree(op1)) == \
7 + int(configuration['language'] != 'C')
8 + int(configuration['language'] != 'C')
buffers = [i for i in FindSymbols().visit(op1) if i.is_Array]
assert len(buffers) == 1

Expand Down
3 changes: 1 addition & 2 deletions tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,7 @@ def test_issue_1927(self, factor):

op = Operator(Eq(f, 1))

assert op.arguments()['time_M'] == 4*(save-1) # == min legal endpoint
assert op.arguments()['time_M'] == 4*save-1 # == min legal endpoint

# Also no issues when supplying an override
assert op.arguments(time_M=10)['time_M'] == 10
Expand All @@ -1530,7 +1530,6 @@ def test_issue_1927_v2(self):
i = Dimension(name='i')

ci = ConditionalDimension(name='ci', parent=i, factor=factor)

g = Function(name='g', shape=(size,), dimensions=(i,))
f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,))

Expand Down
11 changes: 8 additions & 3 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,14 @@ def test_cache_blocking_structure_optrelax():

op = Operator(eqns, opt=('advanced', {'blockrelax': True}))

bns, _ = assert_blocking(op, {'x0_blk0', 'p_src0_blk0'})
bns, _ = assert_blocking(op, {'x0_blk0', 'p_src0_blk0', 'p_src1_blk0'})

iters = FindNodes(Iteration).visit(bns['p_src0_blk0'])
assert len(iters) == 2
assert iters[0].dim.is_Block
assert iters[1].dim.is_Block

iters = FindNodes(Iteration).visit(bns['p_src1_blk0'])
assert len(iters) == 5
assert iters[0].dim.is_Block
assert iters[1].dim.is_Block
Expand Down Expand Up @@ -286,7 +291,7 @@ def test_cache_blocking_structure_optrelax_prec_inject():
'openmp': True,
'par-collapse-ncores': 1}))

assert_structure(op, ['t', 't,p_s0_blk0,p_s,rsx,rsy'],
assert_structure(op, ['t', 't,p_s0_blk0,p_s', 't,p_s0_blk0,p_s,rsx,rsy'],
't,p_s0_blk0,p_s,rsx,rsy')


Expand Down Expand Up @@ -958,7 +963,7 @@ def test_parallel_prec_inject(self):
iterations = FindNodes(Iteration).visit(op0)

assert not iterations[0].pragmas
assert 'omp for collapse(2)' in iterations[1].pragmas[0].value
assert 'omp for' in iterations[1].pragmas[0].value


class TestNestedParallelism(object):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def test_scheduling_after_rewrite():
trees = retrieve_iteration_tree(op)

# Check loop nest structure
assert all(i.dim is j for i, j in zip(trees[0], grid.dimensions)) # time invariant
assert trees[1].root.dim is grid.time_dim
assert all(trees[1].root.dim is tree.root.dim for tree in trees[1:])
assert all(i.dim is j for i, j in zip(trees[1], grid.dimensions)) # time invariant
assert trees[2].root.dim is grid.time_dim
assert all(trees[2].root.dim is tree.root.dim for tree in trees[2:])


@pytest.mark.parametrize('exprs,expected,min_cost', [
Expand Down Expand Up @@ -1687,7 +1687,7 @@ def test_drop_redundants_after_fusion(self, rotate):
op = Operator(eqns, opt=('advanced', {'cire-rotate': rotate}))

arrays = [i for i in FindSymbols().visit(op) if i.is_Array]
assert len(arrays) == 2
assert len(arrays) == 4
assert all(i._mem_heap and not i._mem_external for i in arrays)

def test_full_shape_big_temporaries(self):
Expand Down Expand Up @@ -2773,8 +2773,8 @@ def test_fullopt(self):
assert np.allclose(self.tti_noopt[1].data, rec.data, atol=10e-1)

# Check expected opcount/oi
assert summary[('section2', None)].ops == 92
assert np.isclose(summary[('section2', None)].oi, 2.074, atol=0.001)
assert summary[('section3', None)].ops == 92
assert np.isclose(summary[('section3', None)].oi, 2.074, atol=0.001)

# With optimizations enabled, there should be exactly four BlockDimensions
op = wavesolver.op_fwd()
Expand All @@ -2792,7 +2792,7 @@ def test_fullopt(self):
# 3 Arrays are defined globally for the sparse positions temporaries
# and two additional bock-sized Arrays are defined locally
arrays = [i for i in FindSymbols().visit(op) if i.is_Array]
extra_arrays = 2+3
extra_arrays = 2+3+3
assert len(arrays) == 4 + extra_arrays
assert all(i._mem_heap and not i._mem_external for i in arrays)
bns, pbs = assert_blocking(op, {'x0_blk0'})
Expand Down Expand Up @@ -2828,7 +2828,7 @@ def test_fullopt_w_mpi(self):
def test_opcounts(self, space_order, expected):
op = self.tti_operator(opt='advanced', space_order=space_order)
sections = list(op.op_fwd()._profiler._sections.values())
assert sections[2].sops == expected
assert sections[3].sops == expected

@switchconfig(profiling='advanced')
@pytest.mark.parametrize('space_order,expected', [
Expand All @@ -2838,8 +2838,8 @@ def test_opcounts_adjoint(self, space_order, expected):
wavesolver = self.tti_operator(opt=('advanced', {'openmp': False}))
op = wavesolver.op_adj()

assert op._profiler._sections['section2'].sops == expected
assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == 7+3
assert op._profiler._sections['section3'].sops == expected
assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == 7+3+3


class TestTTIv2(object):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,29 @@ class SparseFirst(SparseFunction):
op(time_M=10)
expected = 10*11/2 # n (n+1)/2
assert np.allclose(s.data, expected)


def test_inject_function():
nt = 11

grid = Grid(shape=(5, 5))
u = TimeFunction(name="u", grid=grid, time_order=2)
src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1,
coordinates=[[0.5, 0.5]])

nfreq = 5
freq_dim = DefaultDimension(name="freq", default_value=nfreq)
omega = Function(name="omega", dimensions=(freq_dim,), shape=(nfreq,), grid=grid)
omega.data.fill(1.)

inj = src.inject(field=u.forward, expr=omega)

op = Operator([inj])

op(time_M=0)
assert u.data[1, 2, 2] == nfreq
assert np.all(u.data[0] == 0)
assert np.all(u.data[2] == 0)
for i in [0, 1, 3, 4]:
for j in [0, 1, 3, 4]:
assert u.data[1, i, j] == 0

0 comments on commit 610eeb2

Please sign in to comment.