Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to the neural dual solver #219

Merged
merged 59 commits into from
Feb 13, 2023
Merged

Improvements to the neural dual solver #219

merged 59 commits into from
Feb 13, 2023

Conversation

bamos
Copy link
Contributor

@bamos bamos commented Dec 22, 2022

Hi @marcocuturi @bunnech @michalk8, here is a WIP PR with the updates we've discussed to the neural dual solver. It improves the example from:

image

to this:

image

My paper has some more context on these updates adding some more options for fine-tuning and amortizing the conjugate approximation. I've marked this as a WIP because I'm still debugging the inverse map in the example, but otherwise, the rest of the code is ready for an initial review as it would be good to get your thoughts on the core updates and API/doc changes. I can also call and chat a little more about these soon if that would be helpful. Here are quick summaries of my updates to the code and notebook.

Updates to the core neural solver

  1. swap the order of f and g potentials (\cc Bugs in the DualPotentials class #182 (comment))
  2. add an option to add a conjugate solver that fine-tunes g's prediction when updating f, along with a default solver using JaxOpt's LBFGS optimizer
  3. add amortization modes to learn g: regression and objective (Makkuva et al.)
  4. an option to update both potentials with the same batch (in parallel) in addition to the separate inner loop Makkuva et al. uses
  5. the option to model the gradient of the g potential with a neural network and a simple MLP implementing this instead of the ICNN from Makkuva et al.
  6. a callback so the training progress can be monitored (debugging the notebook/code was very difficult without this)

Updates to the notebook/example

(I'm still tweaking these around a little as it's still a little suboptimal)

  1. Update f ICNN size from [64 ,64, 64, 64] to [128, 128]
    and modified the architecture
  2. Use an MLP to model the gradient of g
  3. Use the default JaxOpt LBFGS conjugate solver and
    update g to regress onto it
  4. Increase batch size from 1000 to 10000
  5. Sample the data in batches rather than sequentially
  6. Tweak Adam (use weight decay and a cosine schedule)

I've also tried running the code with an approach closer to the older implementation, i.e., the approach from Makkuva et al. that models g with an ICNN and uses 10 inner objective-based updates, but can't get it to work as nicely as the version with conjugate fine-tuning+regression.

Remaining discussion and TODOs

  • Decide on best format for the arguments. E.g., should the conjugate solver be a function and the amortization mode a string?
  • Decide on the best defaults to use (a lightweight LBFGS conjugate solver for fine-tuning that regresses g onto it, or something closer to the approach in Makkuva et al.?)
  • Improve the documentation and notebook to clarify these options.
    • Add/enable docs for the conjugate solver and any other significantly new code
  • Update other places using the Neural OT code (like the initialization schemes notebook)
  • Think about any tests that would be useful to add for these updates
  • Fix the inverse map in the notebook
  • Decide on the default ICNN architecture
  • Add ability to save/load the parameters and ability to re-initialize the solves with them

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link

codecov-commenter commented Dec 22, 2022

Codecov Report

Merging #219 (45f33be) into main (de2c1d7) will decrease coverage by 1.27%.
The diff coverage is 65.35%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #219      +/-   ##
==========================================
- Coverage   87.19%   85.93%   -1.27%     
==========================================
  Files          50       52       +2     
  Lines        5334     5587     +253     
  Branches      806      852      +46     
==========================================
+ Hits         4651     4801     +150     
- Misses        563      663     +100     
- Partials      120      123       +3     
Impacted Files Coverage Δ
src/ott/solvers/linear/sinkhorn.py 95.19% <ø> (ø)
src/ott/problems/linear/potentials.py 63.81% <11.76%> (-26.39%) ⬇️
src/ott/solvers/nn/neuraldual.py 55.96% <55.03%> (-14.41%) ⬇️
src/ott/solvers/nn/models.py 80.71% <80.71%> (ø)
src/ott/problems/nn/dataset.py 94.87% <94.87%> (ø)
src/ott/solvers/nn/conjugate_solvers.py 100.00% <100.00%> (ø)

@michalk8 michalk8 added enhancement New feature or request bugfix labels Dec 23, 2022
Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Brandon!!! Thanks for the holidays present :) This looks really nice! Just a very short review to start with some comments, will go deeper after holidays are over.

src/ott/problems/linear/potentials.py Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Brandon, a few more coments.

Should we discuss more specifically the items you have outlined in the PR?

src/ott/solvers/nn/icnn.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/icnn.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/mlp.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/mlp.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
@bamos
Copy link
Contributor Author

bamos commented Jan 3, 2023

(I'll fix that test soon too)

@michalk8 michalk8 self-requested a review January 4, 2023 09:53
pyproject.toml Outdated Show resolved Hide resolved
src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
docs/references.bib Outdated Show resolved Hide resolved
src/ott/solvers/nn/icnn.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/mlp.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
docs/notebooks/neural_dual.ipynb Outdated Show resolved Hide resolved
docs/notebooks/neural_dual.ipynb Outdated Show resolved Hide resolved
docs/notebooks/neural_dual.ipynb Outdated Show resolved Hide resolved
docs/notebooks/neural_dual.ipynb Outdated Show resolved Hide resolved
docs/notebooks/neural_dual.ipynb Outdated Show resolved Hide resolved
docs/notebooks/neural_dual.ipynb Outdated Show resolved Hide resolved
@bamos
Copy link
Contributor Author

bamos commented Jan 4, 2023

Thanks for the review @michalk8! I just pushed in 0bf2467 to integrate most of your smaller comments and left some larger ones unresolved for now that I'll come back to in a larger pass

Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot Brandon for this amazing contribution!!

@michalk8
Copy link
Collaborator

michalk8 commented Jan 9, 2023

@bamos in #233 the parser in notebooks was changed to use myst-nb in order to provide a better way of specifying links to the classes/functions/etc., just wanted to let you know.

@marcocuturi
Copy link
Contributor

Hi @bamos, now that ICML is behind us do you want to push this ? do you need any help?

@bamos
Copy link
Contributor Author

bamos commented Jan 31, 2023

Hi @bamos, now that ICML is behind us do you want to push this?

Yes! I'll rebase on top of the main branch and see if I can fix the inverse map over the next few days. (I think it's a matter of finding the right hyper-parameters...) Then it should be ready for a final review/merge.

@bamos
Copy link
Contributor Author

bamos commented Feb 8, 2023

Finally got it working! :)

image
image


The key to making these work was to use a non-convex MLP for the potential instead of the ICNNs. Unfortunately the ICNNs were extremely difficult to debug in this setting and I couldn't figure out how to get the inverse map working correctly. (It seems like even numerically solving for the inverse still results in a lot of artifacts, so it may indicate that the forward potential with the ICNN is also off a little although it visually looks ok)

I also added a back_and_forth option, inspired from this paper that alternates between updating the forward and inverse maps. This helps stabilize the potentials to be conjugates of each other. (This also came up in a conversation I was having with @APooladian.) Without this, the non-convex potentials quickly diverge:

image

Tomorrow I'll make a pass through the notebook (it's not updated with these results) and should be able to address the remaining TODOs here to have a final version of this PR ready for review. I'll make the notebook start with the simpler task that the ICNN can fit, and then afterwards show how an MLP can be swapped in and used for this harder setting.

@bamos
Copy link
Contributor Author

bamos commented Feb 9, 2023

The GPU tests failed here with error

ERROR tests/utils_test.py - AttributeError: module 'jax.sharding' has no attribute 'PartitionSpec'

Flax just released v0.6.5 a few hours ago that I think is causing this. It could either be fixed by bumping up the jax version or upper-bounding the flax version. I just added 839719e to try to fix this, but can revert/edit if you prefer something else

@bamos bamos changed the title [WIP] Improvements to the neural dual solver Improvements to the neural dual solver Feb 9, 2023
@bamos bamos marked this pull request as ready for review February 9, 2023 06:05
@bamos bamos requested review from michalk8 and marcocuturi and removed request for michalk8 and marcocuturi February 9, 2023 06:06
@bamos
Copy link
Contributor Author

bamos commented Feb 9, 2023

Hi @michalk8 @marcocuturi, I've a version ready to be reviewed and merged in! Can you please go through the latest version and let me know if there's anything else you want me to get in? I tried to re-request reviews from both of you but for some reason github only lets me send it for one of you.

@bamos
Copy link
Contributor Author

bamos commented Feb 9, 2023

Also @marcocuturi, I tried the idea you've mentioned before of starting with the ICNN weights negative and annealing them to be positive by shifting them throughout training, i.e., z_weights = softplus(x+offset)-offset where offset started at 10 and decreased to 0 over a few thousand training iterations. I wasn't able to quickly make it work nicely, as it seems like the model started relying too much on the negative values and wasn't able to correct them as they were pushed to be positive, although there may be a better initialization/annealing strategy than what I tried. I don't have the code easily ready to keep building on this as I hacked in something to condition the positive dense layer on the training iteration, but it could be relatively easy to add back if a step-dependent annealing seems promising elsewhere.

@michalk8
Copy link
Collaborator

michalk8 commented Feb 9, 2023

The GPU tests failed here with error

ERROR tests/utils_test.py - AttributeError: module 'jax.sharding' has no attribute 'PartitionSpec'

Flax just released v0.6.5 a few hours ago that I think is causing this. It could either be fixed by bumping up the jax version or upper-bounding the flax version. I just added 839719e to try to fix this, but can revert/edit if you prefer something else

Seems like new flax requires jax>=0.4.2 here, whereas our GPU CI needs jax<0.4 because of this issue which I wasn't able to resolve so far: jax-ml/jax#13758; will take another look at it.

src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/conjugate_solver.py Outdated Show resolved Hide resolved
@michalk8
Copy link
Collaborator

@bamos could you please also add some test for the W2NeuralDual using the potential MLP? Maybe parameterizing the already present one will be the easiest.

@bamos
Copy link
Contributor Author

bamos commented Feb 10, 2023

@bamos could you please also add some test for the W2NeuralDual using the potential MLP? Maybe parameterizing the already present one will be the easiest.

Yeah, I'll try adding every combination of supported networks to the tests; using ICNN, MLP (potential), and MLP (gradient). And can add some for the extra modes too (like back_and_forth)

@bamos
Copy link
Contributor Author

bamos commented Feb 10, 2023

I noticed the test also has a ToyDataset class in there: https://github.com/ott-jax/ott/blob/main/tests/solvers/nn/neuraldual_test.py#L25 (also the square_four and square_five modes there are not used in the testing code). The source code for ToyDataset is duplicated in the notebooks, and in the main neural OT tutorial, I've updated it to sample batches. Do you think it makes sense to have this as a module in OTT to help reduce this duplication, or do you think datasets and sampling should remain separate from OTT's core? Connecting to #219 (comment) above it seems useful to have readily available data samplers for some common tasks somewhere (even outside of the core) that can easily be loaded and plugged into OTT's solvers

@michalk8
Copy link
Collaborator

I noticed the test also has a ToyDataset class in there: https://github.com/ott-jax/ott/blob/main/tests/solvers/nn/neuraldual_test.py#L25 (also the square_four and square_five modes there are not used in the testing code). The source code for ToyDataset is duplicated in the notebooks, and in the main neural OT tutorial, I've updated it to sample batches. Do you think it makes sense to have this as a module in OTT to help reduce this duplication, or do you think datasets and sampling should remain separate from OTT's core? Connecting to #219 (comment) above it seems useful to have readily available data samplers for some common tasks somewhere (even outside of the core) that can easily be loaded and plugged into OTT's solvers

I think having the toy example dataset in ott/problems/nn/dataset.py would be great to reduce code duplication, make it easier to try out neural OT + in the future, develop more complex interface (like converting from LinearProblem).

@bamos
Copy link
Contributor Author

bamos commented Feb 10, 2023

I think having the toy example dataset in ott/problems/nn/dataset.py would be great to reduce code duplication, make it easier to try out neural OT + in the future, develop more complex interface (like converting from LinearProblem).


@bamos could you please also add some test for the W2NeuralDual using the potential MLP? Maybe parameterizing the already present one will be the easiest.

Yeah, I'll try adding every combination of supported networks to the tests; using ICNN, MLP (potential), and MLP (gradient). And can add some for the extra modes too (like back_and_forth)


I just pushed in efa2507 addressing both of these

src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/problems/nn/dataset.py Outdated Show resolved Hide resolved
src/ott/problems/nn/dataset.py Show resolved Hide resolved
src/ott/problems/nn/dataset.py Outdated Show resolved Hide resolved
src/ott/problems/nn/dataset.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/models.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/models.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/neuraldual.py Outdated Show resolved Hide resolved
@bamos
Copy link
Contributor Author

bamos commented Feb 11, 2023

Something else interesting: the entropic potentials are looking really nice on this toy data! It could make sense to use the entropic potentials trained on a small set of samples to initialize/pre-train the neural potentials. \cc @APooladian

from ott.problems.nn import dataset
from ott.solvers.linear import sinkhorn
from ott.geometry.pointcloud import PointCloud
from ott.problems.linear.potentials import EntropicPotentials
from ott.problems.linear import linear_problem

train_dataloaders, valid_dataloaders, input_dim = dataset.create_gaussian_mixture_samplers(
    name_source="square_five", name_target="square_four")

eval_data_source = next(valid_dataloaders.source_iter)
eval_data_target = next(valid_dataloaders.target_iter)

geom = PointCloud(eval_data_source, eval_data_target, epsilon=0.001)
prob = linear_problem.LinearProblem(geom)
out = sinkhorn.solve(geom,
                     threshold=0.005, lse_mode=True, jit=False,
                     max_iterations=1000)
entpots = EntropicPotentials(out.f, out.g, prob)
entpots.plot_ot_map(eval_data_source, eval_data_target)

gives:

image
image

Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, I got just a few minor comments (mostly regarding documentation). After these, will be happy to merge!

src/ott/problems/nn/dataset.py Show resolved Hide resolved
src/ott/problems/nn/dataset.py Outdated Show resolved Hide resolved
src/ott/problems/nn/dataset.py Outdated Show resolved Hide resolved
src/ott/problems/linear/potentials.py Outdated Show resolved Hide resolved
src/ott/problems/linear/potentials.py Outdated Show resolved Hide resolved
docs/tutorials/notebooks/icnn_inits.ipynb Outdated Show resolved Hide resolved
docs/problems/nn.rst Outdated Show resolved Hide resolved
src/ott/solvers/nn/models.py Show resolved Hide resolved
docs/solvers/nn.rst Outdated Show resolved Hide resolved
@bamos
Copy link
Contributor Author

bamos commented Feb 12, 2023

Looks great, I got just a few minor comments (mostly regarding documentation). After these, will be happy to merge!

Great! Just put them in at 45f33be

@michalk8
Copy link
Collaborator

This is a very nice contribution @bamos , LGTM!

@michalk8 michalk8 merged commit 40bf4af into ott-jax:main Feb 13, 2023
michalk8 pushed a commit that referenced this pull request Feb 14, 2023
* fix D101 and B028 (stack level for warnings) introduced in #219

* add vscode to gitignore

* fix D102 on costs.py and try to simplify inheritance methods

* try to fix D102 everywhere

* fix D103

* remove comment from .flake8

* address comments

* more comments

* re-introduce tree-flatten in costs
pierreablin pushed a commit to pierreablin/ott that referenced this pull request Feb 14, 2023
* fix D101 and B028 (stack level for warnings) introduced in ott-jax#219

* add vscode to gitignore

* fix D102 on costs.py and try to simplify inheritance methods

* try to fix D102 everywhere

* fix D103

* remove comment from .flake8

* address comments

* more comments

* re-introduce tree-flatten in costs
michalk8 pushed a commit that referenced this pull request Feb 16, 2023
* fix D101 and B028 (stack level for warnings) introduced in #219

* add vscode to gitignore

* fix D102 on costs.py and try to simplify inheritance methods

* try to fix D102 everywhere

* fix D103

* remove comment from .flake8

* add ruff to pyproject.toml

* add ruff

* add documentation on ruff select and add unfixable

* add target version

* add comment on tool.ruff

* pyproject.toml

* format

* remove artifcats from wrong merge

* address ruff comments

* fix ruff imports
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
* Initial batch of updates to neuraldual code.

This swaps the order of f and g potentials and adds:

1. an option to add a conjugate solver that fine-tunes g's prediction
   when updating f, along with a default solver using JaxOpt's
   LBFGS optimizer
2. amortization modes to learn g
3. an option to update both potentials with the same batch (in parallel)
4. the option to model the *gradient* of the g potential
   with a neural network and a simple MLP implementing this
5. a callback so the training progress can be monitored

* icnn: only don't use an activation on the first layer

* set default amortization loss to regression

* Improve notebook

1. Update f ICNN size from [64 ,64, 64, 64] to [128, 128],
   use LReLU activations, and increase init_std to 1.0
2. Use an MLP to model the gradient of g
3. Use the default JaxOpt LBFGS conjugate solver and
   update g to regress onto it
4. Increase batch size from 1000 to 10000
5. Sample the data in batches rather than sequentially
6. Tweak Adam (use weight decay and a cosine schedule)

* update icnn to use the posdef potential last

* X,Y -> source,target

* don't fine-tune when the conjugate solver is not provided

* provides_gradient->returns_potential

* NeuralDualSolver->W2NeuralDual

* ConjugateSolverLBFGS->FenchelConjugateLBFGS

* rm wandb in logging comment

* remove line split

* BFGS->LBFGS

* address review comments from @michalk8

* also permit the MLP to model the potential (in addition to the gradient)

* add alternating update directions option and ability to initialize with existing parameters

* add option to finetune g when getting the potentials

* ->back_and_forth

* fix typing

* add back the W2 distance estimate and penalty

* fix logging with back-and-forth

* add weight term

* default MLP to returns_potential=True

* fix newline

* update icnn_inits nb

* bump down init_std

* make bib style consistent

* ->W2NeuralDual

* pass on docs

* limit flax version

* neural_dual nb: use two examples, finalize potentials with stable hyper-params

* update pin for jax/flax testing dependency

* address review comments from @michalk8

* fix potential type hints

* re-run nb with latest code

* address review comments from @michalk8

* potential_{value,gradient} -> functions

* add latest nb (no outputs)

* fix typos

* plot_ot_map: keep source/target measures fixed

* fix types

* add review comments from @michalk8 and @marcocuturi

* consolidate models

* inverse->forward

* support back_and_forth when using a gradient mapping, improve docs

* move plotting code

* update nbs

* update intro text

* conjugate_solver -> conjugate_solvers

* address review comments from @michalk8

* update docstring

* create problems.nn.dataset, add more tests

* fix indentation

* polish docs and restructure dataset to address @michalk8's comments

* move plotting code back to ott.problems.linear.potentials

* fix test and dataloaders->dataset

* update notebooks to latest code

* bump up batch size and update notebooks

* address review comments from @michalk8 (sorry for multiple force-pushes)
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
* fix D101 and B028 (stack level for warnings) introduced in #219

* add vscode to gitignore

* fix D102 on costs.py and try to simplify inheritance methods

* try to fix D102 everywhere

* fix D103

* remove comment from .flake8

* address comments

* more comments

* re-introduce tree-flatten in costs
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
* fix D101 and B028 (stack level for warnings) introduced in #219

* add vscode to gitignore

* fix D102 on costs.py and try to simplify inheritance methods

* try to fix D102 everywhere

* fix D103

* remove comment from .flake8

* add ruff to pyproject.toml

* add ruff

* add documentation on ruff select and add unfixable

* add target version

* add comment on tool.ruff

* pyproject.toml

* format

* remove artifcats from wrong merge

* address ruff comments

* fix ruff imports
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants