Skip to content

Commit

Permalink
Remove jit from solvers (ott-jax#192)
Browse files Browse the repository at this point in the history
* Remove `jit` from `solvers`

* Remove traces of `jit` in notebooks, fix disc. bary

* Fix not passing correct dim in `NeuralDual`

* Re-introduce jitting in tests

* Fix typo

* Fix grad jitting tests

* Update docstrings
  • Loading branch information
michalk8 authored Dec 7, 2022
1 parent 8cf4de8 commit 70297ea
Show file tree
Hide file tree
Showing 31 changed files with 203 additions and 220 deletions.
1 change: 0 additions & 1 deletion docs/notebooks/MetaOT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,6 @@
" b = eval_batch.b[i]\n",
"\n",
" sink_kwargs = {\n",
" \"jit\": True,\n",
" \"threshold\": -1,\n",
" \"inner_iterations\": 1,\n",
" \"max_iterations\": 26,\n",
Expand Down
1 change: 0 additions & 1 deletion docs/notebooks/OTT_&_POT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@
" b,\n",
" threshold=threshold,\n",
" lse_mode=True,\n",
" jit=False,\n",
" max_iterations=1000,\n",
" )\n",
" f, g = out.f, out.g\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/notebooks/One_Sinkhorn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,10 @@
"import functools\n",
"\n",
"my_sinkorn = functools.partial(\n",
" sinkhorn,\n",
" sinkhorn.sinkhorn,\n",
" inner_iterations=1, # recomputing error every iteration for plots.\n",
" max_iterations=10000, # more iterations than the default setting to see full curves.\n",
" jit=True,\n",
") # force jit"
")"
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions docs/notebooks/Sinkhorn_Barycenters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,9 @@
],
"source": [
"%%time\n",
"discrete_barycenter_fn = jax.jit(discrete_barycenter.discrete_barycenter)\n",
"g_grid._epsilon.target = 1\n",
"barycenter = discrete_barycenter.discrete_barycenter(g_grid, a)"
"barycenter = discrete_barycenter_fn(g_grid, a)"
]
},
{
Expand All @@ -320,7 +321,7 @@
"source": [
"%%time\n",
"g_grid._epsilon.target = 1e-4\n",
"barycenter = discrete_barycenter.discrete_barycenter(g_grid, a)"
"barycenter = discrete_barycenter_fn(g_grid, a)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/fairness.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -601,7 +601,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down
6 changes: 4 additions & 2 deletions docs/notebooks/gromov_wasserstein.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@
"geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)\n",
"geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)\n",
"out = gw.gromov_wasserstein(\n",
" geom_xx=geom_xx, geom_yy=geom_yy, epsilon=100.0, max_iterations=20, jit=True\n",
" geom_xx=geom_xx,\n",
" geom_yy=geom_yy,\n",
" epsilon=100.0,\n",
" max_iterations=20,\n",
")\n",
"n_outer_iterations = jnp.sum(out.costs != -1)\n",
"has_converged = bool(out.linear_convergence[n_outer_iterations - 1])\n",
Expand Down Expand Up @@ -8106,7 +8109,6 @@
" geom_yy=geom_yy,\n",
" epsilon=i,\n",
" max_iterations=40,\n",
" jit=True,\n",
" sinkhorn_kwargs=config,\n",
" )\n",
" im.set_array(out.matrix)\n",
Expand Down
1 change: 0 additions & 1 deletion docs/notebooks/gromov_wasserstein_multiomics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@
" threshold=1e-9,\n",
" a=self.p,\n",
" b=self.q,\n",
" jit=True,\n",
" ).matrix\n",
"\n",
" constC, hC1, hC2 = init_matrix(\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/wasserstein_barycenters_gmms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@
"outputs": [],
"source": [
"# create a Wasserstein barycenter solver.\n",
"solver = continuous_barycenter.WassersteinBarycenter(lse_mode=True, jit=True)"
"solver = continuous_barycenter.WassersteinBarycenter(lse_mode=True)"
]
},
{
Expand Down
3 changes: 0 additions & 3 deletions src/ott/initializers/linear/initializers_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def from_solver(
sinkhorn_kwargs = {
"norm_error": lin_sol._norm_error,
"lse_mode": lin_sol.lse_mode,
"jit": lin_sol.jit,
"implicit_diff": lin_sol.implicit_diff,
"use_danskin": lin_sol.use_danskin
}
Expand Down Expand Up @@ -387,14 +386,12 @@ def _compute_factor(
from ott.tools import k_means

del kwargs
jit = self._sinkhorn_kwargs.get("jit", True)
fn = functools.partial(
k_means.k_means,
min_iterations=self._min_iter,
max_iterations=self._max_iter,
**self._kwargs
)
fn = jax.jit(fn, static_argnames="k") if jit else fn

if isinstance(ot_prob, quadratic_problem.QuadraticProblem):
geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy
Expand Down
5 changes: 2 additions & 3 deletions src/ott/solvers/linear/continuous_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,8 @@ def __call__(
x_init: Optional[jnp.ndarray] = None,
rng: int = 0
) -> BarycenterState:
bar_fn = jax.jit(iterations, static_argnums=1) if self.jit else iterations
out = bar_fn(self, bar_size, bar_prob, x_init, rng)
return out
# TODO(michalk8): no reason for iterations to be outside this class
return iterations(self, bar_size, bar_prob, x_init, rng)

def init_state(
self,
Expand Down
23 changes: 12 additions & 11 deletions src/ott/solvers/linear/discrete_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,18 @@ def discrete_barycenter(
"""Compute discrete barycenter :cite:`janati:20a`.

Args:
geom: a Cost object able to apply kernels with a certain epsilon.
a: jnp.ndarray<float>[batch, geom.num_a]: batch of histograms.
weights: jnp.ndarray of weights in the probability simplex
dual_initialization: jnp.ndarray, size [batch, num_b] initialization for g_v
threshold: (float) tolerance to monitor convergence.
norm_error: int, power used to define p-norm of error for marginal/target.
inner_iterations: (int32) the Sinkhorn error is not recomputed at each
geom: geometry object.
a: batch of histograms of shape ``[batch, num_a]``.
weights: positive weights in the probability simplex.
dual_initialization: array of shape ``[batch, num_b]`` for the
initialization of `g_v`.
threshold: tolerance to monitor convergence.
norm_error: power used to define p-norm of error for marginal/target.
inner_iterations: the Sinkhorn error is not recomputed at each
iteration but every inner_num_iter instead to avoid computational overhead.
min_iterations: (int32) the minimum number of Sinkhorn iterations carried
min_iterations: the minimum number of Sinkhorn iterations carried
out before the error is computed and monitored.
max_iterations: (int32) the maximum number of Sinkhorn iterations.
max_iterations: the maximum number of Sinkhorn iterations.
lse_mode: True for log-sum-exp computations, False for kernel multiply.
debiased: whether to run the debiased version of the Sinkhorn divergence.

Expand All @@ -75,8 +76,8 @@ def discrete_barycenter(

if weights is None:
weights = jnp.ones((batch_size,)) / batch_size
if not jnp.alltrue(weights > 0) or weights.shape[0] != batch_size:
raise ValueError(f'weights must have positive values and size {batch_size}')
if weights.shape[0] != batch_size:
raise ValueError(f'weights must have size `{batch_size}`.')

if dual_initialization is None:
# initialization strategy from https://arxiv.org/pdf/1503.02533.pdf, (3.6)
Expand Down
18 changes: 3 additions & 15 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def transport_cost_at_geom(
we resort to instantiating both transport matrix and cost matrix.

Args:
other_geom: geometry whose cost matrix is used to evaluate tranposrtation.
other_geom: geometry whose cost matrix is used to evaluate the transport
cost.

Returns:
the transportation cost at :math:`C`, i.e. :math:`\langle P, C \rangle`.
Expand Down Expand Up @@ -383,10 +384,6 @@ class Sinkhorn:
when the algorithm has converged with a low tolerance.
initializer: how to compute the initial potentials/scalings.
kwargs_init: keyword arguments when creating the initializer.
jit: if True, automatically jits the function upon first call.
Should be set to False when used in a function that is jitted by the user,
or when computing gradients (in which case the gradient function
should be jitted by the user)
"""

def __init__(
Expand All @@ -406,7 +403,6 @@ def __init__(
initializer: Union[Literal["default", "gaussian", "sorting"],
init_lib.SinkhornInitializer] = "default",
kwargs_init: Optional[Mapping[str, Any]] = None,
jit: bool = True
):
self.lse_mode = lse_mode
self.threshold = threshold
Expand Down Expand Up @@ -435,7 +431,6 @@ def __init__(
self.parallel_dual_updates = parallel_dual_updates
self.initializer = initializer
self.kwargs_init = {} if kwargs_init is None else kwargs_init
self.jit = jit

# Force implicit_differentiation to True when using Anderson acceleration,
# Reset all momentum parameters to default (i.e. no momentum)
Expand Down Expand Up @@ -472,8 +467,7 @@ def __call__(
init_dual_a, init_dual_b = initializer(
ot_prob, *init, lse_mode=self.lse_mode
)
run_fn = jax.jit(run) if self.jit else run
return run_fn(ot_prob, self, (init_dual_a, init_dual_b))
return run(ot_prob, self, (init_dual_a, init_dual_b))

def lse_step(
self, ot_prob: linear_problem.LinearProblem, state: SinkhornState,
Expand Down Expand Up @@ -781,7 +775,6 @@ def make(
parallel_dual_updates: bool = False,
use_danskin: bool = None,
initializer: init_lib.SinkhornInitializer = init_lib.DefaultInitializer(),
jit: bool = False
) -> Sinkhorn:
"""For backward compatibility."""
del tau_a, tau_b
Expand Down Expand Up @@ -825,7 +818,6 @@ def make(
parallel_dual_updates=parallel_dual_updates,
use_danskin=use_danskin,
initializer=initializer,
jit=jit
)


Expand Down Expand Up @@ -1132,10 +1124,6 @@ def sinkhorn(
gradients have been stopped. This is useful when carrying out first order
differentiation, and is only valid (as with ``implicit_differentiation``)
when the algorithm has converged with a low tolerance.
jit: if True, automatically jits the function upon first call.
Should be set to False when used in a function that is jitted by the user,
or when computing gradients (in which case the gradient function
should be jitted by the user).
kwargs: Additional keyword arguments (see above).

Returns:
Expand Down
5 changes: 1 addition & 4 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,7 @@ def __call__(
assert ot_prob.is_balanced, "Unbalanced case is not implemented."
initializer = self.create_initializer(ot_prob)
init = initializer(ot_prob, *init, key=key, **kwargs)
run_fn = jax.jit(run) if self.jit else run
return run_fn(ot_prob, self, init)
return run(ot_prob, self, init)

def _lr_costs(
self,
Expand Down Expand Up @@ -677,7 +676,6 @@ def make(
max_iterations: int = 2000,
use_danskin: bool = True,
implicit_diff: bool = False,
jit: bool = True,
kwargs_dys: Optional[Mapping[str, Any]] = None
) -> LRSinkhorn:
return LRSinkhorn(
Expand All @@ -693,6 +691,5 @@ def make(
max_iterations=max_iterations,
use_danskin=use_danskin,
implicit_diff=implicit_diff,
jit=jit,
kwargs_dys=kwargs_dys
)
4 changes: 2 additions & 2 deletions src/ott/solvers/nn/icnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ICNN(nn.Module):
:cite:`amos:17` with initialization schemes proposed by :cite:`bunne:22`.

Args:
dim_data: data dimensionality.
dim_hidden: sequence specifying size of hidden dimensions. The
output dimension of the last layer is 1 by default.
init_std: value of standard deviation of weight initialization method.
Expand All @@ -48,18 +49,17 @@ class ICNN(nn.Module):
act_fn: choice of activation function used in network architecture
(needs to be convex, default: `nn.relu`).
pos_weights: choice to enforce positivity of weight or use regularizer.
dim_data: data dimensionality (default: 2).
gaussian_map: data inputs of source and target measures for
initialization scheme based on Gaussian approximation of input and
target measure (if None, identity initialization is used).
"""

dim_data: int
dim_hidden: Sequence[int]
init_std: float = 1e-1
init_fn: Callable = jax.nn.initializers.normal
act_fn: Callable = nn.relu
pos_weights: bool = True
dim_data: int = 2
gaussian_map: Tuple[jnp.ndarray, jnp.ndarray] = None

def setup(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/ott/solvers/nn/neuraldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def __init__(

# set default neural architectures
if neural_f is None:
neural_f = icnn.ICNN(dim_hidden=[64, 64, 64, 64])
neural_f = icnn.ICNN(dim_data=input_dim, dim_hidden=[64, 64, 64, 64])
if neural_g is None:
neural_g = icnn.ICNN(dim_hidden=[64, 64, 64, 64])
neural_g = icnn.ICNN(dim_data=input_dim, dim_hidden=[64, 64, 64, 64])

# set optimizer and networks
self.setup(rng, neural_f, neural_g, input_dim, optimizer_f, optimizer_g)
Expand Down
12 changes: 4 additions & 8 deletions src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ def __call__(
initializer = self.create_initializer(prob)
init = initializer(prob, epsilon=self.epsilon, key=key1, **kwargs)

gromov_fn = jax.jit(iterations) if self.jit else iterations
out = gromov_fn(self, prob, init, key2)
out = iterations(self, prob, init, key2)
# TODO(lpapaxanthos): remove stop_gradient when using backprop
if self.is_low_rank:
linearization = prob.update_lr_linearization(
Expand Down Expand Up @@ -395,7 +394,6 @@ def make(
epsilon: Union[epsilon_scheduler.Epsilon, float] = 1.,
rank: int = -1,
max_iterations: int = 50,
jit: bool = False,
warm_start: Optional[bool] = None,
store_inner_errors: bool = False,
linear_ot_solver_kwargs: Optional[Mapping[str, Any]] = None,
Expand All @@ -410,7 +408,6 @@ def make(
rank: integer used to constrain the rank of GW solutions if >0.
max_iterations: the maximum number of outer iterations for
Gromov Wasserstein.
jit: bool, if True, jits the function.
warm_start: Whether to initialize (low-rank) Sinkhorn calls using values
from the previous iteration. If `None`, it's enabled when using low-rank.
store_inner_errors: whether or not to return all the errors of the inner
Expand Down Expand Up @@ -449,7 +446,6 @@ def make(
threshold=threshold,
min_iterations=min_iterations,
max_iterations=max_iterations,
jit=jit,
store_inner_errors=store_inner_errors,
warm_start=warm_start,
**kwargs
Expand Down Expand Up @@ -478,9 +474,9 @@ def gromov_wasserstein(
if the problem is fused) and calls a solver to output a solution.

Args:
geom_xx: a Geometry object for the first view.
geom_yy: a second Geometry object for the second view.
geom_xy: a Geometry object representing the linear cost in FGW.
geom_xx: Geometry for the first view.
geom_yy: Geometry for the second view.
geom_xy: Geometry representing the linear cost in FGW.
fused_penalty: multiplier of the linear term in Fused Gromov Wasserstein,
i.e. loss = quadratic_loss + fused_penalty * linear_loss.
Ignored if ``geom_xy`` is not specified.
Expand Down
6 changes: 1 addition & 5 deletions src/ott/solvers/quadratic/gw_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class GromovWassersteinBarycenter(was_solver.WassersteinSolver):
min_iterations: Minimum number of iterations.
max_iterations: Maximum number of outermost iterations.
threshold: Convergence threshold.
jit: Whether to jit the iteration loop.
store_inner_errors: Whether to store the errors of the GW solver, as well
as its linear solver, at each iteration for each measure.
quad_solver: The GW solver.
Expand All @@ -67,7 +66,6 @@ def __init__(
min_iterations: int = 5,
max_iterations: int = 50,
threshold: float = 1e-3,
jit: bool = True,
store_inner_errors: bool = False,
quad_solver: Optional[gromov_wasserstein.GromovWasserstein] = None,
# TODO(michalk8): maintain the API compatibility with `was_solver`
Expand All @@ -81,7 +79,6 @@ def __init__(
min_iterations=min_iterations,
max_iterations=max_iterations,
threshold=threshold,
jit=jit,
store_inner_errors=store_inner_errors
)
self._quad_solver = quad_solver
Expand All @@ -105,9 +102,8 @@ def __call__(
Returns:
The solution.
"""
bar_fn = jax.jit(iterations, static_argnums=1) if self.jit else iterations
state = self.init_state(problem, bar_size, **kwargs)
state = bar_fn(solver=self, problem=problem, init_state=state)
state = iterations(solver=self, problem=problem, init_state=state)
return self.output_from_state(state)

def init_state(
Expand Down
Loading

0 comments on commit 70297ea

Please sign in to comment.