Skip to content

Commit

Permalink
Remove matplotlib dependency (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 authored Feb 20, 2023
1 parent ab9d89f commit c465cf1
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 41 deletions.
3 changes: 1 addition & 2 deletions docs/tutorials/notebooks/MetaOT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,10 @@
"source": [
"from collections import namedtuple\n",
"\n",
"import torchvision\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import torchvision\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import cm\n",
Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/notebooks/OTT_&_POT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@
"source": [
"import timeit\n",
"\n",
"import ot\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import ot\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import mpl_toolkits.axes_grid1\n",
Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/notebooks/gromov_wasserstein.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@
"metadata": {},
"outputs": [],
"source": [
"from IPython import display\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import mpl_toolkits.mplot3d.axes3d as p3\n",
"from IPython import display\n",
"from matplotlib import animation, cm\n",
"\n",
"from ott.geometry import pointcloud\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/notebooks/gromov_wasserstein_multiomics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@
"source": [
"import time\n",
"\n",
"from IPython import display\n",
"from ot.gromov import gwloss, init_matrix\n",
"from SCOT.src import evals\n",
"from SCOT.src.scot import SCOT\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from ot.gromov import gwloss, init_matrix\n",
"from sklearn.decomposition import PCA\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sn\n",
"from IPython import display\n",
"from matplotlib import animation\n",
"\n",
"import ott\n",
Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/notebooks/neural_dual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,14 @@
"from dataclasses import dataclass\n",
"from functools import partial\n",
"\n",
"from IPython.display import clear_output, display\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import optax\n",
"from torch.utils.data import DataLoader, IterableDataset\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import clear_output, display\n",
"\n",
"from ott.geometry import pointcloud\n",
"from ott.problems.linear import potentials\n",
Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/notebooks/point_clouds.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@
},
"outputs": [],
"source": [
"from IPython import display\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from IPython import display\n",
"\n",
"import ott\n",
"from ott.geometry import costs, pointcloud\n",
Expand Down
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ channels:
- conda-forge
dependencies:
- python>=3.8
- ott-jax>=0.3.1
- ott-jax>=0.4
- matplotlib-base>=3.0.0
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ dependencies = [
"jaxopt>=0.5.5",
# https://github.com/google/jax/discussions/9951#discussioncomment-3017784
"numpy>=1.18.4, !=1.23.0",
"matplotlib>=3.0.0",
"flax>=0.5.2",
"optax>=0.1.1",
"scipy>=1.7.0",
Expand Down Expand Up @@ -83,6 +82,7 @@ docs = [
"sphinxcontrib-bibtex>=2.5.0",
"sphinxcontrib-spelling>=7.7.0",
"myst-nb>=0.17.1",
"matplotlib>=3.0.0",
]

[tool.setuptools]
Expand All @@ -101,8 +101,9 @@ profile = "black"
include_trailing_comma = true
multi_line_output = 3
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"]
known_numeric = ["numpy", "scipy", "pandas", "sklearn", "jax", "flax", "optax", "torch"]
known_plotting = ["matplotlib", "mpl_toolkits", "seaborn"]
# also contains what we import in notebooks
known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "torch", "ot", "torchvision", "pandas", "sklearn"]
known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]

[tool.pytest.ini_options]
minversion = "6.0"
Expand Down Expand Up @@ -190,6 +191,7 @@ legacy_tox_ini = """
[testenv:lint-docs]
description = Lint the documentation.
deps =
extras = docs
allowlist_externals =
rm
Expand All @@ -207,6 +209,7 @@ legacy_tox_ini = """
[testenv:build-docs]
description = Build the documentation.
use_develop = true
deps =
extras = docs
allowlist_externals = sphinx-build
commands =
Expand All @@ -216,6 +219,7 @@ legacy_tox_ini = """
[testenv:clean-docs]
description = Remove the documentation.
deps =
skip_install = true
changedir = {tox_root}/docs
allowlist_externals = make
Expand Down
25 changes: 17 additions & 8 deletions src/ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from ott.problems.linear import linear_problem

if TYPE_CHECKING:
from ott.geometry import costs

try:
import matplotlib as mpl
import matplotlib.pyplot as plt
except ImportError:
mpl = plt = None

__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]

Expand Down Expand Up @@ -178,10 +182,10 @@ def plot_ot_map(
source: jnp.ndarray,
target: jnp.ndarray,
forward: bool = True,
ax: Optional[matplotlib.axes.Axes] = None,
ax: Optional["plt.Axes"] = None,
legend_kwargs: Optional[Dict[str, Any]] = None,
scatter_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
) -> Tuple["plt.Figure", "plt.Axes"]:
"""Plot data and learned optimal transport map.
Args:
Expand All @@ -190,12 +194,17 @@ def plot_ot_map(
forward: use the forward map from the potentials
if ``True``, otherwise use the inverse map
ax: axis to add the plot to
scatter_kwargs: additional kwargs passed into :meth:`~matplotlib.axes.Axes.scatter`
legend_kwargs: additional kwargs passed into :meth:`~matplotlib.axes.Axes.legend`
scatter_kwargs: additional kwargs passed into
:meth:`~matplotlib.axes.Axes.scatter`
legend_kwargs: additional kwargs passed into
:meth:`~matplotlib.axes.Axes.legend`
Returns:
matplotlib figure and axis with the plots
"""
if mpl is None:
raise RuntimeError("Please install `matplotlib` first.")

if scatter_kwargs is None:
scatter_kwargs = {'alpha': 0.5}
if legend_kwargs is None:
Expand Down Expand Up @@ -263,12 +272,12 @@ def plot_potential(
self,
forward: bool = True,
quantile: float = 0.05,
ax: Optional[matplotlib.axes.Axes] = None,
ax: Optional["mpl.axes.Axes"] = None,
x_bounds: Tuple[float, float] = (-6, 6),
y_bounds: Tuple[float, float] = (-6, 6),
num_grid: int = 50,
contourf_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
) -> Tuple["mpl.figure.Figure", "mpl.axes.Axes"]:
"""Plot the potential.
Args:
Expand Down
44 changes: 27 additions & 17 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,28 @@
from typing import List, Optional, Sequence, Tuple, Union

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import scipy
from matplotlib import animation

from ott import utils
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein

try:
import matplotlib.pyplot as plt
from matplotlib import animation
except ImportError:
plt = animation = None

# TODO(michalk8): make sure all outputs conform to a unified transport interface
Transport = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput,
gromov_wasserstein.GWOutput]


def bidimensional(x: jnp.ndarray,
y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Apply PCA to reduce to bimensional data."""
"""Apply PCA to reduce to bi-dimensional data."""
if x.shape[1] < 3:
return x, y

Expand All @@ -44,25 +48,31 @@ def bidimensional(x: jnp.ndarray,


class Plot:
"""Plot an optimal transport map between two point clouds.
"""Plot an optimal transport map between two \
:class:`PointClouds <ott.geometry.pointcloud.PointCloud>`.
It enables to either plot or update a plot in a single object, offering the
possibilities to create animations as matplotlib.animation.FuncAnimation,
which can in turned be saved to disk at will. There are two design principles
here: 1) we do not rely on saving to/loading from disk to create animations
2) we try as much as possible to disentangle the transport problem(s) from
its visualization(s).
possibilities to create animations as a
:class:`~matplotlib.animation.FuncAnimation`, which can in turned be saved to
disk at will. There are two design principles here:
#. we do not rely on saving to/loading from disk to create animations
#. we try as much as possible to disentangle the transport problem from
its visualization.
"""

def __init__(
self,
fig: Optional[plt.Figure] = None,
ax: Optional[plt.Axes] = None,
fig: Optional["plt.Figure"] = None,
ax: Optional["plt.Axes"] = None,
cost_threshold: float = -1.0, # should be negative for animations.
scale: int = 200,
show_lines: bool = True,
cmap: str = 'cool'
):
if plt is None:
raise RuntimeError("Please install `matplotlib` first.")

if ax is None and fig is None:
fig, ax = plt.subplots()
elif fig is None:
Expand Down Expand Up @@ -102,7 +112,7 @@ def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray):
result.append((xy[i, [0, 2]], xy[i, [1, 3]], strength))
return result

def __call__(self, ot: Transport) -> List[plt.Artist]:
def __call__(self, ot: Transport) -> List["plt.Artist"]:
"""Plot 2-D couplings. Projects via PCA if data is higher dimensional."""
x, y, sx, sy = self._scatter(ot)
self._points_x = self.ax.scatter(
Expand Down Expand Up @@ -130,7 +140,7 @@ def __call__(self, ot: Transport) -> List[plt.Artist]:
self._lines.append(line)
return [self._points_x, self._points_y] + self._lines

def update(self, ot: Transport) -> List[plt.Artist]:
def update(self, ot: Transport) -> List["plt.Artist"]:
"""Update a plot with a transport instance."""
x, y, _, _ = self._scatter(ot)
self._points_x.set_offsets(x)
Expand Down Expand Up @@ -168,7 +178,7 @@ def animate(
self,
transports: Sequence[Transport],
frame_rate: float = 10.0
) -> animation.FuncAnimation:
) -> "animation.FuncAnimation":
"""Make an animation from several transports."""
_ = self(transports[0])
return animation.FuncAnimation(
Expand All @@ -182,13 +192,13 @@ def animate(


def _barycenters(
ax: plt.Axes,
ax: "plt.Axes",
y: jnp.ndarray,
a: jnp.ndarray,
b: jnp.ndarray,
matrix: jnp.ndarray,
scale: int = 200
):
) -> None:
"""Plot 2-D sinkhorn barycenters."""
sa, sb = jnp.min(a) / scale, jnp.min(b) / scale
ax.scatter(*y.T, s=b / sb, edgecolors='k', marker='X', label='y')
Expand All @@ -202,7 +212,7 @@ def barycentric_projections(
a: jnp.ndarray = None,
b: jnp.ndarray = None,
matrix: jnp.ndarray = None,
ax: Optional[plt.Axes] = None,
ax: Optional["plt.Axes"] = None,
**kwargs
):
"""Plot the barycenters, from the Transport object or from arguments."""
Expand Down

0 comments on commit c465cf1

Please sign in to comment.