Skip to content

Commit

Permalink
Feature/graph normalized laplacian (ott-jax#191)
Browse files Browse the repository at this point in the history
* Add option to normalize the Laplacian

* Add test

* [ci skip] Update docstring of `normalize`
  • Loading branch information
michalk8 authored Dec 5, 2022
1 parent 5a0078f commit e6dd151
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 25 deletions.
64 changes: 47 additions & 17 deletions src/ott/geometry/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class Graph(geometry.Geometry):
directed: Whether the ``graph`` is directed. If not, it will be made
undirected as :math:`G + G^T`. This parameter is ignored when directly
passing the Laplacian, which is assumed to be symmetric.
normalize: Whether to normalize the Laplacian as
:math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L
\left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the
unnormalized Laplacian and :math:`D` the degree matrix.
tol: Relative tolerance with respect to the Hilbert metric, see
:cite:`peyre:19`, Remark 4.12. Used when iteratively updating scalings.
If negative, this option is ignored and only ``n_steps`` is used.
Expand All @@ -55,6 +59,7 @@ def __init__(
numerical_scheme: Literal["backward_euler",
"crank_nicolson"] = "backward_euler",
directed: bool = False,
normalize: bool = False,
tol: float = -1.,
**kwargs: Any
):
Expand All @@ -64,13 +69,14 @@ def __init__(
# arbitrary epsilon; can't use `None` as `mean_cost_matrix` would be used
super().__init__(epsilon=1., **kwargs)
self._graph = graph
self._laplacian = laplacian
self._lap = laplacian
self._solver: Optional[decomposition.CholeskySolver] = None

self._t = t
self.n_steps = n_steps
self.numerical_scheme = numerical_scheme
self.directed = directed
self.normalize = normalize
self._tol = tol

def apply_kernel(
Expand Down Expand Up @@ -179,21 +185,44 @@ def cost_matrix(self) -> jnp.ndarray:

@property
def laplacian(self) -> Union[jnp.ndarray, Sparse_t]:
"""The graph Laplacian."""
if self._laplacian is not None:
return self._laplacian
"""The (normalized) graph Laplacian."""
return self._norm_laplacian if self.normalize else self._laplacian

if self.is_sparse:
n, _ = self.shape
D, ixs = self.graph.sum(1).todense(), jnp.arange(n)
D = jesp.BCOO((D, jnp.c_[ixs, ixs]), shape=(n, n))
else:
D = jnp.diag(self.graph.sum(1))
def _degree_matrix(self,
*,
inv_sqrt: bool = False) -> Union[jnp.ndarray, Sparse_t]:
if not self.is_sparse:
data = self.graph.sum(1)
if inv_sqrt:
data = jnp.where(data > 0., 1. / jnp.sqrt(data), 0.)
return jnp.diag(data)

n, _ = self.shape
data, ixs = self.graph.sum(1).todense(), jnp.arange(n)
if inv_sqrt:
data = jnp.where(data > 0., 1. / jnp.sqrt(data), 0.)
return jesp.BCOO((data, jnp.c_[ixs, ixs]), shape=(n, n))

@property
def _laplacian(self) -> Union[jnp.ndarray, Sparse_t]:
if self._lap is not None:
return self._lap
# in the sparse case, we don't sum duplicates here because
# we need to know `nnz` a priori for JIT (could be exposed in `__init__`)
# instead, `ott.math.decomposition._jax_sparse_to_scipy` handles it on host
return D - self.graph
return self._degree_matrix() - self.graph

@property
def _norm_laplacian(self) -> Union[jnp.ndarray, Sparse_t]:
# assumes symmetric Laplacian, as mentioned in `__init__`
lap = self._laplacian
inv_sqrt_deg = self._degree_matrix(inv_sqrt=True)
if not self.is_sparse:
return inv_sqrt_deg @ lap @ inv_sqrt_deg

inv_sqrt_deg = inv_sqrt_deg.data # (n,)
# much faster than doing sparse MM
return inv_sqrt_deg[:, None] * lap * inv_sqrt_deg[None, :]

@property
def t(self) -> float:
Expand All @@ -208,7 +237,7 @@ def t(self) -> float:

@property
def _scale(self) -> float:
"""Constant to scale the Laplacian with."""
"""Constant used to scale the Laplacian."""
if self.numerical_scheme == "backward_euler":
return self.t / (4. * self.n_steps)
if self.numerical_scheme == "crank_nicolson":
Expand Down Expand Up @@ -247,14 +276,14 @@ def solver(self) -> decomposition.CholeskySolver:

@property
def shape(self) -> Tuple[int, int]:
arr = self._graph if self._graph is not None else self._laplacian
arr = self._graph if self._graph is not None else self._lap
return arr.shape

@property
def is_sparse(self) -> bool:
"""Whether :attr:`graph` or :attr:`laplacian` is sparse."""
if self._laplacian is not None:
return isinstance(self.laplacian, Sparse_t.__args__)
if self._lap is not None:
return isinstance(self._lap, Sparse_t.__args__)
if isinstance(self._graph, (jesp.CSR, jesp.CSC, jesp.COO)):
raise NotImplementedError("Graph must be specified in `BCOO` format.")
return isinstance(self._graph, jesp.BCOO)
Expand All @@ -268,7 +297,7 @@ def graph(self) -> Optional[Union[jnp.ndarray, jesp.BCOO]]:

@property
def is_symmetric(self) -> bool:
# there are some numerical imprecisions, but it should be symmetric
# there may be some numerical imprecisions, but it should be symmetric
return True

@property
Expand Down Expand Up @@ -303,11 +332,12 @@ def marginal_from_potentials(
raise ValueError("Not implemented.")

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._graph, self._laplacian, self.solver], {
return [self._graph, self._lap, self.solver], {
"t": self._t,
"n_steps": self.n_steps,
"numerical_scheme": self.numerical_scheme,
"directed": self.directed,
"normalize": self.normalize,
"tol": self._tol,
**self._kwargs,
}
Expand Down
2 changes: 1 addition & 1 deletion src/ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Grid(geometry.Geometry):
convolutions one dimension at a time.
Args:
x : list of arrays of varying sizes, describing the locations of the grid.
x: list of arrays of varying sizes, describing the locations of the grid.
Locations are provided as a list of arrays, that is :math:`d`
vectors of (possibly varying) size :math:`n_i`. The resulting grid
is the Cartesian product of these vectors.
Expand Down
37 changes: 30 additions & 7 deletions tests/geometry/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_init_graph(self, fmt: Optional[str]):

assert geom.shape == (n, n)
assert geom.graph is G
assert geom._laplacian is None
assert geom._lap is None
# compute the laplacian on the fly
assert isinstance(geom.laplacian, type(geom.graph))

Expand Down Expand Up @@ -273,8 +273,8 @@ def test_crank_nicolson_sparse_matches_dense(self, eps: float):
atol=eps * 1e2,
)

@pytest.mark.parametrize("jit", [False, True])
def test_directed_graph(self, jit: bool):
@pytest.mark.parametrize("jit,normalize", [(False, True), (True, False)])
def test_directed_graph(self, jit: bool, normalize: bool):

def callback(geom: graph.Graph,
laplacian: bool) -> Union[jnp.ndarray, jesp.BCOO]:
Expand All @@ -283,16 +283,39 @@ def callback(geom: graph.Graph,
G = random_graph(16, p=0.25, directed=True)
fn = jax.jit(callback, static_argnums=1) if jit else callback

geom = graph.Graph(G, directed=True)
geom = graph.Graph(G, directed=True, normalize=normalize)

with pytest.raises(AssertionError):
np.testing.assert_array_equal(G, G.T)
np.testing.assert_allclose(G, G.T)

G = fn(geom, laplacian=False)
L = fn(geom, laplacian=True)

np.testing.assert_array_equal(G, G.T)
np.testing.assert_array_equal(L, L.T)
np.testing.assert_allclose(G, G.T, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(L, L.T, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("fmt", [None, "coo"])
@pytest.mark.parametrize("normalize", [False, True])
def test_normalize_laplacian(self, fmt: Optional[str], normalize: bool):

def laplacian(geom: graph.Graph) -> jnp.ndarray:
graph = geom.graph.todense() if geom.is_sparse else geom.graph
data = G.sum(1)
deg = jnp.diag(data)
lap = deg - graph
if not normalize:
return lap
inv_sqrt_deg = jnp.diag(jnp.where(data > 0., 1. / jnp.sqrt(data), 0.))
return inv_sqrt_deg @ lap @ inv_sqrt_deg

directed = False
G = random_graph(51, p=0.35, directed=directed)
geom = graph.Graph(G, directed=directed, normalize=normalize)

expected = laplacian(geom)
actual = geom.laplacian

np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)

@pytest.mark.fast
def test_factor_cache_works(self, rng: jnp.ndarray):
Expand Down

0 comments on commit e6dd151

Please sign in to comment.