Skip to content

Commit

Permalink
checkpointing: switch to noop operators if unavailable
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 3, 2023
1 parent eb3ca10 commit a5611fc
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 51 deletions.
6 changes: 4 additions & 2 deletions devito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from devito.data.allocators import * # noqa
from devito.logger import error, warning, info, set_log_level # noqa
from devito.mpi import MPI # noqa
from devito.checkpointing import pyrevolve # noqa
try:
from devito.checkpointing import DevitoCheckpoint, CheckpointOperator # noqa
from pyrevolve import Revolver
except ImportError:
pass
from devito.checkpointing import NoopCheckpoint as DevitoCheckpoint # noqa
from devito.checkpointing import NoopCheckpointOperator as CheckpointOperator # noqa
from devito.checkpointing import NoopRevolver as Revolver # noqa

# Imports required to initialize Devito
from devito.arch import compiler_registry, platform_registry
Expand Down
23 changes: 21 additions & 2 deletions devito/checkpointing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
try:
import pyrevolve as pyrevolve
import pyrevolve as pyrevolve # noqa
from .checkpoint import * # noqa
except ImportError:
pyrevolve = None
pass


class Noop(object):
""" Dummy replacement in case pyrevolve isn't available. """

def __init__(self, *args, **kwargs):
raise ImportError("Missing required `pyrevolve`; cannot use checkpointing")


class NoopCheckpointOperator(Noop):
pass


class NoopCheckpoint(Noop):
pass


class NoopRevolver(Noop):
pass
20 changes: 9 additions & 11 deletions examples/seismic/acoustic/wavesolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import devito
from devito import Function, TimeFunction, pyrevolve
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver
from devito.tools import memoized_meth
from examples.seismic.acoustic.operators import (
ForwardOperator, AdjointOperator, GradientOperator, BornOperator
Expand Down Expand Up @@ -194,20 +193,19 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
# Pick vp from model unless explicitly provided
kwargs.update(model.physical_params(**kwargs))

if checkpointing and pyrevolve is not None:
if checkpointing:
u = TimeFunction(name='u', grid=self.model.grid,
time_order=2, space_order=self.space_order)
cp = devito.DevitoCheckpoint([u])
cp = DevitoCheckpoint([u])
n_checkpoints = None
wrap_fw = devito.CheckpointOperator(self.op_fwd(save=False),
src=src or self.geometry.src,
u=u, dt=dt, **kwargs)
wrap_rev = devito.CheckpointOperator(self.op_grad(save=False), u=u, v=v,
rec=rec, dt=dt, grad=grad, **kwargs)
wrap_fw = CheckpointOperator(self.op_fwd(save=False),
src=src or self.geometry.src,
u=u, dt=dt, **kwargs)
wrap_rev = CheckpointOperator(self.op_grad(save=False), u=u, v=v,
rec=rec, dt=dt, grad=grad, **kwargs)

# Run forward
wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, n_checkpoints,
rec.data.shape[0]-2)
wrp = Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, rec.data.shape[0]-2)
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
Expand Down
20 changes: 9 additions & 11 deletions examples/seismic/tti/wavesolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding: utf-8
import devito
from devito import Function, TimeFunction, warning, pyrevolve
from devito import (Function, TimeFunction, warning,
DevitoCheckpoint, CheckpointOperator, Revolver)
from devito.tools import memoized_meth
from examples.seismic.tti.operators import ForwardOperator, AdjointOperator
from examples.seismic.tti.operators import JacobianOperator, JacobianAdjOperator
Expand Down Expand Up @@ -350,22 +350,20 @@ def jacobian_adjoint(self, rec, u0, v0, du=None, dv=None, dm=None, model=None,
if self.model.dim < 3:
kwargs.pop('phi', None)

if checkpointing and pyrevolve is not None:
if checkpointing:
u0 = TimeFunction(name='u0', grid=self.model.grid,
time_order=2, space_order=self.space_order)
v0 = TimeFunction(name='v0', grid=self.model.grid,
time_order=2, space_order=self.space_order)
cp = devito.DevitoCheckpoint([u0, v0])
cp = DevitoCheckpoint([u0, v0])
n_checkpoints = None
wrap_fw = devito.CheckpointOperator(self.op_fwd(save=False), u=u0, v=v0,
dt=dt, src=self.geometry.src, **kwargs)
wrap_rev = devito. CheckpointOperator(self.op_jacadj(save=False), u0=u0,
v0=v0, du=du, dv=dv, rec=rec, dm=dm,
dt=dt, **kwargs)
wrap_fw = CheckpointOperator(self.op_fwd(save=False), src=self.geometry.src,
u=u0, v=v0, dt=dt, **kwargs)
wrap_rev = CheckpointOperator(self.op_jacadj(save=False), u0=u0, v0=v0,
du=du, dv=dv, rec=rec, dm=dm, dt=dt, **kwargs)

# Run forward
wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, n_checkpoints,
rec.data.shape[0]-2)
wrp = Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, rec.data.shape[0]-2)
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
Expand Down
21 changes: 10 additions & 11 deletions examples/seismic/viscoacoustic/wavesolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import devito
from devito import VectorTimeFunction, TimeFunction, Function, NODE, pyrevolve
from devito import (VectorTimeFunction, TimeFunction, Function, NODE,
DevitoCheckpoint, CheckpointOperator, Revolver)
from devito.tools import memoized_meth
from examples.seismic import PointSource
from examples.seismic.viscoacoustic.operators import (
Expand Down Expand Up @@ -262,7 +262,7 @@ def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, va=None, model=No
# Pick vp and physical parameters from model unless explicitly provided
kwargs.update(model.physical_params(**kwargs))

if checkpointing and pyrevolve is not None:
if checkpointing:
if self.time_order == 1:
v = VectorTimeFunction(name="v", grid=self.model.grid,
time_order=self.time_order,
Expand All @@ -277,10 +277,10 @@ def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, va=None, model=No
space_order=self.space_order, staggered=NODE)

l = [p, r] + v.values() if self.time_order == 1 else [p, r]
cp = devito.DevitoCheckpoint(l)
cp = DevitoCheckpoint(l)
n_checkpoints = None
wrap_fw = devito.CheckpointOperator(self.op_fwd(save=False), p=p, r=r, dt=dt,
src=self.geometry.src, **kwargs)
wrap_fw = CheckpointOperator(self.op_fwd(save=False),
src=self.geometry.src, p=p, r=r, dt=dt, **kwargs)

ra = TimeFunction(name="ra", grid=self.model.grid, time_order=self.time_order,
space_order=self.space_order, staggered=NODE)
Expand All @@ -294,13 +294,12 @@ def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, va=None, model=No
kwargs.update({k.name: k for k in va})
kwargs['time_m'] = 0

wrap_rev = devito.CheckpointOperator(self.op_grad(save=False), p=p, pa=pa,
r=ra, rec=rec, dt=dt, grad=grad,
**kwargs)
wrap_rev = CheckpointOperator(self.op_grad(save=False), p=p, pa=pa, r=ra,
rec=rec, dt=dt, grad=grad, **kwargs)

# Run forward
ntchk = rec.data.shape[0] - (1 if self.time_order == 1 else 2)
wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, ntchk)
wrp = Revolver(cp, wrap_fw, wrap_rev, n_checkpoints,
rec.data.shape[0] - (1 if self.time_order == 1 else 2))
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest
import sys

from devito import Eq, configuration # noqa
from devito.checkpointing import pyrevolve
from devito import Eq, configuration, Revolver # noqa
from devito.checkpointing import NoopRevolver
from devito.finite_differences.differentiable import EvalDerivative
from devito.arch import Cpu64, Device, sniff_mpi_distro, Arm
from devito.arch.compiler import compiler_registry, IntelCompiler, NvidiaCompiler
Expand Down Expand Up @@ -73,7 +73,7 @@ def skipif(items, whole_module=False):
skipit = "Arm doesn't support x86-specific instructions"
break
# Skip if pyrevolve not installed
if i == 'chkpnt' and pyrevolve is None:
if i == 'chkpnt' and Revolver is NoopRevolver:
skipit = "pyrevolve not installed"
break

Expand Down
21 changes: 10 additions & 11 deletions tests/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import numpy as np

from conftest import skipif
import devito
from devito import (Grid, TimeFunction, Operator, Function, Eq, switchconfig, Constant,
pyrevolve)
Revolver, CheckpointOperator, DevitoCheckpoint)
from examples.seismic.acoustic.acoustic_example import acoustic_setup


Expand Down Expand Up @@ -131,12 +130,12 @@ def test_forward_with_breaks(shape, kernel, space_order):
dt = solver.model.critical_dt

u = TimeFunction(name='u', grid=grid, time_order=2, space_order=space_order)
cp = devito.DevitoCheckpoint([u])
wrap_fw = devito.CheckpointOperator(solver.op_fwd(save=False), rec=rec,
src=solver.geometry.src, u=u, dt=dt)
wrap_rev = devito.CheckpointOperator(solver.op_grad(save=False), u=u, dt=dt, rec=rec)
cp = DevitoCheckpoint([u])
wrap_fw = CheckpointOperator(solver.op_fwd(save=False), rec=rec,
src=solver.geometry.src, u=u, dt=dt)
wrap_rev = CheckpointOperator(solver.op_grad(save=False), u=u, dt=dt, rec=rec)

wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, None, rec._time_range.num-time_order)
wrp = Revolver(cp, wrap_fw, wrap_rev, None, rec._time_range.num-time_order)
rec1, u1, summary = solver.forward()

wrp.apply_forward()
Expand Down Expand Up @@ -229,13 +228,13 @@ def test_index_alignment():
# change equations to use new symbols
fwd_eqn_2 = Eq(u_nosave.forward, u_nosave + 1.*const)
fwd_op_2 = Operator(fwd_eqn_2)
cp = devito.DevitoCheckpoint([u_nosave])
wrap_fw = devito.CheckpointOperator(fwd_op_2, constant=1)
cp = DevitoCheckpoint([u_nosave])
wrap_fw = CheckpointOperator(fwd_op_2, constant=1)

prod_eqn_2 = Eq(prod, prod + u_nosave * v)
comb_op_2 = Operator([adj_eqn, prod_eqn_2])
wrap_rev = devito.CheckpointOperator(comb_op_2, constant=1)
wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, None, nt)
wrap_rev = CheckpointOperator(comb_op_2, constant=1)
wrp = Revolver(cp, wrap_fw, wrap_rev, None, nt)

# Invocation 4
wrp.apply_forward()
Expand Down

0 comments on commit a5611fc

Please sign in to comment.