-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to the neural dual solver #219
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## main #219 +/- ##
==========================================
- Coverage 87.19% 85.93% -1.27%
==========================================
Files 50 52 +2
Lines 5334 5587 +253
Branches 806 852 +46
==========================================
+ Hits 4651 4801 +150
- Misses 563 663 +100
- Partials 120 123 +3
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Brandon!!! Thanks for the holidays present :) This looks really nice! Just a very short review to start with some comments, will go deeper after holidays are over.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks Brandon, a few more coments.
Should we discuss more specifically the items you have outlined in the PR?
(I'll fix that test soon too) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot Brandon for this amazing contribution!!
Hi @bamos, now that ICML is behind us do you want to push this ? do you need any help? |
Yes! I'll rebase on top of the main branch and see if I can fix the inverse map over the next few days. (I think it's a matter of finding the right hyper-parameters...) Then it should be ready for a final review/merge. |
Finally got it working! :) The key to making these work was to use a non-convex MLP for the potential instead of the ICNNs. Unfortunately the ICNNs were extremely difficult to debug in this setting and I couldn't figure out how to get the inverse map working correctly. (It seems like even numerically solving for the inverse still results in a lot of artifacts, so it may indicate that the forward potential with the ICNN is also off a little although it visually looks ok) I also added a Tomorrow I'll make a pass through the notebook (it's not updated with these results) and should be able to address the remaining TODOs here to have a final version of this PR ready for review. I'll make the notebook start with the simpler task that the ICNN can fit, and then afterwards show how an MLP can be swapped in and used for this harder setting. |
The GPU tests failed here with error
Flax just released v0.6.5 a few hours ago that I think is causing this. It could either be fixed by bumping up the jax version or upper-bounding the flax version. I just added 839719e to try to fix this, but can revert/edit if you prefer something else |
Hi @michalk8 @marcocuturi, I've a version ready to be reviewed and merged in! Can you please go through the latest version and let me know if there's anything else you want me to get in? I tried to re-request reviews from both of you but for some reason github only lets me send it for one of you. |
Also @marcocuturi, I tried the idea you've mentioned before of starting with the ICNN weights negative and annealing them to be positive by shifting them throughout training, i.e., |
Seems like new |
@bamos could you please also add some test for the |
Yeah, I'll try adding every combination of supported networks to the tests; using ICNN, MLP (potential), and MLP (gradient). And can add some for the extra modes too (like |
I noticed the test also has a |
I think having the toy example dataset in |
I just pushed in efa2507 addressing both of these |
Something else interesting: the entropic potentials are looking really nice on this toy data! It could make sense to use the entropic potentials trained on a small set of samples to initialize/pre-train the neural potentials. \cc @APooladian from ott.problems.nn import dataset
from ott.solvers.linear import sinkhorn
from ott.geometry.pointcloud import PointCloud
from ott.problems.linear.potentials import EntropicPotentials
from ott.problems.linear import linear_problem
train_dataloaders, valid_dataloaders, input_dim = dataset.create_gaussian_mixture_samplers(
name_source="square_five", name_target="square_four")
eval_data_source = next(valid_dataloaders.source_iter)
eval_data_target = next(valid_dataloaders.target_iter)
geom = PointCloud(eval_data_source, eval_data_target, epsilon=0.001)
prob = linear_problem.LinearProblem(geom)
out = sinkhorn.solve(geom,
threshold=0.005, lse_mode=True, jit=False,
max_iterations=1000)
entpots = EntropicPotentials(out.f, out.g, prob)
entpots.plot_ot_map(eval_data_source, eval_data_target) gives: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, I got just a few minor comments (mostly regarding documentation). After these, will be happy to merge!
Great! Just put them in at 45f33be |
This is a very nice contribution @bamos , LGTM! |
* fix D101 and B028 (stack level for warnings) introduced in #219 * add vscode to gitignore * fix D102 on costs.py and try to simplify inheritance methods * try to fix D102 everywhere * fix D103 * remove comment from .flake8 * address comments * more comments * re-introduce tree-flatten in costs
* fix D101 and B028 (stack level for warnings) introduced in ott-jax#219 * add vscode to gitignore * fix D102 on costs.py and try to simplify inheritance methods * try to fix D102 everywhere * fix D103 * remove comment from .flake8 * address comments * more comments * re-introduce tree-flatten in costs
* fix D101 and B028 (stack level for warnings) introduced in #219 * add vscode to gitignore * fix D102 on costs.py and try to simplify inheritance methods * try to fix D102 everywhere * fix D103 * remove comment from .flake8 * add ruff to pyproject.toml * add ruff * add documentation on ruff select and add unfixable * add target version * add comment on tool.ruff * pyproject.toml * format * remove artifcats from wrong merge * address ruff comments * fix ruff imports
* Initial batch of updates to neuraldual code. This swaps the order of f and g potentials and adds: 1. an option to add a conjugate solver that fine-tunes g's prediction when updating f, along with a default solver using JaxOpt's LBFGS optimizer 2. amortization modes to learn g 3. an option to update both potentials with the same batch (in parallel) 4. the option to model the *gradient* of the g potential with a neural network and a simple MLP implementing this 5. a callback so the training progress can be monitored * icnn: only don't use an activation on the first layer * set default amortization loss to regression * Improve notebook 1. Update f ICNN size from [64 ,64, 64, 64] to [128, 128], use LReLU activations, and increase init_std to 1.0 2. Use an MLP to model the gradient of g 3. Use the default JaxOpt LBFGS conjugate solver and update g to regress onto it 4. Increase batch size from 1000 to 10000 5. Sample the data in batches rather than sequentially 6. Tweak Adam (use weight decay and a cosine schedule) * update icnn to use the posdef potential last * X,Y -> source,target * don't fine-tune when the conjugate solver is not provided * provides_gradient->returns_potential * NeuralDualSolver->W2NeuralDual * ConjugateSolverLBFGS->FenchelConjugateLBFGS * rm wandb in logging comment * remove line split * BFGS->LBFGS * address review comments from @michalk8 * also permit the MLP to model the potential (in addition to the gradient) * add alternating update directions option and ability to initialize with existing parameters * add option to finetune g when getting the potentials * ->back_and_forth * fix typing * add back the W2 distance estimate and penalty * fix logging with back-and-forth * add weight term * default MLP to returns_potential=True * fix newline * update icnn_inits nb * bump down init_std * make bib style consistent * ->W2NeuralDual * pass on docs * limit flax version * neural_dual nb: use two examples, finalize potentials with stable hyper-params * update pin for jax/flax testing dependency * address review comments from @michalk8 * fix potential type hints * re-run nb with latest code * address review comments from @michalk8 * potential_{value,gradient} -> functions * add latest nb (no outputs) * fix typos * plot_ot_map: keep source/target measures fixed * fix types * add review comments from @michalk8 and @marcocuturi * consolidate models * inverse->forward * support back_and_forth when using a gradient mapping, improve docs * move plotting code * update nbs * update intro text * conjugate_solver -> conjugate_solvers * address review comments from @michalk8 * update docstring * create problems.nn.dataset, add more tests * fix indentation * polish docs and restructure dataset to address @michalk8's comments * move plotting code back to ott.problems.linear.potentials * fix test and dataloaders->dataset * update notebooks to latest code * bump up batch size and update notebooks * address review comments from @michalk8 (sorry for multiple force-pushes)
* fix D101 and B028 (stack level for warnings) introduced in #219 * add vscode to gitignore * fix D102 on costs.py and try to simplify inheritance methods * try to fix D102 everywhere * fix D103 * remove comment from .flake8 * address comments * more comments * re-introduce tree-flatten in costs
* fix D101 and B028 (stack level for warnings) introduced in #219 * add vscode to gitignore * fix D102 on costs.py and try to simplify inheritance methods * try to fix D102 everywhere * fix D103 * remove comment from .flake8 * add ruff to pyproject.toml * add ruff * add documentation on ruff select and add unfixable * add target version * add comment on tool.ruff * pyproject.toml * format * remove artifcats from wrong merge * address ruff comments * fix ruff imports
Hi @marcocuturi @bunnech @michalk8, here is a WIP PR with the updates we've discussed to the neural dual solver. It improves the example from:
to this:
My paper has some more context on these updates adding some more options for fine-tuning and amortizing the conjugate approximation. I've marked this as a WIP because I'm still debugging the inverse map in the example, but otherwise, the rest of the code is ready for an initial review as it would be good to get your thoughts on the core updates and API/doc changes. I can also call and chat a little more about these soon if that would be helpful. Here are quick summaries of my updates to the code and notebook.
Updates to the core neural solver
DualPotentials
class #182 (comment))Updates to the notebook/example
(I'm still tweaking these around a little as it's still a little suboptimal)
and modified the architecture
update g to regress onto it
I've also tried running the code with an approach closer to the older implementation, i.e., the approach from Makkuva et al. that models g with an ICNN and uses 10 inner objective-based updates, but can't get it to work as nicely as the version with conjugate fine-tuning+regression.
Remaining discussion and TODOs