Skip to content

Commit

Permalink
Add regularized costs (ott-jax#244)
Browse files Browse the repository at this point in the history
* Add new cost functions

* Polish docstrings

* Remove `jaxopt` dependency

* Add test

* Polish `ElasticSqKOverlap`

* Fix indexing error

* Remove k-overlap from Legendre test

* Test correlation for `ElasticSqKOverlap`

* Test sparse displacements

* Add sparsity test

* Fix docs

* Simplify `reg` in `ElasticSqKOverlap`

* Rename `ElasticNet` -> `ElasticL1`

* Change `frac_sparse` in test

* Try removing jax pin in CI

* Try pinning jax/flax in GPU tests

* Fix pins in CI

* Fix indentation in docs
  • Loading branch information
michalk8 authored Feb 9, 2023
1 parent eee4ac1 commit a836f4b
Show file tree
Hide file tree
Showing 7 changed files with 378 additions and 54 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Install dependencies
# `jax[cuda]<0.4` because of: https://github.com/google/jax/issues/13758
# `jax[cuda]<0.4` because of Docker issues: https://github.com/google/jax/issues/13758
# `flax<0.6.5` because it requires `jax>=0.4.2`
run: |
python3 -m pip install --upgrade pip
python3 -m pip install -e".[test]"
python3 -m pip install "flax<0.6.5"
python3 -m pip install "jax[cuda]<0.4" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- name: Nvidia SMI
Expand Down
5 changes: 3 additions & 2 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ Cost Functions
.. autosummary::
:toctree: _autosummary

costs.CostFn
costs.TICost
costs.SqPNorm
costs.PNormP
costs.SqEuclidean
costs.Euclidean
costs.Cosine
costs.Bures
costs.UnbalancedBures
costs.ElasticL1
costs.ElasticSTVS
costs.ElasticSqKOverlap

Utilities
---------
Expand Down
14 changes: 7 additions & 7 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ as differentiable approximations to ranking or even clustering.

To achieve this, ``OTT`` rests on two families of tools:

- the first family consists in *discrete* solvers computing transport between
point clouds, using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn
:cite:`scetbon:21` algorithms, and moving up towards Gromov-Wasserstein
:cite:`memoli:11,peyre:16`;
- the second family consists in *continuous* solvers, using suitable neural
architectures :cite:`amos:17` coupled with SGD type estimators
:cite:`makkuva:20,korotin:21`.
- the first family consists in *discrete* solvers computing transport between
point clouds, using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn
:cite:`scetbon:21` algorithms, and moving up towards Gromov-Wasserstein
:cite:`memoli:11,peyre:16`;
- the second family consists in *continuous* solvers, using suitable neural
architectures :cite:`amos:17` coupled with SGD type estimators
:cite:`makkuva:20,korotin:21`.

Installation
------------
Expand Down
37 changes: 37 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -703,3 +703,40 @@ @ARTICLE{chen:20
pages={2133-2147},
doi={10.1109/TPAMI.2019.2908635}
}

@ARTICLE{schreck:15,
author={Schreck, Amandine and Fort, Gersende and Le Corff, Sylvain and Moulines, Eric},
journal={IEEE Journal of Selected Topics in Signal Processing},
title={A Shrinkage-Thresholding Metropolis Adjusted Langevin Algorithm for Bayesian Variable Selection},
year={2016},
volume={10},
number={2},
pages={366-375},
doi={10.1109/JSTSP.2015.2496546}
}

@inproceedings{argyriou:12,
author = {Argyriou, Andreas and Foygel, Rina and Srebro, Nathan},
booktitle = {Advances in Neural Information Processing Systems},
editor = {F. Pereira and C.J. Burges and L. Bottou and K.Q. Weinberger},
pages = {},
publisher = {Curran Associates, Inc.},
title = {Sparse Prediction with the k-Support Norm},
url = {https://proceedings.neurips.cc/paper/2012/file/99bcfcd754a98ce89cb86f73acc04645-Paper.pdf},
volume = {25},
year = {2012}
}

@article{zou:05,
author = {Zou, Hui and Hastie, Trevor},
title = {Regularization and variable selection via the elastic net},
journal = {Journal of the Royal Statistical Society: Series B (Statistical Methodology)},
volume = {67},
number = {2},
pages = {301-320},
keywords = {Grouping effect, LARS algorithm, Lasso, Penalization, p≫n problem, Variable selection},
doi = {https://doi.org/10.1111/j.1467-9868.2005.00503.x},
url = {https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/j.1467-9868.2005.00503.x},
eprint = {https://rss.onlinelibrary.wiley.com/doi/pdf/10.1111/j.1467-9868.2005.00503.x},
year = {2005}
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ classifiers = [
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]

[project.urls]
Expand Down
Loading

0 comments on commit a836f4b

Please sign in to comment.