This repository provides a full implementation of our method for estimating Schrödinger bridges on the basis of i.i.d. samples from the source and target measures, see here. Our approach, called the Sinkhorn bridge, is based on expressing the time-dependent drift as a function of the static potentials which solve the entropic optimal transport problem on the data. The GIF below is an visualization of our approach for computing the Sinkhorn bridge for three common low-dimensional datasets in the machine learning literature.
Jupyter notebooks that replicate the experiments found in our paper can be found in examples
. These include the 2D examples in the GIF above, the Gaussian-to-Gaussian setting, and a benchmark setting due to Grushchin et al. (2023), where we estimate a non-trivial drift in high-dimensions.
Our approach is simulation-free in that we reduce the problem of estimating the drift defining the Schrödinger bridge (in either the forward or backward direction) to estimation of the potentials that define the entropic optimal transport coupling on the data which are computed using Sinkhorn's algorithm (Cuturi 2013). We provide implementations in both POT and OTT-JAX frameworks. The method consists of three hyper parameters. First, the user defines the level of noise eps
which is passed into Sinkhorn's algorithm, and is used to define the drift. Then the user is required to pass the duration of the drift tau
in [0,1), and the number of steps for the Euler--Maruyama discretization, written Nsteps
. From here, the estimator takes care of the bridging process which can be initialized at new samples from the source measure. (Note: attau = 1
the bridge collapses onto the training data!)
If you found this code helpful, or are building upon this work, please cite
Aram-Alexandre Pooladian and Jonathan Niles-Weed. "Plug-in estimation of Schrödinger bridges" arXiv. 2024. [arxiv]
@article{pooladian2024plug,
title={Plug-in estimation of Schr\"odinger bridges},
author={Pooladian, Aram-Alexandre and Niles-Weed, Jonathan},
journal={arXiv preprint arXiv:2408.11686},
year={2024}
}