diff --git a/docs/notebooks/MetaOT.ipynb b/docs/notebooks/MetaOT.ipynb index ebe8bc21d..2eb4b7492 100644 --- a/docs/notebooks/MetaOT.ipynb +++ b/docs/notebooks/MetaOT.ipynb @@ -681,7 +681,6 @@ " b = eval_batch.b[i]\n", "\n", " sink_kwargs = {\n", - " \"jit\": True,\n", " \"threshold\": -1,\n", " \"inner_iterations\": 1,\n", " \"max_iterations\": 26,\n", diff --git a/docs/notebooks/OTT_&_POT.ipynb b/docs/notebooks/OTT_&_POT.ipynb index 5d4ca6815..bafc12230 100644 --- a/docs/notebooks/OTT_&_POT.ipynb +++ b/docs/notebooks/OTT_&_POT.ipynb @@ -168,7 +168,6 @@ " b,\n", " threshold=threshold,\n", " lse_mode=True,\n", - " jit=False,\n", " max_iterations=1000,\n", " )\n", " f, g = out.f, out.g\n", diff --git a/docs/notebooks/One_Sinkhorn.ipynb b/docs/notebooks/One_Sinkhorn.ipynb index 01aefa7ec..9acc50c39 100644 --- a/docs/notebooks/One_Sinkhorn.ipynb +++ b/docs/notebooks/One_Sinkhorn.ipynb @@ -380,11 +380,10 @@ "import functools\n", "\n", "my_sinkorn = functools.partial(\n", - " sinkhorn,\n", + " sinkhorn.sinkhorn,\n", " inner_iterations=1, # recomputing error every iteration for plots.\n", " max_iterations=10000, # more iterations than the default setting to see full curves.\n", - " jit=True,\n", - ") # force jit" + ")" ] }, { diff --git a/docs/notebooks/Sinkhorn_Barycenters.ipynb b/docs/notebooks/Sinkhorn_Barycenters.ipynb index 02e388fba..e21f40af3 100644 --- a/docs/notebooks/Sinkhorn_Barycenters.ipynb +++ b/docs/notebooks/Sinkhorn_Barycenters.ipynb @@ -293,8 +293,9 @@ ], "source": [ "%%time\n", + "discrete_barycenter_fn = jax.jit(discrete_barycenter.discrete_barycenter)\n", "g_grid._epsilon.target = 1\n", - "barycenter = discrete_barycenter.discrete_barycenter(g_grid, a)" + "barycenter = discrete_barycenter_fn(g_grid, a)" ] }, { @@ -320,7 +321,7 @@ "source": [ "%%time\n", "g_grid._epsilon.target = 1e-4\n", - "barycenter = discrete_barycenter.discrete_barycenter(g_grid, a)" + "barycenter = discrete_barycenter_fn(g_grid, a)" ] }, { diff --git a/docs/notebooks/fairness.ipynb b/docs/notebooks/fairness.ipynb index 33f4da429..f70388873 100644 --- a/docs/notebooks/fairness.ipynb +++ b/docs/notebooks/fairness.ipynb @@ -587,7 +587,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -601,7 +601,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/gromov_wasserstein.ipynb b/docs/notebooks/gromov_wasserstein.ipynb index 23c2db96b..ed35afe75 100644 --- a/docs/notebooks/gromov_wasserstein.ipynb +++ b/docs/notebooks/gromov_wasserstein.ipynb @@ -255,7 +255,10 @@ "geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)\n", "geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)\n", "out = gw.gromov_wasserstein(\n", - " geom_xx=geom_xx, geom_yy=geom_yy, epsilon=100.0, max_iterations=20, jit=True\n", + " geom_xx=geom_xx,\n", + " geom_yy=geom_yy,\n", + " epsilon=100.0,\n", + " max_iterations=20,\n", ")\n", "n_outer_iterations = jnp.sum(out.costs != -1)\n", "has_converged = bool(out.linear_convergence[n_outer_iterations - 1])\n", @@ -8106,7 +8109,6 @@ " geom_yy=geom_yy,\n", " epsilon=i,\n", " max_iterations=40,\n", - " jit=True,\n", " sinkhorn_kwargs=config,\n", " )\n", " im.set_array(out.matrix)\n", diff --git a/docs/notebooks/gromov_wasserstein_multiomics.ipynb b/docs/notebooks/gromov_wasserstein_multiomics.ipynb index 778b97cc1..c0434acaf 100644 --- a/docs/notebooks/gromov_wasserstein_multiomics.ipynb +++ b/docs/notebooks/gromov_wasserstein_multiomics.ipynb @@ -175,7 +175,6 @@ " threshold=1e-9,\n", " a=self.p,\n", " b=self.q,\n", - " jit=True,\n", " ).matrix\n", "\n", " constC, hC1, hC2 = init_matrix(\n", diff --git a/docs/notebooks/wasserstein_barycenters_gmms.ipynb b/docs/notebooks/wasserstein_barycenters_gmms.ipynb index 1703efe78..a5db475b8 100644 --- a/docs/notebooks/wasserstein_barycenters_gmms.ipynb +++ b/docs/notebooks/wasserstein_barycenters_gmms.ipynb @@ -425,7 +425,7 @@ "outputs": [], "source": [ "# create a Wasserstein barycenter solver.\n", - "solver = continuous_barycenter.WassersteinBarycenter(lse_mode=True, jit=True)" + "solver = continuous_barycenter.WassersteinBarycenter(lse_mode=True)" ] }, { diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index 8e60ba51a..5f1859b08 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -140,7 +140,6 @@ def from_solver( sinkhorn_kwargs = { "norm_error": lin_sol._norm_error, "lse_mode": lin_sol.lse_mode, - "jit": lin_sol.jit, "implicit_diff": lin_sol.implicit_diff, "use_danskin": lin_sol.use_danskin } @@ -387,14 +386,12 @@ def _compute_factor( from ott.tools import k_means del kwargs - jit = self._sinkhorn_kwargs.get("jit", True) fn = functools.partial( k_means.k_means, min_iterations=self._min_iter, max_iterations=self._max_iter, **self._kwargs ) - fn = jax.jit(fn, static_argnames="k") if jit else fn if isinstance(ot_prob, quadratic_problem.QuadraticProblem): geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index c4d557bf9..45553fcce 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -127,9 +127,8 @@ def __call__( x_init: Optional[jnp.ndarray] = None, rng: int = 0 ) -> BarycenterState: - bar_fn = jax.jit(iterations, static_argnums=1) if self.jit else iterations - out = bar_fn(self, bar_size, bar_prob, x_init, rng) - return out + # TODO(michalk8): no reason for iterations to be outside this class + return iterations(self, bar_size, bar_prob, x_init, rng) def init_state( self, diff --git a/src/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py index 19dcbb506..a181d6d12 100644 --- a/src/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -50,17 +50,18 @@ def discrete_barycenter( """Compute discrete barycenter :cite:`janati:20a`. Args: - geom: a Cost object able to apply kernels with a certain epsilon. - a: jnp.ndarray[batch, geom.num_a]: batch of histograms. - weights: jnp.ndarray of weights in the probability simplex - dual_initialization: jnp.ndarray, size [batch, num_b] initialization for g_v - threshold: (float) tolerance to monitor convergence. - norm_error: int, power used to define p-norm of error for marginal/target. - inner_iterations: (int32) the Sinkhorn error is not recomputed at each + geom: geometry object. + a: batch of histograms of shape ``[batch, num_a]``. + weights: positive weights in the probability simplex. + dual_initialization: array of shape ``[batch, num_b]`` for the + initialization of `g_v`. + threshold: tolerance to monitor convergence. + norm_error: power used to define p-norm of error for marginal/target. + inner_iterations: the Sinkhorn error is not recomputed at each iteration but every inner_num_iter instead to avoid computational overhead. - min_iterations: (int32) the minimum number of Sinkhorn iterations carried + min_iterations: the minimum number of Sinkhorn iterations carried out before the error is computed and monitored. - max_iterations: (int32) the maximum number of Sinkhorn iterations. + max_iterations: the maximum number of Sinkhorn iterations. lse_mode: True for log-sum-exp computations, False for kernel multiply. debiased: whether to run the debiased version of the Sinkhorn divergence. @@ -75,8 +76,8 @@ def discrete_barycenter( if weights is None: weights = jnp.ones((batch_size,)) / batch_size - if not jnp.alltrue(weights > 0) or weights.shape[0] != batch_size: - raise ValueError(f'weights must have positive values and size {batch_size}') + if weights.shape[0] != batch_size: + raise ValueError(f'weights must have size `{batch_size}`.') if dual_initialization is None: # initialization strategy from https://arxiv.org/pdf/1503.02533.pdf, (3.6) diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 411297631..20787e019 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -263,7 +263,8 @@ def transport_cost_at_geom( we resort to instantiating both transport matrix and cost matrix. Args: - other_geom: geometry whose cost matrix is used to evaluate tranposrtation. + other_geom: geometry whose cost matrix is used to evaluate the transport + cost. Returns: the transportation cost at :math:`C`, i.e. :math:`\langle P, C \rangle`. @@ -383,10 +384,6 @@ class Sinkhorn: when the algorithm has converged with a low tolerance. initializer: how to compute the initial potentials/scalings. kwargs_init: keyword arguments when creating the initializer. - jit: if True, automatically jits the function upon first call. - Should be set to False when used in a function that is jitted by the user, - or when computing gradients (in which case the gradient function - should be jitted by the user) """ def __init__( @@ -406,7 +403,6 @@ def __init__( initializer: Union[Literal["default", "gaussian", "sorting"], init_lib.SinkhornInitializer] = "default", kwargs_init: Optional[Mapping[str, Any]] = None, - jit: bool = True ): self.lse_mode = lse_mode self.threshold = threshold @@ -435,7 +431,6 @@ def __init__( self.parallel_dual_updates = parallel_dual_updates self.initializer = initializer self.kwargs_init = {} if kwargs_init is None else kwargs_init - self.jit = jit # Force implicit_differentiation to True when using Anderson acceleration, # Reset all momentum parameters to default (i.e. no momentum) @@ -472,8 +467,7 @@ def __call__( init_dual_a, init_dual_b = initializer( ot_prob, *init, lse_mode=self.lse_mode ) - run_fn = jax.jit(run) if self.jit else run - return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) + return run(ot_prob, self, (init_dual_a, init_dual_b)) def lse_step( self, ot_prob: linear_problem.LinearProblem, state: SinkhornState, @@ -781,7 +775,6 @@ def make( parallel_dual_updates: bool = False, use_danskin: bool = None, initializer: init_lib.SinkhornInitializer = init_lib.DefaultInitializer(), - jit: bool = False ) -> Sinkhorn: """For backward compatibility.""" del tau_a, tau_b @@ -825,7 +818,6 @@ def make( parallel_dual_updates=parallel_dual_updates, use_danskin=use_danskin, initializer=initializer, - jit=jit ) @@ -1132,10 +1124,6 @@ def sinkhorn( gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as with ``implicit_differentiation``) when the algorithm has converged with a low tolerance. - jit: if True, automatically jits the function upon first call. - Should be set to False when used in a function that is jitted by the user, - or when computing gradients (in which case the gradient function - should be jitted by the user). kwargs: Additional keyword arguments (see above). Returns: diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index d402e1e80..1edc9c938 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -321,8 +321,7 @@ def __call__( assert ot_prob.is_balanced, "Unbalanced case is not implemented." initializer = self.create_initializer(ot_prob) init = initializer(ot_prob, *init, key=key, **kwargs) - run_fn = jax.jit(run) if self.jit else run - return run_fn(ot_prob, self, init) + return run(ot_prob, self, init) def _lr_costs( self, @@ -677,7 +676,6 @@ def make( max_iterations: int = 2000, use_danskin: bool = True, implicit_diff: bool = False, - jit: bool = True, kwargs_dys: Optional[Mapping[str, Any]] = None ) -> LRSinkhorn: return LRSinkhorn( @@ -693,6 +691,5 @@ def make( max_iterations=max_iterations, use_danskin=use_danskin, implicit_diff=implicit_diff, - jit=jit, kwargs_dys=kwargs_dys ) diff --git a/src/ott/solvers/nn/icnn.py b/src/ott/solvers/nn/icnn.py index dc121ca90..bbb608433 100644 --- a/src/ott/solvers/nn/icnn.py +++ b/src/ott/solvers/nn/icnn.py @@ -40,6 +40,7 @@ class ICNN(nn.Module): :cite:`amos:17` with initialization schemes proposed by :cite:`bunne:22`. Args: + dim_data: data dimensionality. dim_hidden: sequence specifying size of hidden dimensions. The output dimension of the last layer is 1 by default. init_std: value of standard deviation of weight initialization method. @@ -48,18 +49,17 @@ class ICNN(nn.Module): act_fn: choice of activation function used in network architecture (needs to be convex, default: `nn.relu`). pos_weights: choice to enforce positivity of weight or use regularizer. - dim_data: data dimensionality (default: 2). gaussian_map: data inputs of source and target measures for initialization scheme based on Gaussian approximation of input and target measure (if None, identity initialization is used). """ + dim_data: int dim_hidden: Sequence[int] init_std: float = 1e-1 init_fn: Callable = jax.nn.initializers.normal act_fn: Callable = nn.relu pos_weights: bool = True - dim_data: int = 2 gaussian_map: Tuple[jnp.ndarray, jnp.ndarray] = None def setup(self) -> None: diff --git a/src/ott/solvers/nn/neuraldual.py b/src/ott/solvers/nn/neuraldual.py index aa8eb18a9..6a47e591a 100644 --- a/src/ott/solvers/nn/neuraldual.py +++ b/src/ott/solvers/nn/neuraldual.py @@ -92,9 +92,9 @@ def __init__( # set default neural architectures if neural_f is None: - neural_f = icnn.ICNN(dim_hidden=[64, 64, 64, 64]) + neural_f = icnn.ICNN(dim_data=input_dim, dim_hidden=[64, 64, 64, 64]) if neural_g is None: - neural_g = icnn.ICNN(dim_hidden=[64, 64, 64, 64]) + neural_g = icnn.ICNN(dim_data=input_dim, dim_hidden=[64, 64, 64, 64]) # set optimizer and networks self.setup(rng, neural_f, neural_g, input_dim, optimizer_f, optimizer_g) diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 59b1994c6..cb2a2864c 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -214,8 +214,7 @@ def __call__( initializer = self.create_initializer(prob) init = initializer(prob, epsilon=self.epsilon, key=key1, **kwargs) - gromov_fn = jax.jit(iterations) if self.jit else iterations - out = gromov_fn(self, prob, init, key2) + out = iterations(self, prob, init, key2) # TODO(lpapaxanthos): remove stop_gradient when using backprop if self.is_low_rank: linearization = prob.update_lr_linearization( @@ -395,7 +394,6 @@ def make( epsilon: Union[epsilon_scheduler.Epsilon, float] = 1., rank: int = -1, max_iterations: int = 50, - jit: bool = False, warm_start: Optional[bool] = None, store_inner_errors: bool = False, linear_ot_solver_kwargs: Optional[Mapping[str, Any]] = None, @@ -410,7 +408,6 @@ def make( rank: integer used to constrain the rank of GW solutions if >0. max_iterations: the maximum number of outer iterations for Gromov Wasserstein. - jit: bool, if True, jits the function. warm_start: Whether to initialize (low-rank) Sinkhorn calls using values from the previous iteration. If `None`, it's enabled when using low-rank. store_inner_errors: whether or not to return all the errors of the inner @@ -449,7 +446,6 @@ def make( threshold=threshold, min_iterations=min_iterations, max_iterations=max_iterations, - jit=jit, store_inner_errors=store_inner_errors, warm_start=warm_start, **kwargs @@ -478,9 +474,9 @@ def gromov_wasserstein( if the problem is fused) and calls a solver to output a solution. Args: - geom_xx: a Geometry object for the first view. - geom_yy: a second Geometry object for the second view. - geom_xy: a Geometry object representing the linear cost in FGW. + geom_xx: Geometry for the first view. + geom_yy: Geometry for the second view. + geom_xy: Geometry representing the linear cost in FGW. fused_penalty: multiplier of the linear term in Fused Gromov Wasserstein, i.e. loss = quadratic_loss + fused_penalty * linear_loss. Ignored if ``geom_xy`` is not specified. diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index f88e175cd..4360b3f6f 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -52,7 +52,6 @@ class GromovWassersteinBarycenter(was_solver.WassersteinSolver): min_iterations: Minimum number of iterations. max_iterations: Maximum number of outermost iterations. threshold: Convergence threshold. - jit: Whether to jit the iteration loop. store_inner_errors: Whether to store the errors of the GW solver, as well as its linear solver, at each iteration for each measure. quad_solver: The GW solver. @@ -67,7 +66,6 @@ def __init__( min_iterations: int = 5, max_iterations: int = 50, threshold: float = 1e-3, - jit: bool = True, store_inner_errors: bool = False, quad_solver: Optional[gromov_wasserstein.GromovWasserstein] = None, # TODO(michalk8): maintain the API compatibility with `was_solver` @@ -81,7 +79,6 @@ def __init__( min_iterations=min_iterations, max_iterations=max_iterations, threshold=threshold, - jit=jit, store_inner_errors=store_inner_errors ) self._quad_solver = quad_solver @@ -105,9 +102,8 @@ def __call__( Returns: The solution. """ - bar_fn = jax.jit(iterations, static_argnums=1) if self.jit else iterations state = self.init_state(problem, bar_size, **kwargs) - state = bar_fn(solver=self, problem=problem, init_state=state) + state = iterations(solver=self, problem=problem, init_state=state) return self.output_from_state(state) def init_state( diff --git a/src/ott/solvers/was_solver.py b/src/ott/solvers/was_solver.py index 5d9fd4913..38b5fa7a7 100644 --- a/src/ott/solvers/was_solver.py +++ b/src/ott/solvers/was_solver.py @@ -40,7 +40,6 @@ def __init__( min_iterations: int = 5, max_iterations: int = 50, threshold: float = 1e-3, - jit: bool = True, store_inner_errors: bool = False, **kwargs: Any, ): @@ -58,7 +57,7 @@ def __init__( if epsilon is None: # Use default entropic regularization in LRSinkhorn if None was passed self.linear_ot_solver = sinkhorn_lr.LRSinkhorn( - rank=self.rank, jit=False, **kwargs + rank=self.rank, **kwargs ) else: # If epsilon is passed, use it to replace the default LRSinkhorn value @@ -67,14 +66,13 @@ def __init__( ) else: # When using Entropic GW, epsilon is not handled inside Sinkhorn, - # but rather added back to the Geometry object reinstantiated - # when linearizing the problem. Therefore no need to pass it to solver. + # but rather added back to the Geometry object re-instantiated + # when linearizing the problem. Therefore, no need to pass it to solver. self.linear_ot_solver = sinkhorn.Sinkhorn(**kwargs) self.min_iterations = min_iterations self.max_iterations = max_iterations self.threshold = threshold - self.jit = jit self.store_inner_errors = store_inner_errors self._kwargs = kwargs @@ -87,7 +85,6 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: return ([self.epsilon, self.linear_ot_solver, self.threshold], { "min_iterations": self.min_iterations, "max_iterations": self.max_iterations, - "jit": self.jit, "rank": self.rank, "store_inner_errors": self.store_inner_errors, **self._kwargs diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py index 2369e4b28..fe102d353 100644 --- a/tests/geometry/subsetting_test.py +++ b/tests/geometry/subsetting_test.py @@ -112,10 +112,15 @@ def test_masked_summary( ): geom, masked = geom_masked if stat == "mean": - np.testing.assert_allclose(geom.mean_cost_matrix, masked.mean_cost_matrix) + np.testing.assert_allclose( + geom.mean_cost_matrix, masked.mean_cost_matrix, rtol=1e-6, atol=1e-6 + ) else: np.testing.assert_allclose( - geom.median_cost_matrix, masked.median_cost_matrix + geom.median_cost_matrix, + masked.median_cost_matrix, + rtol=1e-6, + atol=1e-6, ) def test_mask_permutation( @@ -173,9 +178,14 @@ def test_subset_mask( assert geom.src_mask.shape == (geom.shape[0],) assert geom.tgt_mask.shape == (geom.shape[1],) - np.testing.assert_allclose(geom.mean_cost_matrix, masked.mean_cost_matrix) np.testing.assert_allclose( - geom.median_cost_matrix, masked.median_cost_matrix + geom.mean_cost_matrix, masked.mean_cost_matrix, rtol=1e-6, atol=1e-6 + ) + np.testing.assert_allclose( + geom.median_cost_matrix, + masked.median_cost_matrix, + rtol=1e-6, + atol=1e-6 ) np.testing.assert_allclose( geom.cost_matrix, masked.cost_matrix, rtol=1e-6, atol=1e-6 diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 9b073a7ff..4741f4d9b 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -81,7 +81,7 @@ def run_sinkhorn_sort_init( geom = pointcloud.PointCloud(x, y, epsilon=epsilon) sort_init = lin_init.SortingInitializer(vectorized_update=vector_min) out = sinkhorn.sinkhorn( - geom, a=a, b=b, jit=True, initializer=sort_init, lse_mode=lse_mode + geom, a=a, b=b, initializer=sort_init, lse_mode=lse_mode ) return out @@ -89,7 +89,7 @@ def run_sinkhorn_sort_init( @functools.partial(jax.jit, static_argnames=['lse_mode']) def run_sinkhorn(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - out = sinkhorn.sinkhorn(geom, a=a, b=b, jit=True, lse_mode=lse_mode) + out = sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=lse_mode) return out @@ -100,7 +100,6 @@ def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): geom, a=a, b=b, - jit=True, initializer=lin_init.GaussianInitializer(), lse_mode=lse_mode ) @@ -307,7 +306,7 @@ def test_meta_initializer(self, lse_mode, rng: jnp.ndarray): ) base_num_iter = jnp.sum(sink_out.errors > -1) - # Overfit the initializer to the problem. + # overfit the initializer to the problem. meta_initializer = ott.initializers.nn.initializers.MetaInitializer(geom) for _ in range(100): _, _, meta_initializer.state = meta_initializer.update( @@ -315,12 +314,7 @@ def test_meta_initializer(self, lse_mode, rng: jnp.ndarray): ) sink_out = sinkhorn.sinkhorn( - geom, - a=a, - b=b, - jit=True, - initializer=meta_initializer, - lse_mode=lse_mode + geom, a=a, b=b, initializer=meta_initializer, lse_mode=lse_mode ) meta_num_iter = jnp.sum(sink_out.errors > -1) diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index 20f2ff4cd..c16e7ff13 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -39,18 +39,18 @@ class TestBarycenter: @pytest.mark.fast.with_args( rank=[-1, 6], epsilon=[1e-1, 1e-2], - jit=[True, False], init_random=[True, False], + jit=[False, True], only_fast={ "rank": -1, "epsilon": 1e-1, - "jit": True, - "init_random": False + "init_random": False, + "jit": False, }, ) def test_euclidean_barycenter( - self, rng: jnp.ndarray, rank: int, epsilon: float, jit: bool, - init_random: bool + self, rng: jnp.ndarray, rank: int, epsilon: float, init_random: bool, + jit: bool ): rngs = jax.random.split(rng, 20) # Sample 2 point clouds, each of size 113, the first around [0,1]^4, @@ -83,7 +83,9 @@ def test_euclidean_barycenter( # Define solver threshold = 1e-3 - solver = cb.WassersteinBarycenter(rank=rank, threshold=threshold, jit=jit) + solver = cb.WassersteinBarycenter(rank=rank, threshold=threshold) + if jit: + solver = jax.jit(solver, static_argnames="bar_size") # Set barycenter size to 31. bar_size = 31 @@ -170,12 +172,16 @@ def barycenter( jit=[False, True], only_fast={ "lse_mode": True, + "jit": False, "epsilon": 1e-1, - "jit": False } ) def test_bures_barycenter( - self, rng: jnp.ndarray, lse_mode: bool, epsilon: float, jit: bool + self, + rng: jnp.ndarray, + lse_mode: bool, + epsilon: float, + jit: bool, ): num_measures = 2 num_components = 2 @@ -230,27 +236,23 @@ def test_bures_barycenter( assert bar_p.max_measure_size == seg_y.shape[1] assert bar_p.ndim == seg_y.shape[2] - solver = cb.WassersteinBarycenter(lse_mode=lse_mode, jit=jit) + solver = cb.WassersteinBarycenter(lse_mode=lse_mode) + if jit: + solver = jax.jit(solver, static_argnames="bar_size") out = solver(bar_p, bar_size=bar_size, x_init=x_init) barycenter = out.x means_bary, covs_bary = costs.x_to_means_and_covs(barycenter, dimension) - assert jnp.logical_or( - jnp.allclose( - means_bary, - jnp.array([[0., 1.], [0., -1.]]), - rtol=1e-02, - atol=1e-02 - ), - jnp.allclose( - means_bary, - jnp.array([[0., -1.], [0., 1.]]), - rtol=1e-02, - atol=1e-02 - ) - ) + try: + np.testing.assert_allclose( + means_bary, jnp.array([[0., 1.], [0., -1.]]), rtol=1e-02, atol=1e-02 + ) + except AssertionError: + np.testing.assert_allclose( + means_bary, jnp.array([[0., -1.], [0., 1.]]), rtol=1e-02, atol=1e-02 + ) np.testing.assert_allclose( covs_bary, @@ -262,17 +264,22 @@ def test_bures_barycenter( @pytest.mark.fast.with_args( alpha=[50., 1.], epsilon=[1e-2, 1e-1], - jit=[False, True], dim=[4, 10], + jit=[False, True], only_fast={ "alpha": 50, "epsilon": 1e-1, - "jit": False, - "dim": 4 + "dim": 4, + "jit": False } ) def test_bures_barycenter_different_number_of_components( - self, rng: jnp.ndarray, dim: int, alpha: float, epsilon: float, jit: bool + self, + rng: jnp.ndarray, + alpha: float, + epsilon: float, + dim: int, + jit: bool, ): n_components = jnp.array([3, 4]) # the number of components of the GMMs num_measures = n_components.size @@ -351,7 +358,9 @@ def test_bures_barycenter_different_number_of_components( assert bar_p.num_measures == num_measures assert bar_p.ndim == ys.shape[-1] - solver = cb.WassersteinBarycenter(lse_mode=True, jit=jit) + solver = cb.WassersteinBarycenter(lse_mode=True) + if jit: + solver = jax.jit(solver, static_argnames="bar_size") # Compute the barycenter. out = solver(bar_p, bar_size=bar_size, x_init=x_init) diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index c4b1dd2d1..4c17c3be8 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -159,7 +159,7 @@ def test_autograd_sinkhorn( a = a / jnp.sum(a) b = b / jnp.sum(b) - def reg_ot(a, b): + def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: return sinkhorn.sinkhorn( pointcloud.PointCloud(x, y, epsilon=0.1), a=a, b=b, lse_mode=lse_mode ).reg_ot_cost @@ -267,7 +267,8 @@ def test_gradient_sinkhorn_euclidean( # Adding some near-zero distances to test proper handling with p_norm=1. y = y.at[0].set(x[0, :] + 1e-3) - def loss_fn(x, y): + def loss_fn(x: jnp.ndarray, + y: jnp.ndarray) -> Tuple[float, sinkhorn.SinkhornOutput]: geom = pointcloud.PointCloud(x, y, epsilon=epsilon, cost_fn=cost_fn) out = sinkhorn.sinkhorn( geom, @@ -277,7 +278,6 @@ def loss_fn(x, y): implicit_differentiation=implicit_differentiation, min_iterations=min_iter, max_iterations=max_iter, - jit=False ) return out.reg_ot_cost, out @@ -332,13 +332,13 @@ def reg_ot_cost(c: jnp.ndarray) -> float: @pytest.mark.fast def test_differentiability_with_jit(self, rng: jnp.ndarray): - cost = jax.random.uniform(rng, (15, 17)) def reg_ot_cost(c: jnp.ndarray) -> float: geom = geometry.Geometry(c, epsilon=1e-2) - return sinkhorn.sinkhorn(geom, jit=True).reg_ot_cost + return sinkhorn.sinkhorn(geom).reg_ot_cost - gradient = jax.grad(reg_ot_cost)(cost) + cost = jax.random.uniform(rng, (15, 17)) + gradient = jax.jit(jax.grad(reg_ot_cost))(cost) np.testing.assert_array_equal(jnp.isnan(gradient), False) @pytest.mark.fast.with_args( diff --git a/tests/solvers/linear/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py index dd5488d02..4684fe2eb 100644 --- a/tests/solvers/linear/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -40,7 +40,11 @@ def test_separable_grid(self, rng: jnp.ndarray, lse_mode: bool): threshold = 0.01 geom = grid.Grid(grid_size=grid_size, epsilon=0.1) errors = sinkhorn.sinkhorn( - geom, a=a, b=b, threshold=threshold, lse_mode=lse_mode, jit=False + geom, + a=a, + b=b, + threshold=threshold, + lse_mode=lse_mode, ).errors err = errors[jnp.isfinite(errors)][-1] assert threshold > err @@ -63,7 +67,10 @@ def test_grid_vs_euclidean(self, rng: jnp.ndarray, lse_mode: bool): ]).transpose() geometry_mat = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon) out_mat = sinkhorn.sinkhorn( - geometry_mat, a=a, b=b, lse_mode=lse_mode, jit=False + geometry_mat, + a=a, + b=b, + lse_mode=lse_mode, ) out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b, lse_mode=lse_mode) np.testing.assert_allclose( diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index d93e985ce..24d5dd9b6 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests Anderson acceleration for Sinkhorn.""" -import functools -from typing import Callable, Tuple +from typing import Tuple import chex import jax @@ -24,8 +23,6 @@ from ott.geometry import costs, geometry, pointcloud from ott.solvers.linear import sinkhorn -non_jitted_sinkhorn = functools.partial(sinkhorn.sinkhorn, jit=False) - class TestSinkhornAnderson: """Tests for Anderson acceleration.""" @@ -227,8 +224,8 @@ def test_online_matches_offline_size(self, batch_size: int): sol_online.b, sol_offline.b, rtol=rtol, atol=atol ) - @pytest.mark.parametrize("outer_jit", [False, True]) - def test_online_sinkhorn_jit(self, outer_jit: bool): + @pytest.mark.parametrize("jit", [False, True]) + def test_online_sinkhorn_jit(self, jit: bool): def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput: geom = pointcloud.PointCloud( @@ -239,13 +236,12 @@ def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput: a=self.a, b=self.b, threshold=threshold, - jit=True, lse_mode=True, implicit_differentiation=True ) threshold = 1e-1 - fun = jax.jit(callback, static_argnums=(1,)) if outer_jit else callback + fun = jax.jit(callback, static_argnums=(1,)) if jit else callback errors = fun(epsilon=1.0, batch_size=42).errors err = errors[errors > -1][-1] @@ -322,39 +318,31 @@ def initialize(self, rng: jnp.ndarray): def test_jit_vs_non_jit_fwd(self): def assert_output_close(x: jnp.ndarray, y: jnp.ndarray) -> None: - """Asserst SinkhornOutputs are close.""" + """Assert SinkhornOutputs are close.""" x = tuple(a for a in x if (a is not None and isinstance(a, jnp.ndarray))) y = tuple(a for a in y if (a is not None and isinstance(a, jnp.ndarray))) return chex.assert_tree_all_close(x, y, atol=1e-6, rtol=0) - def f( - g: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray - ) -> sinkhorn.SinkhornOutput: - return non_jitted_sinkhorn(g, a, b) - - jitted_result = sinkhorn.sinkhorn(self.geometry, self.a, self.b) - non_jitted_result = non_jitted_sinkhorn(self.geometry, self.a, self.b) - user_jitted_result = jax.jit(f)(self.geometry, self.a, self.b) + jitted_result = jax.jit(sinkhorn.sinkhorn)(self.geometry, self.a, self.b) + non_jitted_result = sinkhorn.sinkhorn(self.geometry, self.a, self.b) - assert_output_close(jitted_result, non_jitted_result) - assert_output_close(jitted_result, user_jitted_result) + assert_output_close(non_jitted_result, jitted_result) @pytest.mark.parametrize("implicit", [False, True]) def test_jit_vs_non_jit_bwd(self, implicit: bool): - def loss( - a: jnp.ndarray, x: jnp.ndarray, fun: Callable[..., - sinkhorn.SinkhornOutput] - ): - out = fun( - geometry.Geometry( - cost_matrix=( - jnp.sum(x ** 2, axis=1)[:, jnp.newaxis] + - jnp.sum(self.y ** 2, axis=1)[jnp.newaxis, :] - - 2 * jnp.dot(x, self.y.T) - ), - epsilon=self.epsilon + @jax.value_and_grad + def val_grad(a: jnp.ndarray, x: jnp.ndarray): + geom = geometry.Geometry( + cost_matrix=( + jnp.sum(x ** 2, axis=1)[:, jnp.newaxis] + + jnp.sum(self.y ** 2, axis=1)[jnp.newaxis, :] - + 2 * jnp.dot(x, self.y.T) ), + epsilon=self.epsilon + ) + out = sinkhorn.sinkhorn( + geom, a=a, b=self.b, tau_a=0.94, @@ -365,18 +353,8 @@ def loss( ) return out.reg_ot_cost - def value_and_grad(a: jnp.ndarray, x: jnp.ndarray): - return jax.value_and_grad(loss)(a, x, non_jitted_sinkhorn) - - jitted_loss, jitted_grad = jax.value_and_grad(loss)( - self.a, self.x, sinkhorn.sinkhorn - ) - non_jitted_loss, non_jitted_grad = jax.value_and_grad(loss)( - self.a, self.x, non_jitted_sinkhorn - ) + jitted_loss, jitted_grad = jax.jit(val_grad)(self.a, self.x) + non_jitted_loss, non_jitted_grad = val_grad(self.a, self.x) - user_jitted_loss, user_jitted_grad = jax.jit(value_and_grad)(self.a, self.x) - chex.assert_tree_all_close(jitted_loss, non_jitted_loss, atol=1e-6, rtol=0) - chex.assert_tree_all_close(jitted_grad, non_jitted_grad, atol=1e-6, rtol=0) - chex.assert_tree_all_close(user_jitted_loss, jitted_loss, atol=1e-6, rtol=0) - chex.assert_tree_all_close(user_jitted_grad, jitted_grad, atol=1e-6, rtol=0) + chex.assert_tree_all_close(jitted_loss, non_jitted_loss, atol=1e-6, rtol=0.) + chex.assert_tree_all_close(jitted_grad, non_jitted_grad, atol=1e-6, rtol=0.) diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index 162feb84d..f5e362a46 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -85,14 +85,18 @@ def test_autoepsilon(self): # First geom specifies explicitly relative_epsilon to be True. This is not # needed in principle, but introduced here to test logic. geom_1 = pointcloud.PointCloud(self.x, self.y, relative_epsilon=True) - # jit first with jit inside sinkhorn call. + # not jitting f_1 = sinkhorn.sinkhorn( - geom_1, a=self.a, b=self.b, tau_a=.99, tau_b=.97, jit=True + geom_1, + a=self.a, + b=self.b, + tau_a=.99, + tau_b=.97, ).f # Second geom does not provide whether epsilon is relative. geom_2 = pointcloud.PointCloud(scale * self.x, scale * self.y) - # jit now with jit outside sinkhorn call. + # jitting compute_f = jax.jit( lambda g, a, b: sinkhorn.sinkhorn(g, a, b, tau_a=.99, tau_b=.97).f ) @@ -125,38 +129,33 @@ def test_autoepsilon_with_decay( tau_b: float ): """Check that variations in init/decay work, and result in same solution.""" - geom = pointcloud.PointCloud(self.x, self.y, init=init, decay=decay) - out_1 = sinkhorn.sinkhorn( - geom, - a=self.a, - b=self.b, - tau_a=tau_a, - tau_b=tau_b, - jit=True, - lse_mode=lse_mode, - threshold=1e-5 - ) - geom = pointcloud.PointCloud(self.x, self.y) - out_2 = sinkhorn.sinkhorn( - geom, - a=self.a, - b=self.b, - tau_a=tau_a, - tau_b=tau_b, - jit=True, - lse_mode=lse_mode, - threshold=1e-5 - ) + @jax.jit + def run_sinkhorn(geom: pointcloud.PointCloud) -> sinkhorn.SinkhornOutput: + return sinkhorn.sinkhorn( + geom, + a=self.a, + b=self.b, + tau_a=tau_a, + tau_b=tau_b, + lse_mode=lse_mode, + threshold=1e-5 + ) + + geom1 = pointcloud.PointCloud(self.x, self.y, init=init, decay=decay) + geom2 = pointcloud.PointCloud(self.x, self.y) + out_1 = run_sinkhorn(geom1) + out_2 = run_sinkhorn(geom2) # recenter if problem is balanced, since in that case solution is only # valid up to additive constant. - unb = (tau_a < 1.0 or tau_b < 1.0) - np.testing.assert_allclose( - out_1.f if unb else out_1.f - jnp.mean(out_1.f[jnp.isfinite(out_1.f)]), - out_2.f if unb else out_2.f - jnp.mean(out_2.f[jnp.isfinite(out_2.f)]), - rtol=1e-4, - atol=1e-4 - ) + if out_1.ot_prob.is_balanced: + # TODO(michalk8): remove after https://github.com/ott-jax/ott/pull/194 + f_1 = out_1.f - jnp.mean(out_1.f[jnp.isfinite(out_1.f)]) + f_2 = out_2.f - jnp.mean(out_2.f[jnp.isfinite(out_2.f)]) + else: + f_1, f_2 = out_1.f, out_2.f + + np.testing.assert_allclose(f_1, f_2, rtol=1e-4, atol=1e-4) @pytest.mark.fast def test_euclidean_point_cloud_min_iter(self): diff --git a/tests/solvers/nn/icnn_test.py b/tests/solvers/nn/icnn_test.py index f7f64127a..d3b0af292 100644 --- a/tests/solvers/nn/icnn_test.py +++ b/tests/solvers/nn/icnn_test.py @@ -30,7 +30,7 @@ def test_icnn_convexity(self, rng: jnp.ndarray): dim_hidden = (64, 64) # define icnn model - model = icnn.ICNN(dim_hidden) + model = icnn.ICNN(n_features, dim_hidden=dim_hidden) # initialize model key1, key2, key3 = jax.random.split(rng, 3) @@ -54,16 +54,16 @@ def test_icnn_hessian(self, rng: jnp.ndarray): """Tests if Hessian of ICNN is positive-semidefinite.""" # define icnn model - n_samples = 2 + n_features = 2 dim_hidden = (64, 64) - model = icnn.ICNN(dim_hidden) + model = icnn.ICNN(n_features, dim_hidden=dim_hidden) # initialize model key1, key2 = jax.random.split(rng) - params = model.init(key1, jnp.ones(n_samples))['params'] + params = model.init(key1, jnp.ones(n_features))['params'] # check if Hessian is positive-semidefinite via eigenvalues - data = jax.random.normal(key2, (n_samples,)) + data = jax.random.normal(key2, (n_features,)) # compute Hessian hessian = jax.jacfwd(jax.jacrev(model.apply, argnums=1), argnums=1) diff --git a/tests/solvers/quadratic/fgw_barycenter_test.py b/tests/solvers/quadratic/fgw_barycenter_test.py index d3dca9ad6..2c9cc4268 100644 --- a/tests/solvers/quadratic/fgw_barycenter_test.py +++ b/tests/solvers/quadratic/fgw_barycenter_test.py @@ -45,7 +45,7 @@ def barycenter( assert prob.ndim_fused == self.ndim_f solver = gwb_solver.GromovWassersteinBarycenter( - jit=False, store_inner_errors=True, epsilon=epsilon + store_inner_errors=True, epsilon=epsilon ) x_init = jax.random.normal(rng, (bar_size, self.ndim_f)) diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 5256613c9..8d1d33f03 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -84,7 +84,7 @@ def test_fgw_flag_store_errors_fused(self): assert threshold_sinkhorn > last_errors[last_errors > -1][-1] assert out.ndim == 2 - @pytest.mark.fast.with_args(jit=[False, True], only_fast=1) + @pytest.mark.fast.with_args("jit", [False, True], only_fast=0) def test_gradient_marginals_fgw_solver(self, jit: bool): """Test gradient w.r.t. probability weights.""" geom_x = pointcloud.PointCloud(self.x) @@ -92,7 +92,7 @@ def test_gradient_marginals_fgw_solver(self, jit: bool): geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) fused_penalty = self.fused_penalty - def reg_gw(a, b, implicit): + def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): sinkhorn_kwargs = { 'implicit_differentiation': implicit, 'max_iterations': 1001 @@ -107,7 +107,6 @@ def reg_gw(a, b, implicit): epsilon=1.0, loss='sqeucl', max_iterations=10, - jit=jit, sinkhorn_kwargs=sinkhorn_kwargs ) return out.reg_gw_cost, (out.linear_state.f, out.linear_state.g) @@ -115,6 +114,9 @@ def reg_gw(a, b, implicit): grad_matrices = [None, None] for i, implicit in enumerate([True, False]): reg_gw_and_grad = jax.value_and_grad(reg_gw, has_aux=True, argnums=(0, 1)) + if jit: + reg_gw_and_grad = jax.jit(reg_gw_and_grad, static_argnames="implicit") + (_, aux), grad_reg_gw = reg_gw_and_grad(self.a, self.b, implicit) grad_matrices[i] = grad_reg_gw grad_manual_a = aux[0] - jnp.log(self.a) @@ -365,8 +367,15 @@ def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(xx, yy) - ot_gwlr = gw_solver.gromov_wasserstein( - geom_x, geom_y, geom_xy, rank=5, jit=jit + solver = gw_solver.gromov_wasserstein + if jit: + solver = jax.jit(solver, static_argnames="rank") + + ot_gwlr = solver( + geom_x, + geom_y, + geom_xy, + rank=5, ) res0 = ot_gwlr.apply(x.T, axis=0) res1 = ot_gwlr.apply(y.T, axis=1) diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index f824955c8..d6b991a4b 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -38,7 +38,7 @@ def random_pc( key1, key2 = jax.random.split(rng, 2) x = jax.random.normal(key1, (n, d)) y = x if m is None else jax.random.normal(key2, (m, d)) - return pointcloud.PointCloud(x, y, batch_size=None, **kwargs) + return pointcloud.PointCloud(x, y, **kwargs) @staticmethod def pad_cost_matrices( @@ -100,7 +100,9 @@ def test_gw_barycenter( assert problem_pc.ndim == self.ndim assert problem_cost.ndim is None - solver = gwb_solver.GromovWassersteinBarycenter(jit=True) + solver = jax.jit( + gwb_solver.GromovWassersteinBarycenter(), static_argnames="bar_size" + ) out_pc = solver(problem_pc, bar_size=bar_size) out_cost = solver(problem_cost, bar_size=bar_size) diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index 5ec47d2e5..1aa5609ca 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -152,10 +152,8 @@ def test_flag_store_errors(self): @pytest.mark.parametrize("jit", [False, True]) def test_gradient_marginals_gw(self, jit: bool): """Test gradient w.r.t. probability weights.""" - geom_x = pointcloud.PointCloud(self.x) - geom_y = pointcloud.PointCloud(self.y) - def reg_gw(a, b, implicit): + def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): sinkhorn_kwargs = { 'implicit_differentiation': implicit, 'max_iterations': 1001 @@ -168,14 +166,19 @@ def reg_gw(a, b, implicit): epsilon=1.0, loss='sqeucl', max_iterations=10, - jit=jit, sinkhorn_kwargs=sinkhorn_kwargs ) return out.reg_gw_cost, (out.linear_state.f, out.linear_state.g) + geom_x = pointcloud.PointCloud(self.x) + geom_y = pointcloud.PointCloud(self.y) + grad_matrices = [None, None] for i, implicit in enumerate([True, False]): reg_gw_and_grad = jax.value_and_grad(reg_gw, has_aux=True, argnums=(0, 1)) + if jit: + reg_gw_and_grad = jax.jit(reg_gw_and_grad, static_argnames="implicit") + (_, aux), grad_reg_gw = reg_gw_and_grad(self.a, self.b, implicit) grad_matrices[i] = grad_reg_gw grad_manual_a = aux[0] - jnp.log(self.a) diff --git a/tests/tools/gaussian_mixture/probabilities_test.py b/tests/tools/gaussian_mixture/probabilities_test.py index 18a0e2467..5ea9d2fbd 100644 --- a/tests/tools/gaussian_mixture/probabilities_test.py +++ b/tests/tools/gaussian_mixture/probabilities_test.py @@ -28,7 +28,7 @@ def test_probs(self): pp = probabilities.Probabilities(jnp.array([1., 2.])) probs = pp.probs() np.testing.assert_array_equal(probs.shape, (3,)) - np.testing.assert_allclose(jnp.sum(probs), 1.0) + np.testing.assert_allclose(jnp.sum(probs), 1.0, rtol=1e-6, atol=1e-6) np.testing.assert_array_equal(probs > 0., True) def test_log_probs(self): @@ -38,7 +38,7 @@ def test_log_probs(self): np.testing.assert_array_equal(log_probs.shape, (3,)) np.testing.assert_array_equal(probs.shape, (3,)) - np.testing.assert_allclose(jnp.sum(probs), 1.0) + np.testing.assert_allclose(jnp.sum(probs), 1.0, rtol=1e-6, atol=1e-6) np.testing.assert_array_equal(probs > 0., True) def test_from_random(self): @@ -52,7 +52,7 @@ def test_from_random(self): def test_from_probs(self): probs = jnp.array([0.1, 0.2, 0.3, 0.4]) pp = probabilities.Probabilities.from_probs(probs) - np.testing.assert_allclose(probs, pp.probs()) + np.testing.assert_allclose(probs, pp.probs(), rtol=1e-6, atol=1e-6) def test_sample(self): p = 0.4 @@ -74,4 +74,6 @@ def test_pytree_mapping(self): probs = jnp.array([0.1, 0.2, 0.3, 0.4]) pp = probabilities.Probabilities.from_probs(probs) pp_x_2 = jax.tree_map(lambda x: 2 * x, pp) - np.testing.assert_allclose(2. * pp.params, pp_x_2.params) + np.testing.assert_allclose( + 2. * pp.params, pp_x_2.params, rtol=1e-6, atol=1e-6 + )