Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU test runner #242

Merged
merged 27 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c0d07aa
Add `rstcheck` and `doc8`
michalk8 Feb 2, 2023
2b3021d
Pass `CUDA_VISIBLE_DEVICES` to `tox`
michalk8 Feb 2, 2023
6fb3e9d
Add GPU CI runner
michalk8 Feb 3, 2023
d63441d
Merge branch 'main' into feature/improve-pre-commits
michalk8 Feb 3, 2023
af3cc26
Fix Python version for GPU tests
michalk8 Feb 3, 2023
4572dd4
Try running without `tox`
michalk8 Feb 3, 2023
f6c9ea1
Fix not installing jax[cuda]
michalk8 Feb 3, 2023
6010ebd
Use different Docker image
michalk8 Feb 6, 2023
5244413
Fix escpape
michalk8 Feb 6, 2023
7aef30d
Use apt-get
michalk8 Feb 6, 2023
270b86a
Do not use `{}`
michalk8 Feb 6, 2023
68650e5
Fix not installing `git`
michalk8 Feb 6, 2023
7cb3d10
Use personal Docker image
michalk8 Feb 7, 2023
6b34346
Pin `jax[cuda]` version
michalk8 Feb 7, 2023
ab18f39
Mark grad(sqrtm) as CPU only test
michalk8 Feb 7, 2023
b682b64
Fix ICNN hessian test on GPU
michalk8 Feb 7, 2023
1236aa6
Use `eigvalsh` to check for positive-semidefinite
michalk8 Feb 7, 2023
c60cf50
Adjust tolerance in a test
michalk8 Feb 7, 2023
9991e15
Mark Sinkhorn online as CPU
michalk8 Feb 7, 2023
7c2aa8d
Run all tests on GPU
michalk8 Feb 7, 2023
2634a58
Skip more tests on GPU
michalk8 Feb 8, 2023
1fe92af
Update tolerances on k-means test
michalk8 Feb 8, 2023
669dec5
Always jit in online Sinkhorn test
michalk8 Feb 8, 2023
7dcb6b3
Use simple comparison
michalk8 Feb 8, 2023
2863be1
Only run fast GPU tests, try other GPU
michalk8 Feb 8, 2023
d6cfb78
Use previous GPU
michalk8 Feb 8, 2023
d7210f7
[ci skip] Fix test
michalk8 Feb 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ end_of_line = lf
insert_final_newline = true
charset = utf-8

[*py]
[{*py,*.rst}]
indent_size = 2
indent_style = space
max_line_length = 80
Expand Down
25 changes: 24 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ jobs:
run: |
tox -e py39 --skip-pkg-install -- -m fast --memray -n auto -vv

gpu-tests:
name: Fast GPU tests Python 3.8 on ubuntu-20.04
runs-on: [self-hosted, ott-gpu]
container:
image: docker://michalk8/cuda:11.3.0-ubuntu20.04
options: --gpus="device=12"
steps:
- uses: actions/checkout@v3
- name: Install dependencies
# `jax[cuda]<0.4` because of: https://github.com/google/jax/issues/13758
run: |
python3 -m pip install --upgrade pip
python3 -m pip install -e".[test]"
python3 -m pip install "jax[cuda]<0.4" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

- name: Nvidia SMI
run: |
nvidia-smi

- name: Run tests
run: |
python3 -m pytest -m "fast and not cpu" --memray --durations 10 -vv

tests:
name: Python ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
Expand Down Expand Up @@ -68,7 +91,7 @@ jobs:
run: |
tox -e py${{ matrix.python-version }} --skip-pkg-install
env:
PYTEST_ADDOPTS: --memray --durations 10 -vv
PYTEST_ADDOPTS: --memray -vv

- name: Upload coverage
uses: codecov/codecov-action@v3
Expand Down
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,14 @@ repos:
hooks:
- id: pyupgrade
args: [--py38-plus, --keep-runtime-typing]
- repo: https://github.com/rstcheck/rstcheck
rev: v6.1.1
hooks:
- id: rstcheck
additional_dependencies: [tomli]
args: [--config=pyproject.toml]
- repo: https://github.com/PyCQA/doc8
rev: v1.1.1
hooks:
- id: doc8
args: [--config=pyproject.toml]
49 changes: 26 additions & 23 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
.. _geometry:

ott.geometry package
====================
ott.geometry
============
.. currentmodule:: ott.geometry
.. automodule:: ott.geometry

This package implements several classes to define a geometry, arguably the most influential
ingredient of optimal transport problem. In its full generality, a :class:`~ott.geometry.geometry.Geometry`
defines source points (input measure), target points (target measure) and a ground cost function
(resp. a positive kernel function) that quantifies how expensive (resp. easy) it is to displace
a unit of mass from any of the input points to the target points.
This package implements several classes to define a geometry, arguably the most
influential ingredient of optimal transport problem. In its full generality, a
:class:`~ott.geometry.geometry.Geometry` defines source points (input measure),
target points (target measure) and a ground cost function (resp. a positive
kernel function) that quantifies how expensive (resp. easy) it is to displace a
unit of mass from any of the input points to the target points.

The geometry package proposes a few simple geometries. The simplest of all would
be that for which input and target points coincide, and the geometry between them
simplifies to a symmetric cost or kernel matrix. In the very particular case
where these points happen to lie on grid (a cartesian product in full generality,
e.g. 2 or 3D grids), the :class:`~ott.geometry.grid.Grid` geometry will prove useful.
be that for which input and target points coincide, and the geometry between
them simplifies to a symmetric cost or kernel matrix. In the very particular
case where these points happen to lie on grid (a cartesian product in full
generality, e.g. 2 or 3D grids), the :class:`~ott.geometry.grid.Grid`
geometry will prove useful.

For more general settings where input/target points do not coincide, one can
alternatively instantiate a :class:`~ott.geometry.geometry.Geometry` through a rectangular cost matrix.
alternatively instantiate a :class:`~ott.geometry.geometry.Geometry` through a
rectangular cost matrix.

However, it is often preferable in applications to define ground costs "symbolically",
by listing instead points in the input/target point clouds, to specify directly
a cost *function* between them. Such functions should follow the :class:`~ott.geometry.costs.CostFn`
class description. We provide a few standard cost functions that are meaningful in an
OT context, notably the (unbalanced, regularized) Bures distances between
Gaussians :cite:`janati:20`. That cost can be used for instance to compute a distance between
Gaussian mixtures, as proposed in :cite:`chen:19a` and revisited in :cite:`delon:20`.
However, it is often preferable in applications to define ground costs
"symbolically", by listing instead points in the input/target point clouds, to
specify directly a cost *function* between them. Such functions should follow
the :class:`~ott.geometry.costs.CostFn` class description. We provide a few
standard cost functions that are meaningful in an OT context, notably the
(unbalanced, regularized) Bures distances between Gaussians :cite:`janati:20`.
That cost can be used for instance to compute a distance between Gaussian
mixtures, as proposed in :cite:`chen:19a` and revisited in :cite:`delon:20`.

To be useful with Sinkhorn solvers, ``Geometries`` typically need to provide an
``epsilon`` regularization parameter. We propose either to set that value once for
all, or implement an annealing :class:`~ott.geometry.epsilon_scheduler.Epsilon` scheduler.
``epsilon`` regularization parameter. We propose either to set that value once
for all, or implement an annealing
:class:`~ott.geometry.epsilon_scheduler.Epsilon` scheduler.

Geometries
----------
Expand Down
64 changes: 35 additions & 29 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@ Optimal Transport Tools (OTT)

Introduction
------------
``OTT`` is a `JAX <https://jax.readthedocs.io/en/latest/>`_ package that bundles a few utilities to compute,
and differentiate as needed, the solution to optimal transport (OT) problems, taken in a fairly wide sense.
For instance, ``OTT`` can of course compute Wasserstein (or Gromov-Wasserstein) distances between
weighted clouds of points (or histograms) in a wide variety of scenarios,
but also estimate Monge maps, Wasserstein barycenters, and help with simpler tasks
such as differentiable approximations to ranking or even clustering.
``OTT`` is a `JAX <https://jax.readthedocs.io/en/latest/>`_ package that bundles
a few utilities to compute, and differentiate as needed, the solution to optimal
transport (OT) problems, taken in a fairly wide sense. For instance, ``OTT`` can
of course compute Wasserstein (or Gromov-Wasserstein) distances between weighted
clouds of points (or histograms) in a wide variety of scenarios, but also
estimate Monge maps, Wasserstein barycenters, and help with simpler tasks such
as differentiable approximations to ranking or even clustering.

To achieve this, ``OTT`` rests on two families of tools:
The first family consists in *discrete* solvers computing transport between point clouds,
using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn :cite:`scetbon:21` algorithms,
and moving up towards Gromov-Wasserstein :cite:`memoli:11,peyre:16`;
the second family consists in *continuous* solvers, using suitable neural architectures :cite:`amos:17` coupled
with SGD type estimators :cite:`makkuva:20,korotin:21`.

- the first family consists in *discrete* solvers computing transport between
point clouds, using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn
:cite:`scetbon:21` algorithms, and moving up towards Gromov-Wasserstein
:cite:`memoli:11,peyre:16`;
- the second family consists in *continuous* solvers, using suitable neural
architectures :cite:`amos:17` coupled with SGD type estimators
:cite:`makkuva:20,korotin:21`.

Installation
------------
Expand All @@ -27,7 +31,7 @@ Install ``OTT`` from `PyPI <https://pypi.org/project/ott-jax/>`_ as:

pip install ott-jax

or with ``conda`` via `conda-forge <https://anaconda.org/conda-forge/ott-jax>`_ as:
or with ``conda`` via `conda-forge`_ as:

.. code-block:: bash

Expand All @@ -37,40 +41,41 @@ Design Choices
--------------
``OTT`` is designed with the following choices:

- Take advantage whenever possible of JAX features, such as `Just-in-time (JIT) compilation`_,
`auto-vectorization (VMAP)`_ and both `automatic`_ but most importantly `implicit`_ differentiation.
- Take advantage whenever possible of JAX features, such as
`Just-in-time (JIT) compilation`_, `auto-vectorization (VMAP)`_ and both
`automatic`_ but most importantly `implicit`_ differentiation.
- Split geometry from OT solvers in the discrete case: We argue that there
should be one, and one implementation only, of every major OT algorithm
(Sinkhorn, Gromov-Wasserstein, barycenters, etc...), regardless of the
geometric setup that is considered. To give a concrete example, any
speedups one may benefit from by using a specific cost
(e.g. Sinkhorn being faster when run on a separable cost on histograms supported
on a separable grid :cite:`solomon:15`) should not require a separate
reimplementation of a Sinkhorn routine.
speedups one may benefit from by using a specific cost (e.g. Sinkhorn being
faster when run on a separable cost on histograms supported on a separable
grid :cite:`solomon:15`) should not require a separate reimplementation
of a Sinkhorn routine.
- As a consequence, and to minimize code copy/pasting, use as often as possible
object hierarchies, and interleave outer solvers (such as quadratic,
aka Gromov-Wasserstein solvers) with inner solvers (e.g. Low-Rank Sinkhorn).
This choice ensures that speedups achieved at lower computation levels
(e.g. low-rank factorization of squared Euclidean distances) propagate seamlessly and
automatically in higher level calls (e.g. updates in Gromov-Wasserstein),
without requiring any attention from the user.
(e.g. low-rank factorization of squared Euclidean distances) propagate
seamlessly and automatically in higher level calls (e.g. updates in
Gromov-Wasserstein), without requiring any attention from the user.

.. TODO(marcocuturi): add missing package descriptions below

Packages
--------
- :ref:`geometry` contains classes to instantiate objects that describe
- :doc:`geometry` contains classes to instantiate objects that describe
*two point clouds* paired with a *cost* function. Geometry objects are used to
describe OT problems, handled by solvers in the :ref:`solvers`.
- :ref:`problems`
- :ref:`solvers`
- :ref:`initializers`
- :ref:`tools` provides an interface to exploit OT solutions, as produced by
solvers in the :ref:`solvers`. Such tasks include computing approximations
describe OT problems, handled by solvers in the solvers.
- :doc:`problems/index`
- :doc:`solvers/index`
- :doc:`initializers/index`
- :doc:`tools` provides an interface to exploit OT solutions, as produced by
solvers in the solvers. Such tasks include computing approximations
to Wasserstein distances :cite:`genevay:18,sejourne:19`, approximating OT
between GMMs, or computing differentiable sort and quantile operations
:cite:`cuturi:19`.
- :ref:`math`
- :doc:`math`

.. toctree::
:maxdepth: 1
Expand Down Expand Up @@ -116,3 +121,4 @@ Packages
.. _auto-vectorization (VMAP): https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap
.. _automatic: https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation
.. _implicit: https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html#jax.custom_jvp
.. _conda-forge: https://anaconda.org/conda-forge/ott-jax
6 changes: 2 additions & 4 deletions docs/initializers/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
.. _initializers:

ott.initializers package
========================
ott.initializers
================

.. TODO(cuturi): add some nice text here please

Expand Down
4 changes: 2 additions & 2 deletions docs/initializers/linear.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.initializers.linear package
===============================
ott.initializers.linear
=======================
.. currentmodule:: ott.initializers.linear
.. automodule:: ott.initializers.linear

Expand Down
4 changes: 2 additions & 2 deletions docs/initializers/nn.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.initializers.nn package
===========================
ott.initializers.nn
===================
.. currentmodule:: ott.initializers.nn
.. automodule:: ott.initializers.nn

Expand Down
4 changes: 2 additions & 2 deletions docs/initializers/quadratic.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.initializers.quadratic package
==================================
ott.initializers.quadratic
==========================
.. currentmodule:: ott.initializers.quadratic
.. automodule:: ott.initializers.quadratic

Expand Down
6 changes: 2 additions & 4 deletions docs/math.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
.. _math:

ott.math package
================
ott.math
========
.. currentmodule:: ott.math
.. automodule:: ott.math

Expand Down
6 changes: 2 additions & 4 deletions docs/problems/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
.. _problems:

ott.problems package
====================
ott.problems
============

.. TODO(marcocuturi): add some nice text here please

Expand Down
4 changes: 2 additions & 2 deletions docs/problems/linear.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.problems.linear package
===========================
ott.problems.linear
===================
.. currentmodule:: ott.problems.linear
.. automodule:: ott.problems.linear

Expand Down
4 changes: 2 additions & 2 deletions docs/problems/quadratic.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.problems.quadratic package
==============================
ott.problems.quadratic
======================
.. currentmodule:: ott.problems.quadratic
.. automodule:: ott.problems.quadratic

Expand Down
6 changes: 2 additions & 4 deletions docs/solvers/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
.. _solvers:

ott.solvers package
===================
ott.solvers
===========

.. TODO(marcocuturi): add some nice text here please

Expand Down
4 changes: 2 additions & 2 deletions docs/solvers/linear.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.solvers.linear package
==========================
ott.solvers.linear
==================
.. currentmodule:: ott.solvers.linear
.. automodule:: ott.solvers.linear

Expand Down
4 changes: 2 additions & 2 deletions docs/solvers/nn.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.solvers.nn package
======================
ott.solvers.nn
==============
.. currentmodule:: ott.solvers.nn
.. automodule:: ott.solvers.nn

Expand Down
4 changes: 2 additions & 2 deletions docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.solvers.quadratic package
=============================
ott.solvers.quadratic
=====================
.. currentmodule:: ott.solvers.quadratic
.. automodule:: ott.solvers.quadratic

Expand Down
13 changes: 6 additions & 7 deletions docs/tools.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
.. _tools:

ott.tools package
=================
ott.tools
=========
.. currentmodule:: ott.tools
.. automodule:: ott.tools

The tools package contains high level functions that build on outputs produced by core functions.
They can be used to compute Sinkhorn divergences :cite:`sejourne:19`, instantiate transport matrices,
provide differentiable approximations to ranks and quantile functions :cite:`cuturi:19`, etc.
The tools package contains high level functions that build on outputs produced
by core functions. They can be used to compute Sinkhorn divergences
:cite:`sejourne:19`, instantiate transport matrices, provide differentiable
approximations to ranks and quantile functions :cite:`cuturi:19`, etc.

Segmented Sinkhorn
------------------
Expand Down
Loading