Skip to content

Commit

Permalink
Feature/center potentials (ott-jax#194)
Browse files Browse the repository at this point in the history
* Center potentials for balanced problem

* Add potential centered test

* Fix not using `rtol/atol`
  • Loading branch information
michalk8 authored Dec 5, 2022
1 parent b07f042 commit d510223
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def finalize(carry):
return max_value + self._bias

def to_LRCGeometry(
self, rank: int, tol: float = 1e-2, seed: int = 0
self, rank: int = 0, tol: float = 1e-2, seed: int = 0
) -> 'LRCGeometry':
"""Return self."""
return self
Expand Down
18 changes: 9 additions & 9 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,14 +428,9 @@ def __init__(
self.momentum = acceleration.Momentum(
inner_iterations=self.inner_iterations
)
# Use adaptive momentum from 300th iteration. Only do so
# if error is already below threshold below.
else:
self.momentum = acceleration.Momentum(
start=300,
error_threshold=1e-2,
inner_iterations=self.inner_iterations
)
# no momentum
self.momentum = acceleration.Momentum()

self.parallel_dual_updates = parallel_dual_updates
self.initializer = initializer
Expand Down Expand Up @@ -633,8 +628,13 @@ def output_from_state(
geom = ot_prob.geom
f = state.fu if self.lse_mode else geom.potential_from_scaling(state.fu)
g = state.gv if self.lse_mode else geom.potential_from_scaling(state.gv)
errors = state.errors[:, 0]
return SinkhornOutput(f=f, g=g, errors=errors)
if ot_prob.is_balanced:
# center the potentials for numerical stability if the problem is balanced
is_finite = jnp.isfinite(f)
center = jnp.sum(jnp.where(is_finite, f, 0.)) / jnp.sum(is_finite)
f -= center
g += center
return SinkhornOutput(f=f, g=g, errors=state.errors[:, 0])

@property
def norm_error(self) -> Tuple[int, ...]:
Expand Down
6 changes: 3 additions & 3 deletions tests/solvers/linear/sinkhorn_grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_grid_vs_euclidean(self, rng: jnp.ndarray, lse_mode: bool):
@pytest.mark.fast.with_args("lse_mode", [False, True], only_fast=1)
def test_apply_transport_grid(self, rng: jnp.ndarray, lse_mode: bool):
grid_size = (5, 6, 7)
keys = jax.random.split(rng, 3)
keys = jax.random.split(rng, 4)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
Expand All @@ -91,8 +91,8 @@ def test_apply_transport_grid(self, rng: jnp.ndarray, lse_mode: bool):

batch_a = 3
batch_b = 4
vec_a = jax.random.normal(keys[4], [batch_a, np.prod(np.array(grid_size))])
vec_b = jax.random.normal(keys[4], [batch_b, np.prod(grid_size)])
vec_a = jax.random.normal(keys[2], [batch_a, np.prod(np.array(grid_size))])
vec_b = jax.random.normal(keys[3], [batch_b, np.prod(grid_size)])

vec_a = vec_a / jnp.sum(vec_a, axis=1)[:, jnp.newaxis]
vec_b = vec_b / jnp.sum(vec_b, axis=1)[:, jnp.newaxis]
Expand Down
2 changes: 2 additions & 0 deletions tests/solvers/linear/sinkhorn_lr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def test_euclidean_point_cloud_lr(
# Ensure cost can still be computed on different geometry.
other_geom = pointcloud.PointCloud(self.x, self.y + 0.3)
cost_other = out.transport_cost_at_geom(other_geom)
cost_other_lr = out.transport_cost_at_geom(other_geom.to_LRCGeometry())
assert cost_other > 0.0
np.testing.assert_allclose(cost_other, cost_other_lr, rtol=1e-6, atol=1e-6)

# Ensure cost is higher when using high entropy.
# (Note that for small entropy regularizers, this can be the opposite
Expand Down
12 changes: 12 additions & 0 deletions tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,15 @@ def test_primal_cost_pointcloud(self, cost_fn):
atol=1e-1)
cost = jnp.sum(out.matrix * out.geom.cost_matrix)
np.testing.assert_allclose(cost, out.primal_cost, rtol=1e-5, atol=1e-5)

@pytest.mark.parametrize("lse_mode", [False, True])
def test_f_potential_is_centered(self, lse_mode: bool):
geom = pointcloud.PointCloud(self.x, self.y)
prob = linear_problem.LinearProblem(geom, a=self.a, b=self.b)
assert prob.is_balanced
solver = sinkhorn.Sinkhorn(lse_mode=lse_mode)

f = solver(prob).f
f_mean = jnp.mean(jnp.where(jnp.isfinite(f), f, 0.))

np.testing.assert_allclose(f_mean, 0., rtol=1e-6, atol=1e-6)
4 changes: 3 additions & 1 deletion tests/tools/sinkhorn_divergence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ def test_segment_sinkhorn_different_segment_sizes(self):
pointcloud.PointCloud, x, y, epsilon=0.01
).divergence for x, y in zip((x1, x2), (y1, y2))
])
np.testing.assert_allclose(segmented_divergences, true_divergences)
np.testing.assert_allclose(
segmented_divergences, true_divergences, rtol=1e-6, atol=1e-6
)

def test_sinkhorn_divergence_segment_custom_padding(self, rng):
rngs = jax.random.split(rng, 4)
Expand Down

0 comments on commit d510223

Please sign in to comment.