Skip to content

Commit

Permalink
Docs/notebooks myst (ott-jax#233)
Browse files Browse the repository at this point in the history
* Remove `sinkhorn function`

* Fix `sinkhorn_divergence` test

* Remove `gromov_wasserstein` function

* Remove `make` functions

* Fix `soft_sort` and Jacobian tests

* Remove `Transport` interface

* Fix Jacobian test

* Fix `soft_sort` and tests

* Clean up some tests

* Fix wrong `value_and_grad` usage

* Update notebooks, isort and pre-commit

* [ci skip] Fix rendering in `Sinkhorn`

* Handle TODOs, clean initializer tests

* Add `sinkhorn.solve` utility

* Re-add `gromov_wasserstein.solve`, polish docs

* Remove redundant line from `pyproject.toml`

* Polish quad docs

* Add rank to `sinkhorn.solve`

* Add `rank` to `sinkhorn.solve`

* Start with `myst-nb`

* Update rest of the notebooks

* Fix remaining `<data-cite>`

* Fix pepy badge link

* Update references in `gmm_pair_demo.ipynb`

* First tutorial structure

* Update structure

* Remove old requirements

* Address comments

* Fix notebook tests
  • Loading branch information
michalk8 authored Jan 9, 2023
1 parent 7b14b52 commit e23bf1c
Show file tree
Hide file tree
Showing 39 changed files with 407 additions and 320 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<img src="https://raw.githubusercontent.com/ott-jax/ott/main/docs/_static/images/logoOTT.png" width="10%" alt="logo">

# Optimal Transport Tools (OTT)
[![Downloads](https://pepy.tech/badge/ott-jax)](https://pypi.org/project/ott-jax/)
[![Downloads](https://static.pepy.tech/badge/ott-jax)](https://pypi.org/project/ott-jax/)
[![Tests](https://img.shields.io/github/actions/workflow/status/ott-jax/ott/tests.yml?branch=main)](https://github.com/ott-jax/ott/actions/workflows/tests.yml)
[![Docs](https://img.shields.io/readthedocs/ott-jax/latest)](https://ott-jax.readthedocs.io/en/latest/)
[![Coverage](https://img.shields.io/codecov/c/github/ott-jax/ott/main)](https://app.codecov.io/gh/ott-jax/ott)
Expand Down
47 changes: 29 additions & 18 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@
'sphinx.ext.viewcode',
'sphinxcontrib.bibtex',
'sphinx_copybutton',
'nbsphinx',
'myst_nb',
'IPython.sphinxext.ipython_console_highlighting',
'sphinx_autodoc_typehints',
'recommonmark',
]

intersphinx_mapping = {
Expand All @@ -64,15 +63,28 @@
"flax": ("https://flax.readthedocs.io/en/latest/", None),
"scikit-sparse": ("https://scikit-sparse.readthedocs.io/en/latest/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
"pot": ("https://pythonot.github.io/", None),
}

master_doc = 'index'
source_suffix = ['.rst']
source_suffix = {
'.rst': 'restructuredtext',
'.ipynb': 'myst-nb',
}
todo_include_todos = False

autosummary_generate = True

autodoc_typehints = 'description'

# myst-nb
myst_heading_anchors = 2
nb_execution_mode = "off"
myst_enable_extensions = [
'amsmath',
'colon_fence',
'dollarmath',
]

# bibliography
bibtex_bibfiles = ["references.bib"]
bibtex_reference_style = "author_year"
Expand All @@ -99,17 +111,16 @@
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']

nbsphinx_codecell_lexer = "ipython3"
nbsphinx_execute = 'never'
nbsphinx_prolog = r"""
{% set docname = 'docs/' + env.doc2path(env.docname, base=None) %}
.. raw:: html
<div class="docutils container">
<a class="reference external"
href="https://colab.research.google.com/github/ott-jax/ott/blob/main/{{ docname|e }}">
<img alt="Open in Colab" src="../_static/images/colab-badge.svg" width="125px">
</a>
</div>
"""
html_theme_options = {
'repository_url': 'https://github.com/ott-jax/ott',
'repository_branch': 'main',
'path_to_docs': 'docs/',
'use_repository_button': True,
'use_fullscreen_button': False,
'logo_only': True,
'launch_buttons': {
'colab_url': 'https://colab.research.google.com',
'binderhub_url': 'https://mybinder.org',
'notebook_interface': 'jupyterlab',
},
}
40 changes: 6 additions & 34 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,40 +74,14 @@ Packages

.. toctree::
:maxdepth: 1
:caption: Tutorials:
:caption: Examples

notebooks/point_clouds.ipynb
notebooks/introduction_grid.ipynb
Getting Started <tutorials/notebooks/point_clouds>
tutorials/index

.. toctree::
:maxdepth: 1
:caption: Benchmarks:

notebooks/OTT_&_POT.ipynb
notebooks/One_Sinkhorn.ipynb
notebooks/LRSinkhorn.ipynb

.. toctree::
:maxdepth: 1
:caption: Advanced Applications:

notebooks/Sinkhorn_Barycenters.ipynb
notebooks/gromov_wasserstein.ipynb
notebooks/GWLRSinkhorn.ipynb
notebooks/Hessians.ipynb
notebooks/soft_sort.ipynb
notebooks/application_biology.ipynb
notebooks/gromov_wasserstein_multiomics.ipynb
notebooks/fairness.ipynb
notebooks/neural_dual.ipynb
notebooks/icnn_inits.ipynb
notebooks/wasserstein_barycenters_gmms.ipynb
notebooks/gmm_pair_demo.ipynb
notebooks/MetaOT.ipynb

.. toctree::
:maxdepth: 1
:caption: Public API: ott packages
:caption: API

geometry
problems/index
Expand All @@ -118,13 +92,11 @@ Packages

.. toctree::
:maxdepth: 1
:caption: References:
:caption: References

GitHub <https://github.com/ott-jax/ott>
references


.. |Downloads| image:: https://pepy.tech/badge/ott-jax
.. |Downloads| image:: https://static.pepy.tech/badge/ott-jax
:target: https://pypi.org/project/ott-jax/
:alt: Documentation

Expand Down
54 changes: 46 additions & 8 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,17 @@ @Article{vayer:20
doi = {10.3390/a13090212}
}

@Article{demetci:20,
author = {Demetci, Pinar and Santorella, Rebecca and Sandstede, Bj{\"o}rn and Noble, William Stafford and Singh, Ritambhara},
title = {Gromov-Wasserstein optimal transport to align single-cell multi-omics data},
elocation-id = {2020.04.28.066787},
year = {2020},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2020/11/11/2020.04.28.066787},
journal = {bioRxiv}
@Article{demetci:22,
author = {Demetci, Pinar and Santorella, Rebecca and Sandstede, Bj\"{o}rn and Noble, William Stafford and
Singh, Ritambhara},
title = {SCOT: Single-Cell Multi-Omics Alignment with Optimal Transport},
journal = {Journal of Computational Biology},
volume = {29},
number = {1},
pages = {3-18},
year = {2022},
doi = {10.1089/cmb.2021.0446},
note = {PMID: 35050714},
}

@Article{chen:19,
Expand Down Expand Up @@ -665,3 +668,38 @@ @InProceedings{sejourne:22
publisher = {PMLR},
url = {https://proceedings.mlr.press/v151/sejourne22a/sejourne22a.pdf},
}

@Article{thibault:21,
author = {Thibault, Alexis and Chizat, Lénaïc and Dossal, Charles and Papadakis, Nicolas},
title = {Overrelaxed Sinkhorn–Knopp Algorithm for Regularized Optimal Transport},
journal = {Algorithms},
volume = {14},
year = {2021},
number = {5},
article-number = {143},
issn = {1999-4893},
doi = {10.3390/a14050143}
}

@InProceedings{chen:16,
author="Chen, Yukun and Ye, Jianbo and Li, Jia",
editor="Leibe, Bastian and Matas, Jiri and Sebe, Nicu and Welling, Max",
title="A Distance for HMMs Based on Aggregated Wasserstein Metric and State Registration",
booktitle="Computer Vision -- ECCV 2016",
year="2016",
publisher="Springer International Publishing",
address="Cham",
pages="451--466",
isbn="978-3-319-46466-4"
}

@ARTICLE{chen:20,
author={Chen, Yukun and Ye, Jianbo and Li, Jia},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={Aggregated Wasserstein Distance and State Registration for Hidden Markov Models},
year={2020},
volume={42},
number={9},
pages={2133-2147},
doi={10.1109/TPAMI.2019.2908635}
}
19 changes: 19 additions & 0 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,22 @@ Clustering

k_means.k_means
k_means.KMeansOutput

ott.tools.gaussian_mixture package
----------------------------------
.. currentmodule:: ott.tools.gaussian_mixture
.. automodule:: ott.tools.gaussian_mixture

.. TODO(cuturi): add a description
Gaussian Mixtures
^^^^^^^^^^^^^^^^^
.. autosummary::
:toctree: _autosummary

gaussian.Gaussian
gaussian_mixture.GaussianMixture
gaussian_mixture_pair.GaussianMixturePair
fit_gmm.initialize
fit_gmm.fit_model_em
fit_gmm_pair.get_fit_model_em_fn
55 changes: 55 additions & 0 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
Tutorials
=========

Geometry
--------
.. toctree::
:maxdepth: 1

notebooks/introduction_grid

Linear Optimal Transport
------------------------
.. toctree::
:maxdepth: 1

notebooks/One_Sinkhorn
notebooks/OTT_&_POT
notebooks/Hessians
notebooks/LRSinkhorn

Barycenters
^^^^^^^^^^^
.. toctree::
:maxdepth: 1

notebooks/Sinkhorn_Barycenters
notebooks/gmm_pair_demo
notebooks/wasserstein_barycenters_gmms

Miscellaneous
^^^^^^^^^^^^^
.. toctree::
:maxdepth: 1

notebooks/soft_sort
notebooks/fairness
notebooks/application_biology

Quadratic Optimal Transport
---------------------------
.. toctree::
:maxdepth: 1

notebooks/gromov_wasserstein
notebooks/GWLRSinkhorn
notebooks/gromov_wasserstein_multiomics

Neural Optimal Transport
------------------------
.. toctree::
:maxdepth: 1

notebooks/neural_dual
notebooks/icnn_inits
notebooks/MetaOT
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"source": [
"# Low-Rank Gromov-Wasserstein\n",
"\n",
"We try in this colab the low-rank (LR) Gromov-Wasserstein Solver, proposed by <cite data-cite=\"scetbon:22\">Scetbon et al.</cite>, a follow up to the LR Sinkhorn solver in <cite data-cite=\"scetbon:21\">Scetbon et al.</cite>."
"We try in this colab the low-rank (LR) Gromov-Wasserstein Solver, proposed by {cite}`scetbon:22`, a follow up to the LR Sinkhorn solver in {cite}`scetbon:21`."
]
},
{
Expand Down Expand Up @@ -81,7 +81,7 @@
"id": "y4aQGprB_oeW"
},
"source": [
"Create two toy point clouds of heterogeneous size, and add a third geometry to provide a fused problem (see <cite data-cite=\"vayer:20\">Vayer et al.</cite>).\n"
"Create two toy point clouds of heterogeneous size, and add a third geometry to provide a fused problem {cite}`vayer:20`."
]
},
{
Expand Down Expand Up @@ -121,7 +121,7 @@
"id": "dS49krqd_weJ"
},
"source": [
"Solve the problem using the Low-Rank Sinkhorn solver."
"Solve the problem using the Solve the problem using the low-rank solver, utilizing {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solver under the hood."
]
},
{
Expand Down Expand Up @@ -153,7 +153,7 @@
"id": "vxDoBrusUHmq"
},
"source": [
"Run it with entropic-GW for the sake of comparison"
"Run it with entropic {class}`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` solver for the sake of comparison."
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
"id": "jzzs0FmbPpvY"
},
"source": [
"## Samples two point clouds, computes their ``sinkhorn_divergence``\n",
"## Samples two point clouds, computes their Sinkhorn divergence\n",
"\n",
"We show in colab how OTT and JAX can be used to compute automatically the Hessian of the Sinkhorn divergence w.r.t. input variables, such as weights ``a`` or locations ``x``."
"We show in colab how OTT and JAX can be used to compute automatically the Hessian of the {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` w.r.t. the input variables, such as weights ``a`` or locations ``x``."
]
},
{
Expand Down Expand Up @@ -93,9 +93,9 @@
"id": "kOtW4xTTSJhg"
},
"source": [
"As usual in JAX, we define a custom loss that outputs the quantity of interest, and is defined using relevant inputs as arguments, i.e. parameters against which we may want to differentiate. We add to `a` and `x` the ``implicit`` auxiliary flag which will be used to switch between unrolling and implicit differentiation of the Sinkhorn algorithm (see this excellent [tutorial](http://implicit-layers-tutorial.org/implicit_functions/) for a deep dive on their differences!)\n",
"As usual in JAX, we define a custom loss that outputs the quantity of interest, and is defined using relevant inputs as arguments, i.e. parameters against which we may want to differentiate. We add to `a` and `x` the ``implicit`` auxiliary flag which will be used to switch between unrolling and implicit differentiation of the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm (see this excellent [tutorial](http://implicit-layers-tutorial.org/implicit_functions/) for a deep dive on their differences).\n",
"\n",
"The loss outputs the Sinkhorn Divergence between two point clouds."
"The loss outputs the Sinkhorn divergence between two point clouds."
]
},
{
Expand Down Expand Up @@ -128,12 +128,13 @@
"id": "tnrx9dMnVDxD"
},
"source": [
"Let's parse the three lines in the call to ``sinkhorn_divergence`` above:\n",
"- The first one defines the point cloud geometry between ``x`` and ``y`` that will define the cost matrix. Here we could have added details on ``epsilon`` regularization (or scheduler), as well as alternative definitions of the cost function (here assumed by default to be squared Euclidean distance). We stick to the default setting.\n",
"Let's parse the three lines in the call to {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` above:\n",
"\n",
"- The second one sets the respective weight vectors `a` and `b`. Those are simply two histograms of size ``n`` and ``m``, both sum to 1, in the so-called balanced setting.\n",
"- The first one defines the point cloud geometry between `x` and `y` that will define the cost matrix. Here we could have added details on `epsilon` regularization (or scheduler), as well as alternative definitions of the cost function (here assumed by default to be squared Euclidean distance). We stick to the default setting.\n",
"\n",
"- The third one passes on arguments to the three ``sinkhorn`` solvers that will be called, to compare ``x`` with ``y``, ``x`` with ``x`` and ``y`` with ``y`` with their respective weights ``a`` and ``b``. Rather than focusing on the several numerical options available to parmeterize ``sinkhorn``'s behavior, we instruct JAX on how it should differentiate the outputs of the sinkhorn algorithm. The ``use_danskin`` flag specifies whether the outputted potentials should be freezed when differentiating. Since we aim for 2nd order differentiation here, we must set this to ``False`` (if we wanted to compute gradients, ``True`` would have resulted in faster yet almost equivalent computations)."
"- The second one sets the respective weight vectors `a` and `b`. Those are simply two histograms of size ``n`` and `m`, both sum to 1, in the so-called balanced setting.\n",
"\n",
"- The third one passes on arguments to the three {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solvers that will be called, to compare ``x`` with `y`, `x` with `x` and `y` with `y` with their respective weights `a` and `b`. Rather than focusing on the several numerical options available to parameterize {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`'s behavior, we instruct JAX on how it should differentiate the outputs of the sinkhorn algorithm. The `use_danskin` flag specifies whether the outputted potentials should be freezed when differentiating. Since we aim for 2nd order differentiation here, we must set this to ``False`` (if we wanted to compute gradients, ``True`` would have resulted in faster yet almost equivalent computations)."
]
},
{
Expand All @@ -151,13 +152,12 @@
"id": "StMRwYUJVuOY"
},
"source": [
"Let's now plot Hessians of this output w.r.t. either ``a`` or ``x``. \n",
"\n",
"- The Hessian w.r.t. ``a`` will be a $n \\times n$ matrix, with the convention that ``a`` has size $n$. \n",
"Let's now plot Hessians of this output w.r.t. either `a` or `x`. \n",
"\n",
"- Because ``x`` is itself a matrix of 3D coordinates, the Hessian w.r.t. ``x`` will be a 4D tensor of size $n \\times 3 \\times n \\times 3$.\n",
"- The Hessian w.r.t. `a` will be a $n \\times n$ matrix, with the convention that `a` has size $n$. \n",
"- Because `x` is itself a matrix of 3D coordinates, the Hessian w.r.t. `x` will be a 4D tensor of size $n \\times 3 \\times n \\times 3$.\n",
"\n",
"To plot both Hessians, we loop on arg 0 or 1 of ``loss``, and plot all (or part for ``x``) of those Hessians, to check they match:"
"To plot both Hessians, we loop on arg 0 or 1 of `loss`, and plot all (or part for `x`) of those Hessians, to check they match:"
]
},
{
Expand Down
Loading

0 comments on commit e23bf1c

Please sign in to comment.