Skip to content

Commit

Permalink
Merge pull request #2300 from devitocodes/sycl-moar
Browse files Browse the repository at this point in the history
compiler: Misc codegen enhancements
  • Loading branch information
FabioLuporini committed Jan 31, 2024
2 parents 436ffa2 + 446c97d commit 0ccf5fd
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 24 deletions.
13 changes: 9 additions & 4 deletions devito/arch/archinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,11 @@ class IntelDevice(Device):

max_mem_trans_nbytes = 64

def __init__(self, *args, sub_group_size=32, **kwargs):
super().__init__(*args, **kwargs)

self.sub_group_size = sub_group_size

@property
def march(self):
return ''
Expand Down Expand Up @@ -894,10 +899,10 @@ def march(cls):
AMDGPUX = AmdDevice('amdgpuX')

INTELGPUX = IntelDevice('intelgpuX')
PVC = IntelDevice('pvc', max_threads_per_block=4096) # Legacy codename for MAX GPUs
INTELGPUMAX = IntelDevice('intelgpuMAX', max_threads_per_block=4096)
MAX1100 = IntelDevice('max1100', max_threads_per_block=4096)
MAX1550 = IntelDevice('max1550', max_threads_per_block=4096)
PVC = IntelDevice('pvc') # Legacy codename for MAX GPUs
INTELGPUMAX = IntelDevice('intelgpuMAX')
MAX1100 = IntelDevice('max1100')
MAX1550 = IntelDevice('max1550')

platform_registry = Platform.registry
platform_registry['cpu64'] = get_platform # Autodetection
Expand Down
50 changes: 40 additions & 10 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
Symbol)
from devito.types.object import AbstractObject, LocalObject

__all__ = ['Node', 'Block', 'Expression', 'Callable', 'Call', 'ExprStmt',
'Conditional', 'Iteration', 'List', 'Section', 'TimedList', 'Prodder',
'MetaCall', 'PointerCast', 'HaloSpot', 'Definition', 'ExpressionBundle',
'AugmentedExpression', 'Increment', 'Return', 'While', 'ListMajor',
'ParallelIteration', 'ParallelBlock', 'Dereference', 'Lambda',
'SyncSpot', 'Pragma', 'DummyExpr', 'BlankLine', 'ParallelTree',
'BusyWait', 'CallableBody', 'Transfer']
__all__ = ['Node', 'MultiTraversable', 'Block', 'Expression', 'Callable',
'Call', 'ExprStmt', 'Conditional', 'Iteration', 'List', 'Section',
'TimedList', 'Prodder', 'MetaCall', 'PointerCast', 'HaloSpot',
'Definition', 'ExpressionBundle', 'AugmentedExpression',
'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration',
'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma',
'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace',
'CallableBody', 'Transfer']

# First-class IET nodes

Expand Down Expand Up @@ -175,6 +176,15 @@ class ExprStmt(object):
pass


class MultiTraversable(Node):

"""
An abstract base class for Nodes comprising more than one traversable children.
"""

pass


class List(Node):

"""A sequence of Nodes."""
Expand Down Expand Up @@ -740,7 +750,7 @@ def defines(self):
return self.all_parameters


class CallableBody(Node):
class CallableBody(MultiTraversable):

"""
The immediate child of a Callable.
Expand Down Expand Up @@ -1057,7 +1067,7 @@ class Lambda(Node):
A callable C++ lambda function. Several syntaxes are possible; here we
implement one of the common ones:
[captures](parameters){body}
[captures](parameters){body} SPECIAL [[attributes]]
For more info about C++ lambda functions:
Expand All @@ -1071,14 +1081,21 @@ class Lambda(Node):
The captures of the lambda function.
parameters : list of Basic or expr-like, optional
The objects in input to the lambda function.
special : list of Basic, optional
Placeholder for custom lambdas, to add in e.g. macros.
attributes : list of str, optional
The attributes of the lambda function.
"""

_traversable = ['body']

def __init__(self, body, captures=None, parameters=None):
def __init__(self, body, captures=None, parameters=None, special=None,
attributes=None):
self.body = as_tuple(body)
self.captures = as_tuple(captures)
self.parameters = as_tuple(parameters)
self.special = as_tuple(special)
self.attributes = as_tuple(attributes)

def __repr__(self):
return "Lambda[%s](%s)" % (self.captures, self.parameters)
Expand Down Expand Up @@ -1178,6 +1195,19 @@ def periodic(self):
return self._periodic


class UsingNamespace(Node):

"""
A C++ using namespace directive.
"""

def __init__(self, namespace):
self.namespace = namespace

def __repr__(self):
return "<UsingNamespace(%s)>" % self.namespace


class Pragma(Node):

"""
Expand Down
17 changes: 14 additions & 3 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def visit_Callable(self, o):
signature = self._gen_signature(o)
return c.FunctionBody(signature, c.Block(body))

def visit_CallableBody(self, o):
def visit_MultiTraversable(self, o):
body = []
prev = None
for i in o.children:
Expand All @@ -585,6 +585,9 @@ def visit_CallableBody(self, o):
body.extend(as_tuple(v))
return c.Collection(body)

def visit_UsingNamespace(self, o):
return c.Statement('using namespace %s' % ccode(o.namespace))

def visit_Lambda(self, o):
body = []
for i in o.children:
Expand All @@ -595,7 +598,15 @@ def visit_Lambda(self, o):
body.extend(as_tuple(v))
captures = [str(i) for i in o.captures]
decls = [i.inline() for i in self._args_decl(o.parameters)]
top = c.Line('[%s](%s)' % (', '.join(captures), ', '.join(decls)))
extra = []
if o.special:
extra.append(' ')
extra.append(' '.join(str(i) for i in o.special))
if o.attributes:
extra.append(' ')
extra.append(' '.join('[[%s]]' % i for i in o.attributes))
top = c.Line('[%s](%s)%s' %
(', '.join(captures), ', '.join(decls), ''.join(extra)))
return LambdaCollection([top, c.Block(body)])

def visit_HaloSpot(self, o):
Expand Down Expand Up @@ -677,7 +688,7 @@ def visit_Operator(self, o, mode='all'):
includes = self._operator_includes(o) + [blankline]

# Namespaces
namespaces = [c.Statement('using namespace %s' % i) for i in o._namespaces]
namespaces = [self._visit(i) for i in o._namespaces]
if namespaces:
namespaces.append(blankline)

Expand Down
17 changes: 11 additions & 6 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
__all__ = ['CondEq', 'CondNe', 'IntDiv', 'CallFromPointer', # noqa
'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite',
'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction',
'InlineIf', 'ReservedWord', 'Keyword', 'String', 'Macro', 'Class',
'MacroArgument', 'CustomType', 'Deref', 'Namespace', 'Rvalue',
'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', 'SizeOf', 'rfunc',
'cast_mapper', 'BasicWrapperMixin']
'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String',
'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref',
'Namespace', 'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null',
'SizeOf', 'rfunc', 'cast_mapper', 'BasicWrapperMixin']


class CondEq(sympy.Eq):
Expand Down Expand Up @@ -541,8 +541,7 @@ class DefFunction(Function, Pickable):
https://github.com/sympy/sympy/issues/4297
"""

__rargs__ = ('name', 'arguments')
__rkwargs__ = ('template',)
__rargs__ = ('name', 'arguments', 'template')

def __new__(cls, name, arguments=None, template=None, **kwargs):
if isinstance(name, str):
Expand Down Expand Up @@ -609,6 +608,12 @@ def _sympystr(self, printer):
__reduce_ex__ = Pickable.__reduce_ex__


class MathFunction(DefFunction):

# Supposed to involve real operands
is_commutative = True


class InlineIf(sympy.Expr, Pickable):

"""
Expand Down
7 changes: 6 additions & 1 deletion devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Iterable
from functools import singledispatch

from sympy import Pow, Add, Mul, Min, Max, SympifyError, Tuple, sympify
from sympy import Pow, Add, Mul, Min, Max, S, SympifyError, Tuple, sympify
from sympy.core.add import _addsort
from sympy.core.mul import _mulsort

Expand Down Expand Up @@ -146,6 +146,11 @@ def _(expr, args, kwargs):

@_uxreplace_handle.register(Mul)
def _(expr, args, kwargs):
# Perform some basic simplifications at least
args = [i for i in args if i != 1]
if any(i == 0 for i in args):
return S.Zero

if all(i.is_commutative for i in args):
_mulsort(args)
_eval_numbers(expr, args)
Expand Down
2 changes: 2 additions & 0 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def _print_DefFunction(self, expr):
template = ''
return "%s%s(%s)" % (expr.name, template, ','.join(arguments))

_print_MathFunction = _print_DefFunction

def _print_Fallback(self, expr):
return expr.__str__()

Expand Down
20 changes: 20 additions & 0 deletions tests/test_iet.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,26 @@ def test_list_denesting():
assert str(l3) == str(l2)


def test_lambda():
grid = Grid(shape=(4, 4, 4))
x, y, z = grid.dimensions

u = Function(name='u', grid=grid)

e0 = DummyExpr(u.indexed, 1)
e1 = DummyExpr(u.indexed, 2)

body = List(body=[e0, e1])
lam = Lambda(body, ['='], [u.indexed], attributes=['my_attr'])

assert str(lam) == """\
[=](float *restrict u) [[my_attr]]
{
u = 1;
u = 2;
}"""


def test_make_cpp_parfor():
"""
Test construction of a C++ parallel for. This excites the IET construction
Expand Down

0 comments on commit 0ccf5fd

Please sign in to comment.