ot:
Multi-lib backend for POT |
|
Batch operations for optimal transport. |
|
Solvers related to Bregman projections for entropic regularized OT |
|
Efficient combinatorial optimization for low transport cost bijections based on Binary Space Partitioning trees (BSP-OT). |
|
CO-Optimal Transport solver |
|
Domain adaptation with optimal transport |
|
Simple example datasets |
|
Dimension reduction with OT |
|
Factored OT solvers (low rank, cost or OT plan) |
|
Optimal transport for Gaussian distributions |
|
Optimal transport for Gaussian Mixtures |
|
Layers and functions for optimal transport in Graph Neural Networks. |
|
Solvers related to Gromov-Wasserstein problems. |
|
Low rank OT solvers |
|
Solvers for the original linear program OT problem. |
|
Optimal Transport maps and variants |
|
Generic solvers for regularized OT or its semi-relaxed version. |
|
Efficient 1D solver for the partial optimal transport problem. |
|
Functions for plotting OT matrices |
|
Regularization path OT solvers |
|
Solvers related to (balanced) sliced transport. |
|
Spectral-Grassmann optimal transport for linear operators. |
|
Smooth and Sparse (KL an L2 reg.) and sparsity-constrained OT solvers. |
|
Stochastic solvers for regularized OT. |
|
Solvers related to Unbalanced Optimal Transport problems. |
|
Various useful functions |
|
Weak optimal ransport solvers |
ot functionsWarning
The list of automatically imported sub-modules is as follows:
ot.lp, ot.bregman, ot.optim
ot.utils, ot.datasets,
ot.gromov, ot.smooth
ot.stochastic, ot.partial, ot.regpath
, ot.unbalanced, ot.sliced, ot.mapping .
The following sub-modules are not imported due to additional dependencies:
- ot.dr : depends on pymanopt and autograd.
- ot.plot : depends on matplotlib
Compute the entropic regularized wasserstein barycenter of distributions \(\mathbf{A}\)
The function solves the following optimization problem:
where :
\(W_{reg}(\cdot,\cdot)\) is the entropic regularized Wasserstein
distance (see ot.bregman.sinkhorn())
if method is sinkhorn or sinkhorn_stabilized or sinkhorn_log.
\(\mathbf{a}_i\) are training distributions in the columns of matrix \(\mathbf{A}\)
reg and \(\mathbf{M}\) are respectively the regularization term and the cost matrix for OT
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]
A (array-like, shape (dim, n_hists)) – n_hists training distributions \(\mathbf{a}_i\) of size dim
M (array-like, shape (dim, dim)) – loss matrix for OT
reg (float) – Regularization term > 0
method (str (optional)) – method used for the solver either ‘sinkhorn’ or ‘sinkhorn_stabilized’ or ‘sinkhorn_log’
weights (array-like, shape (n_hists,)) – Weights of each histogram \(\mathbf{a}_i\) on the simplex (barycentric coordinates)
numItermax (int, optional) – Max number of iterations
stopThr (float, optional) – Stop threshold on error (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
warn (bool, optional) – if True, raises a warning if the algorithm doesn’t convergence.
a ((dim,) array-like) – Wasserstein barycenter
log (dict) – log dictionary return only if log==True in parameters
References
Compute the entropic unbalanced wasserstein barycenter of \(\mathbf{A}\).
The function solves the following optimization problem with \(\mathbf{a}\)
where :
\(W_{u_{reg}}(\cdot,\cdot)\) is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced())
\(\mathbf{a}_i\) are training distributions in the columns of matrix \(\mathbf{A}\)
reg and \(\mathbf{M}\) are respectively the regularization term and the cost matrix for OT
reg_mis the marginal relaxation hyperparameter
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]
A (array-like, shape (dim, n_hists)) – n_hists training distributions \(\mathbf{a}_i\) of dimension dim
M (array-like, shape (dim, dim)) – ground metric matrix for OT.
reg (float) – Entropy regularization term > 0
reg_m (float) – Marginal relaxation term > 0
weights (array-like, shape (n_hists,) optional) – Weight of each distribution (barycentric coordinates) If None, uniform weights are used.
numItermax (int, optional) – Max number of iterations
stopThr (float, optional) – Stop threshold on error (> 0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
a (array-like, shape (dim,)) – Unbalanced Wasserstein barycenter
log (dict) – log dictionary return only if log==True in parameters
References
Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprin arXiv:1607.05816.
Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. Samples need to be in \(S^1\cong [0,1[\). If they are on \(\mathbb{R}\), takes the value modulo 1. If the values are on \(S^1\subset\mathbb{R}^2\), it is required to first find the coordinates using e.g. the atan2 function.
where:
\(F_u\) and \(F_v\) are respectively the cdfs of \(u\) and \(v\)
For values \(x=(x_1,x_2)\in S^1\), it is required to first get their coordinates with
using e.g. ot.utils.get_coordinate_circle(x)
The function runs on backend but tensorflow and jax are not supported.
u_values (ndarray, shape (n, ...)) – samples in the source domain (coordinates on [0,1[)
v_values (ndarray, shape (n, ...)) – samples in the target domain (coordinates on [0,1[)
u_weights (ndarray, shape (n, ...), optional) – samples weights in the source domain
v_weights (ndarray, shape (n, ...), optional) – samples weights in the target domain
p (float, optional (default=1)) – Power p used for computing the Wasserstein distance
Lm (int, optional) – Lower bound dC
Lp (int, optional) – Upper bound dC
tm (float, optional) – Lower bound theta
tp (float, optional) – Upper bound theta
eps (float, optional) – Stopping condition
require_sort (bool, optional) – If True, sort the values.
log (bool, optional) – If True, returns also the optimal theta
loss (float/array-like, shape (…)) – Batched cost associated to the optimal transportation
log (dict, optional) – log dictionary returned only if log==True in parameters
Examples
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> binary_search_circle(u.T, v.T, p=1)
array([0.1])
References
Delon, Julie, Julien Salomon, and Andrei Sobolevski. “Fast transport optimization for Monge costs on the circle.” SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
This solver provides a good and fast approximation of the combinatorial problem of finding a bijection between two point clouds that minimizes the transport cost:
To do so, it generates \(n_{plans}\) random bijective BSP matchings, merges them together to obtain a bijection of low transport cost. Log-linear complexity in the number of points. Algorithm 2 & 3 from [84].
Note
There is no guarantee on the quality of the returned bijection, but the method is highly scalable on the CPU. Worst cases are obtained between point clouds that are very similar (e.g. two samples from the same distribution), where the solver can get stuck in local minima, but works well when the point clouds are very different. The method also works best for the standard squared euclidean cost (p=2), as this cost enables efficient BSP construction heuristic (with a cubic dependence on the dimension, this feature is disabled for dimensions larger than 64).
X (array-like, shape (n_samples, dimension))
Y (array-like, shape (n_samples, dimension))
n_plans (int) – The number of BSP Matchings used to compute the final bijection.
p (int, optional) – The power of the ground metric (default 2 for squared euclidean, -1 for infinity norm).
initial_perm (array-like, shape (n_samples,), optional) – Bijection to use for initializing merging (optional).
gaussian_slicing ("auto" or bool, optional) – If true then uses the Gaussian slicing heuristic to improve matching quality. Comes with a cubic complexity with dimension. If ‘auto’, then the heuristic is used for dimensions smaller than 64, and disabled for larger dimensions.
seed (int, optional) – Random seed for reproducibility (default 0).
cost (float) – The transport cost of the final bijection.
perm (array-like, shape (n_samples,)) – The final bijection, stored as a permutation (e.g. a list of numbers) such that X[i] is assigned to Y[perm[i]].
perms (array-like, shape (n_plans,n_samples)) – The intermediary bijections used to compute the final one.
Compute distance between samples in \(\mathbf{x_1}\) and \(\mathbf{x_2}\)
Note
This function is backend-compatible and will work on arrays from all compatible backends for the following metrics: ‘sqeuclidean’, ‘euclidean’, ‘cityblock’, ‘minkowski’, ‘cosine’, ‘correlation’.
x1 (array-like, shape (n1,d)) – matrix with n1 samples of size d
x2 (array-like, shape (n2,d), optional) – matrix with n2 samples of size d (if None then \(\mathbf{x_2} = \mathbf{x_1}\))
metric (str | callable, optional) – ‘sqeuclidean’ or ‘euclidean’ on all backends. On numpy the function also accepts from the scipy.spatial.distance.cdist function : ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulczynski1’, ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’, ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’.
p (float, optional) – p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
w (array-like, rank 1) – Weights for the weighted metrics.
backend (str, optional) – Backend to use for the computation. If ‘auto’, the backend is
automatically selected based on the input data. if ‘scipy’,
the scipy.spatial.distance.cdist function is used (and gradients are
detached).
use_tensor (bool, optional) – If true use tensorized computation for the distance matrix which can cause memory issues for large datasets. Default is False and the parameter is used only for the ‘cityblock’ and ‘minkowski’ metrics.
nx (Backend, optional) – Backend to perform computations on. If omitted, the backend defaults to that of x1.
M – distance matrix computed with given metric
array-like, shape (n1, n2)
Batched version of ot.dist, use it to compute many distance matrices in parallel.
X1 (array-like, shape (b,n1,d)) – b matrices with n1 samples of size d
X2 (array-like, shape (b,n2,d), optional) – b matrices with n2 samples of size d (if None then \(\mathbf{X_2} = \mathbf{X_1}\))
metric (str, optional) – ‘sqeuclidean’, ‘euclidean’, ‘minkowski’ or ‘kl’
p (float, optional) – p-norm for the Minkowski metrics. Default value is 2.
nx (Backend, optional) – Backend to perform computations on. If omitted, the backend defaults to that of x1.
M – distance matrix computed with given metric
array-like, shape (b, n1, n2)
Examples
>>> import numpy as np
>>> from ot.batch import dist_batch
>>> X1 = np.random.randn(5, 10, 3)
>>> X2 = np.random.randn(5, 15, 3)
>>> M = dist_batch(X1, X2, metric="euclidean")
>>> M.shape
(5, 10, 15)
See also
ot.distequivalent non-batched function.
Solves the Earth Movers distance problem and returns the OT matrix
where :
\(\mathbf{M}\) is the metric cost matrix
\(\mathbf{a}\) and \(\mathbf{b}\) are the sample weights
Warning
Note that the \(\mathbf{M}\) matrix in numpy needs to be a C-order numpy.array in float64 format. It will be converted if not in this format
Note
This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays.
Note
This function will cast the computed transport plan to the data type of the provided input with the following priority: \(\mathbf{a}\), then \(\mathbf{b}\), then \(\mathbf{M}\) if marginals are not provided. Casting to an integer tensor might result in a loss of precision. If this behaviour is unwanted, please make sure to provide a floating point input.
Note
An error will be raised if the vectors \(\mathbf{a}\) and \(\mathbf{b}\) do not sum to the same value.
Uses the algorithm proposed in [1].
a ((ns,) array-like, float) – Source histogram (uniform weight if empty list)
b ((nt,) array-like, float) – Target histogram (uniform weight if empty list)
M ((ns,nt) array-like or sparse matrix, float) –
Loss matrix. Can be:
Dense array (c-order array in numpy with type float64)
Sparse matrix in backend’s format (scipy.sparse.coo_matrix for NumPy backend, torch.sparse_coo_tensor for PyTorch backend, etc.)
numItermax (int, optional (default=100000)) – The maximum number of iterations before stopping the optimization algorithm if it has not converged.
log (bool, optional (default=False)) – If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the optimal transportation matrix.
center_dual (boolean, optional (default=True)) – If True, centers the dual potential using function
ot.lp.center_ot_dual().
numThreads (int or "max", optional (default=1)) – Deprecated compatibility parameter. The network simplex solver no longer uses OpenMP, so this parameter is ignored.
check_marginals (bool, optional (default=True)) – If True, checks that the marginals mass are equal. If False, skips the check.
potentials_init (tuple of two arrays (alpha, beta), optional (default=None)) – Warmstart dual potentials to accelerate convergence. Should be a tuple (alpha, beta) where alpha is shape (ns,) and beta is shape (nt,). These potentials are used to guide initial pivots in the network simplex. Typically obtained from a previous EMD solve or Sinkhorn approximation.
note: (..) –
The solver automatically detects sparse format using the backend’s: issparse() method. For sparse inputs:
Uses a memory-efficient sparse EMD algorithm
Returns the transport plan as a sparse matrix in the backend’s format
Supports scipy.sparse (NumPy), torch.sparse (PyTorch), etc.
JAX and TensorFlow backends don’t support sparse matrices
gamma (array-like or sparse matrix, shape (ns, nt)) – Optimal transportation matrix for the given parameters.
For dense inputs: returns a dense array
For sparse inputs: returns a sparse matrix in the backend’s format (e.g., scipy.sparse.coo_matrix for NumPy, torch.sparse_coo_tensor for PyTorch)
log (dict, optional) – If input log is True, a dictionary containing the cost, dual variables, and exit status.
Examples
Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.emd(a, b, M)
array([[0.5, 0. ],
[0. , 0.5]])
References
Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
See also
ot.bregman.sinkhornEntropic regularized OT
ot.optim.cgGeneral regularized OT
Solves the Earth Movers distance problem and returns the loss
where :
\(\mathbf{M}\) is the metric cost matrix
\(\mathbf{a}\) and \(\mathbf{b}\) are the sample weights
Note
This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays.
Note
This function will cast the computed transport plan and transportation loss to the data type of the provided input with the following priority : \(\mathbf{a}\), then \(\mathbf{b}\), then \(\mathbf{M}\) if marginals are not provided. Casting to an integer tensor might result in a loss of precision. If this behaviour is unwanted, please make sure to provide a floating point input.
Note
An error will be raised if the vectors \(\mathbf{a}\) and \(\mathbf{b}\) do not sum to the same value.
Uses the algorithm proposed in [1].
a ((ns,) array-like, float64) – Source histogram (uniform weight if empty list)
b ((nt,) array-like, float64) – Target histogram (uniform weight if empty list)
M ((ns,nt) array-like or sparse matrix, float64) –
Loss matrix. Can be:
Dense array (c-order array in numpy with type float64)
Sparse matrix in backend’s format (scipy.sparse.coo_matrix for NumPy backend, torch.sparse_coo_tensor for PyTorch backend, etc.)
processes (int, optional (default=1)) – Nb of processes used for multiple emd computation (deprecated)
numItermax (int, optional (default=100000)) – The maximum number of iterations before stopping the optimization algorithm if it has not converged.
log (boolean, optional (default=False)) – If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the optimal transportation cost.
return_matrix (boolean, optional (default=False)) – If True, returns the optimal transportation matrix in the log.
center_dual (boolean, optional (default=True)) – If True, centers the dual potential using function
ot.lp.center_ot_dual().
numThreads (int or "max", optional (default=1)) – Deprecated compatibility parameter. The network simplex solver no longer uses OpenMP, so this parameter is ignored.
check_marginals (bool, optional (default=True)) – If True, checks that the marginals mass are equal. If False, skips the check.
potentials_init (tuple of two arrays (alpha, beta), optional (default=None)) – Warmstart dual potentials to accelerate convergence. Should be a tuple (alpha, beta) where alpha is shape (ns,) and beta is shape (nt,). These potentials are used to guide initial pivots in the network simplex. Typically obtained from a previous EMD solve or Sinkhorn approximation.
note: (..) –
The solver automatically detects sparse format using the backend’s: issparse() method. For sparse inputs:
Uses a memory-efficient sparse EMD algorithm
Edges not included are treated as having infinite cost (forbidden)
Supports scipy.sparse (NumPy), torch.sparse (PyTorch), etc.
JAX and TensorFlow backends don’t support sparse matrices
W (float, array-like) – Optimal transportation loss for the given parameters
log (dict) – If input log is true, a dictionary containing the cost, dual variables (u, v), exit status, and optionally the optimal transportation matrix (G) if return_matrix is True
Examples
Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.emd2(a,b,M)
0.0
References
Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
See also
ot.bregman.sinkhornEntropic regularized OT
ot.optim.cgGeneral regularized OT
Solves the Earth Movers distance problem between 1d measures and returns the loss
where :
d is the metric
\(x_a\) and \(x_b\) are the samples
a and b are the sample weights
This implementation only supports metrics of the form \(d(x, y) = |x - y|^p\).
Uses the algorithm detailed in [1]_
x_a (ndarray of float64, shape (ns,) or (ns, 1)) – Source dirac locations (on the real line)
x_b (ndarray of float64, shape (nt,) or (ns, 1)) – Target dirac locations (on the real line)
a (ndarray of float64, shape (ns,), optional) – Source histogram (default is uniform weight)
b (ndarray of float64, shape (nt,), optional) – Target histogram (default is uniform weight)
metric (str, optional (default='sqeuclidean')) – Metric to be used. Only works with either of the strings ‘sqeuclidean’, ‘minkowski’, ‘cityblock’, or ‘euclidean’.
p (float, optional (default=1.0)) – The p-norm to apply for if metric=’minkowski’
dense (boolean, optional (default=True)) – If True, returns \(\gamma\) as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy’s coo_matrix format. Only used if log is set to True. Due to implementation details, this function runs faster when dense is set to False.
log (boolean, optional (default=False)) – If True, returns a dictionary containing the transportation matrix. Otherwise returns only the loss.
loss (float) – Cost associated to the optimal transportation
log (dict) – If input log is True, a dictionary containing the Optimal transportation matrix for the given parameters
Examples
Simple example with obvious solution. The function emd2_1d accepts lists and performs automatic conversion to numpy arrays
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> x_a = [2., 0.]
>>> x_b = [0., 3.]
>>> ot.emd2_1d(x_a, x_b, a, b)
0.5
>>> ot.emd2_1d(x_a, x_b)
0.5
References
Peyré, G., & Cuturi, M. (2017). “Computational Optimal Transport”, 2018.
See also
ot.lp.emd2EMD for multidimensional distributions
ot.lp.emd_1dEMD for 1d distributions (returns the transportation matrix instead of the cost)
Solves the Earth Movers distance problem with lazy cost computation and returns the loss
where :
\(\mathbf{M}(\mathbf{X}_a, \mathbf{X}_b)\) is computed on-the-fly from coordinates
\(\mathbf{a}\) and \(\mathbf{b}\) are the sample weights
Note
This function computes distances on-the-fly during the network simplex algorithm, avoiding the O(ns*nt) memory cost of pre-computing the full cost matrix. Memory usage is O(ns+nt) instead.
Note
This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays.
X_a ((ns, dim) array-like, float64) – Source sample coordinates
X_b ((nt, dim) array-like, float64) – Target sample coordinates
a ((ns,) array-like, float64, optional) – Source histogram (uniform weight if None)
b ((nt,) array-like, float64, optional) – Target histogram (uniform weight if None)
metric (str, optional (default='sqeuclidean')) –
Distance metric for cost computation. Options:
’sqeuclidean’: Squared Euclidean distance
’euclidean’: Euclidean distance
’cityblock’: Manhattan/L1 distance
numItermax (int, optional (default=100000)) – Maximum number of iterations before stopping if not converged
log (boolean, optional (default=False)) – If True, returns a dictionary containing the cost, dual variables, and optionally the transport plan (sparse format)
return_matrix (boolean, optional (default=False)) – If True, returns the optimal transportation matrix in the log (sparse format)
center_dual (boolean, optional (default=True)) – If True, centers the dual potential using ot.lp.center_ot_dual()
check_marginals (bool, optional (default=True)) – If True, checks that the marginals mass are equal
potentials_init (tuple of (ns,) and (nt,) arrays, optional) – Initial dual potentials (u, v) to warmstart the solver. If provided, the solver starts from these potentials instead of a cold start.
W (float) – Optimal transportation loss
log (dict) – If input log is True, a dictionary containing:
cost: the optimal transportation cost
u, v: dual variables
warning: solver status message
result_code: solver return code
G: (optional) sparse transport plan if return_matrix=True
See also
ot.emd2EMD with pre-computed cost matrix
ot.lp.emd_c_lazyLow-level C++ lazy solver
Solves the Earth Movers distance problem between 1d measures and returns the OT matrix
where :
d is the metric
\(x_a\) and \(x_b\) are the samples
a and b are the sample weights
This implementation only supports metrics of the form \(d(x, y) = |x - y|^p\).
Uses the algorithm detailed in [1]_
x_a (ndarray of float64, shape (ns,) or (ns, 1)) – Source dirac locations (on the real line)
x_b (ndarray of float64, shape (nt,) or (ns, 1)) – Target dirac locations (on the real line)
a (ndarray of float64, shape (ns,), optional) – Source histogram (default is uniform weight)
b (ndarray of float64, shape (nt,), optional) – Target histogram (default is uniform weight)
metric (str, optional (default='sqeuclidean')) – Metric to be used. Only works with either of the strings ‘sqeuclidean’, ‘minkowski’, ‘cityblock’, or ‘euclidean’.
p (float, optional (default=1.0)) – The p-norm to apply for if metric=’minkowski’
dense (boolean, optional (default=True)) – If True, returns \(\gamma\) as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy’s coo_matrix format. Due to implementation details, this function runs faster when ‘sqeuclidean’, ‘minkowski’, ‘cityblock’, or ‘euclidean’ metrics are used.
log (boolean, optional (default=False)) – If True, returns a dictionary containing the cost. Otherwise returns only the optimal transportation matrix.
check_marginals (bool, optional (default=True)) – If True, checks that the marginals mass are equal. If False, skips the check.
gamma (ndarray, shape (ns, nt)) – Optimal transportation matrix for the given parameters
log (dict) – If input log is True, a dictionary containing the cost and the indices of the non-zero elements of the transportation matrix
Examples
Simple example with obvious solution. The function emd_1d accepts lists and performs automatic conversion to numpy arrays
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> x_a = [2., 0.]
>>> x_b = [0., 3.]
>>> ot.emd_1d(x_a, x_b, a, b)
array([[0. , 0.5],
[0.5, 0. ]])
>>> ot.emd_1d(x_a, x_b)
array([[0. , 0.5],
[0.5, 0. ]])
References
Peyré, G., & Cuturi, M. (2017). “Computational Optimal Transport”, 2018.
See also
ot.lp.emdEMD for multidimensional distributions
ot.lp.emd2_1dEMD for 1d distributions (returns cost instead of the transportation matrix)
Computes the 1 dimensional OT loss between two (batched) empirical distributions
and returns the dual potentials and the loss, i.e. such that
Warning
This function only works in pytorch or jax as it backpropagates through the wasserstein_1d function.
u_values (array-like, shape (n, ...)) – locations of the first empirical distribution
v_values (array-like, shape (m, ...)) – locations of the second empirical distribution
u_weights (array-like, shape (n, ...), optional) – weights of the first empirical distribution, if None then uniform weights are used
v_weights (array-like, shape (m, ...), optional) – weights of the second empirical distribution, if None then uniform weights are used
p (int, optional) – order of the ground metric used, should be at least 1, default is 1
require_sort (bool, optional) – sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to the function, default is True
f (array-like shape (n, …)) – First dual potential
g (array-like shape (m, …)) – Second dual potential
loss (float/array-like, shape (…)) – the batched EMD
Computes the Expected Sliced cost and plan between two datasets X_s and X_t of shapes (ns, d) and (nt, d). Given a set of n_projections projection directions, the expected sliced plan is obtained by averaging the n_projections 1d optimal transport plans between the projections of X_s and X_t on each direction. Expected Sliced was introduced in [87] and further studied in [86].
Note
The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases.
Warning
Tensorflow and jax only returns dense plans, as they do not support well sparse matrices.
X_s (array-like, shape (ns, d)) – The first set of vectors.
X_t (array-like, shape (nt, d)) – The second set of vectors.
a (ndarray of float64, shape (ns,), optional) – Source histogram (default is uniform weight)
b (ndarray of float64, shape (nt,), optional) – Target histogram (default is uniform weight)
projections (shape (dim, n_projections), optional) – Projection matrix (n_projections and seed are not used in this case). Default is None
metric (str, optional (default='sqeuclidean')) – Metric to be used. Only works with either of the strings ‘sqeuclidean’, ‘minkowski’, ‘cityblock’, or ‘euclidean’.
p (float, optional (default=2)) – The p-norm to apply for if metric=’minkowski’
n_projections (int, optional) – The number of projection directions. Required if projections is None.
seed (int, optional) – The seed for the random number generator for sampling projections, in case projections is None. Default is None.
beta (float, optional) – Inverse-temperature parameter which weights each projection’s contribution to the expected plan. Default is 0 (uniform weighting).
dense (boolean, optional (default=True)) – If True, returns \(\gamma\) as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy’s coo_matrix format.
batch_size (int, optional) – If specified, compute the distance in batches of size batch_size to avoid memory issues for large datasets. Default is None (no batching).
log (bool, optional) – If True, returns additional logging information. Default is False.
plan (ndarray, shape (ns, nt) or coo_matrix if dense is False) – Optimal transportation matrix for the given parameters.
cost (float) – The cost associated to the optimal permutation.
log_dict (dict, optional) – A dictionary containing intermediate computations for logging purposes. Returned only if log is True.
References
Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Transport Plans. arXiv preprint 2506.03661.
Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations.
Examples
>>> import ot
>>> import numpy as np
>>> x=np.array([[3.,3.], [1.,1.]])
>>> y=np.array([[2.,2.5], [3.,2.]])
>>> projections=np.array([[1, 0], [0, 1]])
>>> plan, cost = ot.expected_sliced_plan(x, y, projections=projections)
>>> plan
array([[0.25, 0.25],
[0.25, 0.25]])
>>> cost
2.625
Solves factored OT problem and return OT plans and intermediate distribution
This function solve the following OT problem [40]_
where :
\(\mu_a\) and \(\mu_b\) are empirical distributions.
\(\mu\) is an empirical distribution with r samples
And returns the two OT plans between
Note
This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays.
Uses the conditional gradient algorithm to solve the problem proposed in [39].
Xa ((ns,d) array-like, float) – Source samples
Xb ((nt,d) array-like, float) – Target samples
a ((ns,) array-like, float) – Source histogram (uniform weight if empty list)
b ((nt,) array-like, float) – Target histogram (uniform weight if empty list))
numItermax (int, optional) – Max number of iterations
stopThr (float, optional) – Stop threshold on the relative variation (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
Ga (array-like, shape (ns, r)) – Optimal transportation matrix between source and the intermediate distribution
Gb (array-like, shape (r, nt)) – Optimal transportation matrix between the intermediate and target distribution
X (array-like, shape (r, d)) – Support of the intermediate distribution
log (dict, optional) – If input log is true, a dictionary containing the cost and dual variables and exit status
References
Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). Statistical optimal transport via factored couplings. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.
See also
ot.bregman.sinkhornEntropic regularized OT
ot.optim.cgGeneral regularized OT
Returns the Fused Gromov-Wasserstein transport between \((\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})\) and \((\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})\) with pairwise distance matrix \(\mathbf{M}\) between node feature matrices \(\mathbf{Y_1}\) and \(\mathbf{Y_2}\) (see [24]).
The function solves the following optimization problem using Conditional Gradient:
Where :
\(\mathbf{M}\): metric cost matrix between features across domains
\(\mathbf{C_1}\): Metric cost matrix in the source space
\(\mathbf{C_2}\): Metric cost matrix in the target space
\(\mathbf{p}\): distribution in the source space
\(\mathbf{q}\): distribution in the target space
L: loss function to account for the misfit between the similarity and feature matrices
\(\alpha\): trade-off parameter
Note
This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays.
Note
All computations in the conditional gradient solver are done with numpy to limit memory overhead.
Note
This function will cast the computed transport plan to the data type of the provided input \(\mathbf{M}\). Casting to an integer tensor might result in a loss of precision. If this behaviour is unwanted, please make sure to provide a floating point input.
M (array-like, shape (ns, nt)) – Metric cost matrix between features across domains
C1 (array-like, shape (ns, ns)) – Metric cost matrix representative of the structure in the source space
C2 (array-like, shape (nt, nt)) – Metric cost matrix representative of the structure in the target space
p (array-like, shape (ns,), optional) – Distribution in the source space. If let to its default value None, uniform distribution is taken.
q (array-like, shape (nt,), optional) – Distribution in the target space. If let to its default value None, uniform distribution is taken.
loss_fun (str, optional) – Loss function used for the solver
symmetric (bool, optional) – Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
alpha (float, optional) – Trade-off parameter (0 < alpha < 1)
armijo (bool, optional) – If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False.
G0 (array-like, shape (ns,nt), optional) – If None the initial transport plan of the solver is pq^T. Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log (bool, optional) – record log if True
max_iter (int, optional) – Max number of iterations
tol_rel (float, optional) – Stop threshold on relative error (>0)
tol_abs (float, optional) – Stop threshold on absolute error (>0)
**kwargs (dict) – parameters can be directly passed to the ot.optim.cg solver
T (array-like, shape (ns, nt)) – Optimal transportation matrix for the given parameters.
log (dict) – Log dictionary return only if log==True in parameters.
References
Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas “Optimal Transport for structured data with application on graphs”, International Conference on Machine Learning (ICML). 2019.
Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein distance between networks and stable network invariants. Information and Inference: A Journal of the IMA, 8(4), 757-787.
Returns the Fused Gromov-Wasserstein distance between \((\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})\) and \((\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})\) with pairwise distance matrix \(\mathbf{M}\) between node feature matrices \(\mathbf{Y_1}\) and \(\mathbf{Y_2}\) (see [24]).
The function solves the following optimization problem using Conditional Gradient:
Where :
\(\mathbf{M}\): metric cost matrix between features across domains
\(\mathbf{C_1}\): Metric cost matrix in the source space
\(\mathbf{C_2}\): Metric cost matrix in the target space
\(\mathbf{p}\): distribution in the source space
\(\mathbf{q}\): distribution in the target space
L: loss function to account for the misfit between the similarity and feature matrices
\(\alpha\): trade-off parameter
Note that when using backends, this loss function is differentiable wrt the matrices (C1, C2, M) and weights (p, q) for quadratic loss using the gradients from [38]_.
Note
This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays.
Note
All computations in the conditional gradient solver are done with numpy to limit memory overhead.
Note
This function will cast the computed transport plan to the data type of the provided input \(\mathbf{M}\). Casting to an integer tensor might result in a loss of precision. If this behaviour is unwanted, please make sure to provide a floating point input.
M (array-like, shape (ns, nt)) – Metric cost matrix between features across domains
C1 (array-like, shape (ns, ns)) – Metric cost matrix representative of the structure in the source space.
C2 (array-like, shape (nt, nt)) – Metric cost matrix representative of the structure in the target space.
p (array-like, shape (ns,), optional) – Distribution in the source space. If let to its default value None, uniform distribution is taken.
q (array-like, shape (nt,), optional) – Distribution in the target space. If let to its default value None, uniform distribution is taken.
loss_fun (str, optional) – Loss function used for the solver.
symmetric (bool, optional) – Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
alpha (float, optional) – Trade-off parameter (0 < alpha < 1)
armijo (bool, optional) – If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False.
G0 (array-like, shape (ns,nt), optional) – If None the initial transport plan of the solver is pq^T. Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log (bool, optional) – Record log if True.
max_iter (int, optional) – Max number of iterations
tol_rel (float, optional) – Stop threshold on relative error (>0)
tol_abs (float, optional) – Stop threshold on absolute error (>0)
**kwargs (dict) – Parameters can be directly passed to the ot.optim.cg solver.
fgw-distance (float) – Fused Gromov-Wasserstein distance for the given parameters.
log (dict) – Log dictionary return only if log==True in parameters.
References
Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas “Optimal Transport for structured data with application on graphs” International Conference on Machine Learning (ICML). 2019.
C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein distance between networks and stable network invariants. Information and Inference: A Journal of the IMA, 8(4), 757-787.
Returns the Gromov-Wasserstein barycenters of S measured similarity matrices \((\mathbf{C}_s)_{1 \leq s \leq S}\)
The function solves the following optimization problem with block coordinate descent:
Where :
\(\mathbf{C}_s\): metric cost matrix
\(\mathbf{p}_s\): distribution
N (int) – Size of the targeted barycenter
Cs (list of S array-like of shape (ns, ns)) – Metric cost matrices
ps (list of S array-like of shape (ns,), optional) – Sample weights in the S spaces. If let to its default value None, uniform distributions are taken.
p (array-like, shape (N,), optional) – Weights in the targeted barycenter. If let to its default value None, uniform distribution is taken.
lambdas (list of float, optional) – List of the S spaces’ weights. If let to its default value None, uniform weights are taken.
loss_fun (callable, optional) – tensor-matrix multiplication function based on specific loss function
symmetric (bool, optional.) – Either structures are to be assumed symmetric or not. Default value is True. Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
armijo (bool, optional) – If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False.
max_iter (int, optional) – Max number of iterations
tol (float, optional) – Stop threshold on relative error (>0)
stop_criterion (str, optional. Default is 'barycenter'.) – Stop criterion taking values in [‘barycenter’, ‘loss’]. If set to ‘barycenter’ uses absolute norm variations of estimated barycenters. Else if set to ‘loss’ uses the relative variations of the loss.
warmstartT (bool, optional) – Either to perform warmstart of transport plans in the successive fused gromov-wasserstein transport problems.s
verbose (bool, optional) – Print information along iterations.
log (bool, optional) – Record log if True.
init_C (bool | array-like, shape(N,N)) – Random initial value for the \(\mathbf{C}\) matrix provided by user.
random_state (int or RandomState instance, optional) – Fix the seed for reproducibility
C (array-like, shape (N, N)) – Similarity matrix in the barycenter space (permutated arbitrarily)
log (dict) – Only returned when log=True. It contains the keys:
\(\mathbf{T}\): list of (N, ns) transport matrices
\(\mathbf{p}\): (N,) barycenter weights
values used in convergence evaluation.
References
Gabriel Peyré, Marco Cuturi, and Justin Solomon, “Gromov-Wasserstein averaging of kernel and distance matrices.” International Conference on Machine Learning (ICML). 2016.
Returns the Gromov-Wasserstein transport between \((\mathbf{C_1}, \mathbf{p})\) and \((\mathbf{C_2}, \mathbf{q})\).
The function solves the following optimization problem using Conditional Gradient:
Where :
\(\mathbf{C_1}\): Metric cost matrix in the source space.
\(\mathbf{C_2}\): Metric cost matrix in the target space.
\(\mathbf{p}\): Distribution in the source space.
\(\mathbf{q}\): Distribution in the target space.
L: Loss function to account for the misfit between the similarity matrices.
Note
This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays.
Note
All computations in the conditional gradient solver are done with numpy to limit memory overhead.
Note
This function will cast the computed transport plan to the data type of the provided input \(\mathbf{C}_1\). Casting to an integer tensor might result in a loss of precision. If this behaviour is unwanted, please make sure to provide a floating point input.
C1 (array-like, shape (ns, ns)) – Metric cost matrix in the source space.
C2 (array-like, shape (nt, nt)) – Metric cost matrix in the target space.
p (array-like, shape (ns,), optional) – Distribution in the source space. If let to its default value None, uniform distribution is taken.
q (array-like, shape (nt,), optional) – Distribution in the target space. If let to its default value None, uniform distribution is taken.
loss_fun (str, optional) – Loss function used for the solver either ‘square_loss’ or ‘kl_loss’.
symmetric (bool, optional) – Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
verbose (bool, optional) – Print information along iterations.
log (bool, optional) – Record log if True.
armijo (bool, optional) – If True, the step of the line-search is found via an armijo search. Else closed form is used. If there are convergence issues, use False.
G0 (array-like, shape (ns,nt), optional) – If None, the initial transport plan of the solver is pq^T. Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
max_iter (int, optional) – Max number of iterations.
tol_rel (float, optional) – Stop threshold on relative error (>0).
tol_abs (float, optional) – Stop threshold on absolute error (>0).
**kwargs (dict) – Parameters can be directly passed to the ot.optim.cg solver.
T (array-like, shape (ns, nt)) –
Coupling between the two spaces that minimizes:
\(\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}\)
log (dict) – Convergence information and loss.
References
Gabriel Peyré, Marco Cuturi, and Justin Solomon, “Gromov-Wasserstein averaging of kernel and distance matrices.” International Conference on Machine Learning (ICML). 2016.
Mémoli, Facundo. Gromov–Wasserstein distances and the metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487.
Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein distance between networks and stable network invariants. Information and Inference: A Journal of the IMA, 8(4), 757-787.
Returns the Gromov-Wasserstein loss \(\mathbf{GW}\) between \((\mathbf{C_1}, \mathbf{p})\) and \((\mathbf{C_2}, \mathbf{q})\). To recover the Gromov-Wasserstein distance as defined in [13] compute \(d_{GW} = \frac{1}{2} \sqrt{\mathbf{GW}}\).
The function solves the following optimization problem using Conditional Gradient:
Where :
\(\mathbf{C_1}\): Metric cost matrix in the source space
\(\mathbf{C_2}\): Metric cost matrix in the target space
\(\mathbf{p}\): distribution in the source space
\(\mathbf{q}\): distribution in the target space
L: loss function to account for the misfit between the similarity matrices
Note that when using backends, this loss function is differentiable wrt the matrices (C1, C2) and weights (p, q) for quadratic loss using the gradients from [38]_.
Note
This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays.
Note
All computations in the conditional gradient solver are done with numpy to limit memory overhead.
Note
This function will cast the computed transport plan to the data type of the provided input \(\mathbf{C}_1\). Casting to an integer tensor might result in a loss of precision. If this behaviour is unwanted, please make sure to provide a floating point input.
C1 (array-like, shape (ns, ns)) – Metric cost matrix in the source space
C2 (array-like, shape (nt, nt)) – Metric cost matrix in the target space
p (array-like, shape (ns,), optional) – Distribution in the source space. If let to its default value None, uniform distribution is taken.
q (array-like, shape (nt,), optional) – Distribution in the target space. If let to its default value None, uniform distribution is taken.
loss_fun (str) – loss function used for the solver either ‘square_loss’ or ‘kl_loss’
symmetric (bool, optional) – Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
armijo (bool, optional) – If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False.
G0 (array-like, shape (ns,nt), optional) – If None the initial transport plan of the solver is pq^T. Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
max_iter (int, optional) – Max number of iterations
tol_rel (float, optional) – Stop threshold on relative error (>0)
tol_abs (float, optional) – Stop threshold on absolute error (>0)
**kwargs (dict) – parameters can be directly passed to the ot.optim.cg solver
gw_dist (float) – Gromov-Wasserstein distance
log (dict) – convergence information and Coupling matrix
References
Gabriel Peyré, Marco Cuturi, and Justin Solomon, “Gromov-Wasserstein averaging of kernel and distance matrices.” International Conference on Machine Learning (ICML). 2016.
Mémoli, Facundo. Gromov–Wasserstein distances and the metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487.
C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein distance between networks and stable network invariants. Information and Inference: A Journal of the IMA, 8(4), 757-787.
Computes the Linear Circular Optimal Transport distance from [78] using \(\eta=\mathrm{Unif}(S^1)\) as reference measure. Samples need to be in \(S^1\cong [0,1[\). If they are on \(\mathbb{R}\), takes the value modulo 1. If the values are on \(S^1\subset\mathbb{R}^2\), it is required to first find the coordinates using e.g. the atan2 function.
General loss returned:
where \(\hat{\mu}(x)=F_{\mu}^{-1}(x-\int z\mathrm{d}\mu(z)+\frac12) - x\) for all \(x\in [0,1[\), and \(d_{S^1}(x,y)=\min(|x-y|, 1-|x-y|)\) for \(x,y\in [0,1[\).
u_values (ndarray, shape (n, ...)) – samples in the source domain (coordinates on [0,1[)
v_values (ndarray, shape (n, ...), optional) – samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution
u_weights (ndarray, shape (n, ...), optional) – samples weights in the source domain
v_weights (ndarray, shape (n, ...), optional) – samples weights in the target domain
loss – Batched cost associated to the linear optimal transportation
float/array-like, shape (…)
Examples
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> linear_circular_ot(u.T, v.T)
array([0.0127])
References
Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations.
Computes the linear spherical sliced wasserstein distance from [79].
General loss returned:
where \(\mu,\nu\in\mathcal{P}(S^{d-1})\) are two probability measures on the sphere, \(\mathrm{LCOT}_2\) is the linear circular optimal transport distance, and \(P^U_\# \mu\) stands for the pushforwards of the projection \(\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}\).
X_s (ndarray, shape (n_samples_a, dim)) – Samples in the source domain
X_t (ndarray, shape (n_samples_b, dim), optional) – Samples in the target domain. If None, computes the distance against the uniform distribution on the sphere.
a (ndarray, shape (n_samples_a,), optional) – samples weights in the source domain
b (ndarray, shape (n_samples_b,), optional) – samples weights in the target domain
n_projections (int, optional) – Number of projections used for the Monte-Carlo approximation
projections (shape (n_projections, dim, 2), optional) – Projection matrix (n_projections and seed are not used in this case)
seed (int or RandomState or None, optional) – Seed used for random number generator
log (bool, optional) – if True, linear_sliced_wasserstein_sphere returns the projections used and their associated LCOT.
cost (float) – Linear Spherical Sliced Wasserstein Cost
log (dict, optional) – log dictionary return only if log==True in parameters
Examples
>>> import ot
>>> import numpy as np
>>> n_samples_a = 20
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
>>> ot.linear_sliced_wasserstein_sphere(X, X, seed=0)
0.0
References
Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data. International Conference on Learning Representations.
Solve the entropic regularization Gromov-Wasserstein transport problem under low-nonnegative rank constraints on the couplings and cost matrices.
Squared euclidean distance matrices are considered for the target and source distributions.
The function solves the following optimization problem:
where :
\(A\) is the (dim_a, dim_a) square pairwise cost matrix of the source domain.
\(B\) is the (dim_b, dim_b) square pairwise cost matrix of the target domain.
\(\mathcal{Q}_{A,B}\) is quadratic objective function of the Gromov Wasserstein plan.
\(Q\) and R are the low-rank matrix decomposition of the Gromov-Wasserstein plan.
\(g\) is the weight vector for the low-rank decomposition of the Gromov-Wasserstein plan.
\(\mathbf{a}\) and \(\mathbf{b}\) are source and target weights (histograms, both sum to 1).
\(r\) is the rank of the Gromov-Wasserstein plan.
\(\mathcal{C(a,b,r)}\) are the low-rank couplings of the OT problem.
\(H((Q,R,g))\) is the values of the three respective entropies evaluated for each term.
X_s (array-like, shape (n_samples_a, dim_Xs)) – Samples in the source domain
X_t (array-like, shape (n_samples_b, dim_Xt)) – Samples in the target domain
a (array-like, shape (n_samples_a,), optional) – Samples weights in the source domain If let to its default value None, uniform distribution is taken.
b (array-like, shape (n_samples_b,), optional) – Samples weights in the target domain If let to its default value None, uniform distribution is taken.
reg (float, optional) – Regularization term >=0
rank (int, optional. Default is None. (>0)) – Nonnegative rank of the OT plan. If None, min(ns, nt) is considered.
alpha (float, optional. Default is 1e-10. (>0 and <1/r)) – Lower bound for the weight vector g.
rescale_cost (bool, optional. Default is True) – Rescale the low rank factorization of the sqeuclidean cost matrix
seed_init (int, optional. Default is 49. (>0)) – Random state for the ‘random’ initialization of low rank couplings
gamma_init (str, optional. Default is "rescale".) – Initialization strategy for gamma. ‘rescale’, or ‘theory’ Gamma is a constant that scales the convergence criterion of the Mirror Descent optimization scheme used to compute the low-rank couplings (Q, R and g)
numItermax (int, optional. Default is 1000.) – Max number of iterations for Low Rank GW
stopThr (float, optional. Default is 1e-4.) – Stop threshold on error (>0) for Low Rank GW The error is the sum of Kullback Divergences computed for each low rank coupling (Q, R and g) and scaled using gamma.
numItermax_dykstra (int, optional. Default is 10000.) – Max number of iterations for the Dykstra algorithm
stopThr_dykstra (float, optional. Default is 1e-3.) – Stop threshold on error (>0) in Dykstra
cost_factorized_Xs (tuple, optional. Default is None) – Tuple with two pre-computed low rank decompositions (A1, A2) of the source cost matrix. Both matrices should have a shape of (n_samples_a, dim_Xs + 2). If None, the low rank cost matrices will be computed as sqeuclidean cost matrices.
cost_factorized_Xt (tuple, optional. Default is None) – Tuple with two pre-computed low rank decompositions (B1, B2) of the target cost matrix. Both matrices should have a shape of (n_samples_b, dim_Xt + 2). If None, the low rank cost matrices will be computed as sqeuclidean cost matrices.
warn (bool, optional) – if True, raises a warning if the low rank GW algorithm doesn’t convergence.
warn_dykstra (bool, optional) – if True, raises a warning if the Dykstra algorithm doesn’t convergence.
log (bool, optional) – record log if True
Q (array-like, shape (n_samples_a, r)) – First low-rank matrix decomposition of the OT plan
R (array-like, shape (n_samples_b, r)) – Second low-rank matrix decomposition of the OT plan
g (array-like, shape (r, )) – Weight vector for the low-rank decomposition of the OT
log (dict (lazy_plan, value and value_linear)) – log dictionary return only if log==True in parameters
References
Scetbon, M., Peyré, G. & Cuturi, M. (2022). “Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs”. In International Conference on Machine Learning (ICML), 2022.
Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the couplings.
The function solves the following optimization problem:
where :
\(\mathbf{C}\) is the (dim_a, dim_b) metric cost matrix
\(H((\mathbf{Q}, \mathbf{R}, \mathbf{g}))\) is the values of the three respective entropies evaluated for each term.
\(\mathbf{Q}\) and \(\mathbf{R}\) are the low-rank matrix decomposition of the OT plan
\(\mathbf{g}\) is the weight vector for the low-rank decomposition of the OT plan
\(\mathbf{a}\) and \(\mathbf{b}\) are source and target weights (histograms, both sum to 1)
\(r\) is the rank of the OT plan
\(\mathcal{C}(\mathbf{a}, \mathbf{b}, r)\) are the low-rank couplings of the OT problem
X_s (array-like, shape (n_samples_a, dim)) – samples in the source domain
X_t (array-like, shape (n_samples_b, dim)) – samples in the target domain
a (array-like, shape (n_samples_a,)) – samples weights in the source domain
b (array-like, shape (n_samples_b,)) – samples weights in the target domain
reg (float, optional) – Regularization term >0
rank (int, optional. Default is None. (>0)) – Nonnegative rank of the OT plan. If None, min(ns, nt) is considered.
alpha (int, optional. Default is 1e-10. (>0 and <1/r)) – Lower bound for the weight vector g.
rescale_cost (bool, optional. Default is True) – Rescale the low rank factorization of the sqeuclidean cost matrix
init (str, optional. Default is 'random'.) – Initialization strategy for the low rank couplings. ‘random’, ‘deterministic’ or ‘kmeans’
reg_init (float, optional. Default is 1e-1. (>0)) – Regularization term for a ‘kmeans’ init. If None, 1 is considered.
seed_init (int, optional. Default is 49. (>0)) – Random state for a ‘random’ or ‘kmeans’ init strategy.
gamma_init (str, optional. Default is "rescale".) – Initialization strategy for gamma. ‘rescale’, or ‘theory’ Gamma is a constant that scales the convergence criterion of the Mirror Descent optimization scheme used to compute the low-rank couplings (Q, R and g)
numItermax (int, optional. Default is 2000.) – Max number of iterations for the Dykstra algorithm
stopThr (float, optional. Default is 1e-7.) – Stop threshold on error (>0) in Dykstra
warn (bool, optional) – if True, raises a warning if the algorithm doesn’t convergence.
log (bool, optional) – record log if True
Q (array-like, shape (n_samples_a, r)) – First low-rank matrix decomposition of the OT plan
R (array-like, shape (n_samples_b, r)) – Second low-rank matrix decomposition of the OT plan
g (array-like, shape (r, )) – Weight vector for the low-rank decomposition of the OT
log (dict (lazy_plan, value and value_linear)) – log dictionary return only if log==True in parameters
References
Scetbon, M., Cuturi, M., & Peyré, G. (2021). “Low-rank Sinkhorn Factorization”. In International Conference on Machine Learning.
Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance
where :
\(\theta_\# \mu\) stands for the pushforwards of the projection \(\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle\)
X_s (ndarray, shape (n_samples_a, dim)) – samples in the source domain
X_t (ndarray, shape (n_samples_b, dim)) – samples in the target domain
a (ndarray, shape (n_samples_a,), optional) – samples weights in the source domain
b (ndarray, shape (n_samples_b,), optional) – samples weights in the target domain
n_projections (int, optional) – Number of projections used for the Monte-Carlo approximation
p (float, optional =) – Power p used for computing the sliced Wasserstein
projections (shape (dim, n_projections), optional) – Projection matrix (n_projections and seed are not used in this case)
seed (int or RandomState or None, optional) – Seed used for random number generator
log (bool, optional) – if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
scaler (None, object with .transform(), or callable, optional) –
Preprocessing applied to X_s and X_t before computing the distance. Useful for normalizing inputs when features have very different scales.
None : no preprocessing (default)
Object with .transform() method : e.g. an ot.utils.DataScaler
fitted on a representative sample. This is the recommended way to get
stable, consistent normalization across multiple calls (e.g. when
using SWD as a loss in mini-batch training).
Callable : any function, lambda, or PyTorch transform applied
directly as scaler(X_s) and scaler(X_t).
See ot.utils.DataScaler for a backend-aware scaler that supports
joint fitting on multiple distributions.
cost (float) – Sliced Wasserstein Cost
log (dict, optional) – log dictionary return only if log==True in parameters
Examples
>>> import numpy as np
>>> n_samples_a = 20
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0)
0.0
References
Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., … & Schwing, A. G. (2019). Max-sliced wasserstein distance and its use for gans. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
Merge several bijections between two point clouds to obtain a new one with low transport cost. The new bijection is guaranteed to have a transport cost no greater than the cost of any of the input bijections. Based on simple local/global swapping strategy, with a linear complexity in the number of points. Algorithm 3 from [84].
X (array-like, shape (n_samples, dimension))
Y (array-like, shape (n_samples, dimension))
perms (array-like, shape (n_plans,n_samples)) – The bijections to merge, stored as permutations (e.g. a list of numbers)
p (int, optional) – The power of the ground metric (default 2 for squared euclidean). If p is -1, the infinity norm is used.
cost (float) – The transport cost of the merged bijection.
perm (array-like, shape (n_samples,)) – The merged bijection, stored as a permutation (e.g. a list of numbers) such that X[i] is assigned to Y[perm[i]].
Computes the cost and permutation associated to the min-Pivot Sliced Discrepancy (introduced as min-SWGG in [85] and studied further in [86]). Given the supports X_s and X_t of two discrete uniform measures with ns and nt atoms in dimension d, the min-Pivot Sliced Discrepancy goes through n_projections different projections of the measures on random directions, and retains the couplings that yields the lowest cost between X_s and X_t (compared in \(\mathbb{R}^d\)). When ns=nt, it gives
where \(\sigma_k\) is a permutation such that ordering the projections on the axis projections[k, :] matches \(X_s[i, :]\) to \(X_t[\sigma_k(i), :]\).
Note
The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases.
Warning
Tensorflow and jax only returns dense plans, as they do not support well sparse matrices.
X_s (array-like, shape (ns, d)) – The first set of vectors.
X_t (array-like, shape (nt, d)) – The second set of vectors.
a (ndarray of float64, shape (ns,), optional) – Source histogram (default is uniform weight)
b (ndarray of float64, shape (nt,), optional) – Target histogram (default is uniform weight)
projections (shape (dim, n_projections), optional) – Projection matrix (n_projections and seed are not used in this case). Default is None
metric (str, optional (default='sqeuclidean')) – Metric to be used. Only works with either of the strings ‘sqeuclidean’, ‘minkowski’, ‘cityblock’, or ‘euclidean’.
p (float, optional (default=1.0)) – The p-norm to apply for if metric=’minkowski’
n_projections (int, optional) – The number of projection directions. Required if projections is None.
seed (int, optional) – The seed for the random number generator for sampling projections, in case projections is None. Default is None.
batch_size (int, optional) – If specified, compute the distance in batches of size batch_size to avoid memory issues for large datasets. Default is None (no batching).
dense (boolean, optional (default=True)) – If True, returns \(\gamma\) as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy’s coo_matrix format.
log (bool, optional) – If True, returns additional logging information. Default is False.
plan (ndarray, shape (ns, nt) or coo_matrix if dense is False) – Optimal transportation matrix for the given parameters.
cost (float) – The cost associated to the optimal permutation.
log_dict (dict, optional) – A dictionary containing intermediate computations for logging purposes. Returned only if log is True.
References
Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385.
Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Transport Plans. arXiv preprint 2506.03661.
Examples
>>> import ot
>>> import numpy as np
>>> x=np.array([[3.,3.], [1.,1.]])
>>> y=np.array([[2.,2.5], [3.,2.]])
>>> projections=np.array([[1, 0], [0, 1]])
>>> plan, cost = ot.min_sliced_transport_plan(x, y, projections=projections)
>>> plan
array([[0. , 0.5],
[0.5, 0. ]])
>>> cost
2.125
Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on \(S^1\) Samples need to be in \(S^1\cong [0,1[\). If they are on \(\mathbb{R}\), takes the value modulo 1. If the values are on \(S^1\subset\mathbb{R}^2\), it is required to first find the coordinates using e.g. the atan2 function.
where:
\(\nu=\mathrm{Unif}(S^1)\) and \(\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}\)
For values \(x=(x_1,x_2)\in S^1\), it is required to first get their coordinates with
using e.g. ot.utils.get_coordinate_circle(x).
u_values (ndarray, shape (n, ...)) – Samples
u_weights (ndarray, shape (n, ...), optional) – samples weights in the source domain
loss – Batched cost associated to the optimal transportation
float/array-like, shape (…)
Examples
>>> x0 = np.array([[0], [0.2], [0.4]])
>>> semidiscrete_wasserstein2_unif_circle(x0)
array([0.02111111])
References
Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
where :
\(\mathbf{M}\) is the (dim_a, dim_b) metric cost matrix
\(\Omega\) is the entropic regularization term \(\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})\)
\(\mathbf{a}\) and \(\mathbf{b}\) are source and target weights (histograms, both sum to 1)
Note
This function is backend-compatible and will work on arrays from all compatible backends.
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]
Choosing a Sinkhorn solver
By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
ot.bregman.sinkhorn_stabilized() solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
ot.bregman.sinkhorn_epsilon_scaling() that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
ot.bregman.greenkhorn() can also lead to a speedup and the screening
version of the sinkhorn ot.bregman.screenkhorn() aim at providing a
fast approximation of the Sinkhorn problem. For use of GPU and gradient
computation with small number of iterations we strongly recommend the
ot.bregman.sinkhorn_log() solver that will no need to check for
numerical problems.
a (array-like, shape (dim_a,)) – samples weights in the source domain
b (array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)) – samples in the target domain, compute sinkhorn with multiple targets and fixed \(\mathbf{M}\) if \(\mathbf{b}\) is a matrix (return OT loss + dual variables in log)
M (array-like, shape (dim_a, dim_b)) – loss matrix
reg (float) – Regularization term >0
method (str) – method used for the solver either ‘sinkhorn’,’sinkhorn_log’, ‘greenkhorn’, ‘sinkhorn_stabilized’ or ‘sinkhorn_epsilon_scaling’, see those function for specific parameters
numItermax (int, optional) – Max number of iterations
stopThr (float, optional) – Stop threshold on error (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
warn (bool, optional) – if True, raises a warning if the algorithm doesn’t convergence.
warmstart (tuple of arrays, shape (dim_a, dim_b), optional) – Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors)
gamma (array-like, shape (dim_a, dim_b)) – Optimal transportation matrix for the given parameters
log (dict) – log dictionary return only if log==True in parameters
Examples
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
References
M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See also
ot.lp.emdUnregularized OT
ot.optim.cgGeneral regularized OT
ot.bregman.sinkhorn_knoppClassic Sinkhorn [2]
ot.bregman.sinkhorn_stabilizedot.bregman.sinkhorn_epsilon_scalingSolve the entropic regularization optimal transport problem and return the loss
The function solves the following optimization problem:
where :
\(\mathbf{M}\) is the (dim_a, dim_b) metric cost matrix
\(\Omega\) is the entropic regularization term \(\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})\)
\(\mathbf{a}\) and \(\mathbf{b}\) are source and target weights (histograms, both sum to 1)
and returns \(\langle \gamma^*, \mathbf{M} \rangle_F\) (without the entropic contribution).
Note
This function is backend-compatible and will work on arrays from all compatible backends.
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]
Choosing a Sinkhorn solver
By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
ot.bregman.sinkhorn_log() solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
ot.bregman.sinkhorn_epsilon_scaling() that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
ot.bregman.greenkhorn() can also lead to a speedup and the screening
version of the sinkhorn ot.bregman.screenkhorn() aim a providing a
fast approximation of the Sinkhorn problem. For use of GPU and gradient
computation with small number of iterations we strongly recommend the
ot.bregman.sinkhorn_log() solver that will no need to check for
numerical problems.
a (array-like, shape (dim_a,)) – samples weights in the source domain
b (array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)) – samples in the target domain, compute sinkhorn with multiple targets and fixed \(\mathbf{M}\) if \(\mathbf{b}\) is a matrix (return OT loss + dual variables in log)
M (array-like, shape (dim_a, dim_b)) – loss matrix
reg (float) – Regularization term >0
method (str) – method used for the solver either ‘sinkhorn’,’sinkhorn_log’, ‘sinkhorn_stabilized’, see those function for specific parameters
numItermax (int, optional) – Max number of iterations
stopThr (float, optional) – Stop threshold on error (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
warn (bool, optional) – if True, raises a warning if the algorithm doesn’t convergence.
warmstart (tuple of arrays, shape (dim_a, dim_b), optional) – Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors)
W ((n_hists) float/array-like) – Optimal transportation loss for the given parameters
log (dict) – log dictionary return only if log==True in parameters
Examples
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn2(a, b, M, 1)
0.26894142136999516
References
M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See also
ot.lp.emdUnregularized OT
ot.optim.cgGeneral regularized OT
ot.bregman.sinkhorn_knoppClassic Sinkhorn [2]
ot.bregman.greenkhornGreenkhorn [21]
ot.bregman.sinkhorn_stabilizedSolve the entropic regularization optimal transport problem with non-convex group lasso regularization
The function solves the following optimization problem:
where :
\(\mathbf{M}\) is the (ns, nt) metric cost matrix
\(\Omega_e\) is the entropic regularization term \(\Omega_e (\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})\)
\(\Omega_g\) is the group lasso regularization term \(\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1\) where \(\mathcal{I}_c\) are the index of samples from class c in the source domain.
\(\mathbf{a}\) and \(\mathbf{b}\) are source and target weights (sum to 1)
The algorithm used for solving the problem is the generalized conditional gradient as proposed in [5, 7].
a (array-like (ns,)) – samples weights in the source domain
labels_a (array-like (ns,)) – labels of samples in the source domain
b (array-like (nt,)) – samples weights in the target domain
M (array-like (ns,nt)) – loss matrix
reg (float) – Regularization term for entropic regularization >0
eta (float, optional) – Regularization term for group lasso regularization >0
numItermax (int, optional) – Max number of iterations
numInnerItermax (int, optional) – Max number of iterations (inner sinkhorn solver)
stopInnerThr (float, optional) – Stop threshold on error (inner sinkhorn solver) (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
gamma ((ns, nt) array-like) – Optimal transportation matrix for the given parameters
log (dict) – log dictionary return only if log==True in parameters
References
N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, “Optimal Transport for Domain Adaptation,” in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See also
ot.lp.emdUnregularized OT
ot.bregman.sinkhornEntropic regularized OT
ot.optim.cgGeneral regularized OT
Solve the unbalanced entropic regularization optimal transport problem and return the OT plan
The function solves the following optimization problem:
where :
\(\mathbf{M}\) is the (dim_a, dim_b) metric cost matrix
\(\mathbf{a}\) and \(\mathbf{b}\) are source and target unbalanced distributions
\(\mathbf{c}\) is a reference distribution for the regularization
KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 25]
Warning
Starting from version 0.9.5, the default value has been changed to reg_type=’kl’ instead of reg_type=’entropy’. This makes the function more consistent with the literature and the other solvers. If you want to use the entropy regularization, please set reg_type=’entropy’ explicitly.
a (array-like, shape (dim_a,)) – Unnormalized histogram of dimension dim_a If a is an empty list or array ([]), then a is set to uniform distribution.
b (array-like, shape (dim_b,)) – One or multiple unnormalized histograms of dimension dim_b. If b is an empty list or array ([]), then b is set to uniform distribution. If many, compute all the OT costs \((\mathbf{a}, \mathbf{b}_i)_i\)
M (array-like, shape (dim_a, dim_b)) – loss matrix
reg (float) – Entropy regularization term > 0
reg_m (float or indexable object of length 1 or 2) – Marginal relaxation term. If \(\mathrm{reg_{m}}\) is a scalar or an indexable object of length 1, then the same \(\mathrm{reg_{m}}\) is applied to both marginal relaxations. The entropic balanced OT can be recovered using \(\mathrm{reg_{m}}=float("inf")\). For semi-relaxed case, use either \(\mathrm{reg_{m}}=(float("inf"), scalar)\) or \(\mathrm{reg_{m}}=(scalar, float("inf"))\). If \(\mathrm{reg_{m}}\) is an array, it must have the same backend as input arrays (a, b, M).
method (str) – method used for the solver either ‘sinkhorn’, ‘sinkhorn_stabilized’, ‘sinkhorn_translation_invariant’ or ‘sinkhorn_reg_scaling’, see those function for specific parameters
reg_type (string, optional) –
Regularizer term. Can take two values:
Negative entropy: ‘entropy’: \(\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}\). This is equivalent (up to a constant) to \(\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)\).
Kullback-Leibler divergence (default): ‘kl’: \(\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)\).
c (array-like, shape (dim_a, dim_b), optional (default=None)) – Reference measure for the regularization. If None, then use \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\). If \(\texttt{reg_type}=\)’entropy’, then \(\mathbf{c} = 1_{dim_a} 1_{dim_b}^T\).
warmstart (tuple of arrays, shape (dim_a, dim_b), optional) – Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u, v sinkhorn scaling vectors).
numItermax (int, optional) – Max number of iterations
stopThr (float, optional) – Stop threshold on error (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
if n_hists == 1 –
Optimal transportation matrix for the given parameters
log dictionary returned only if log is True
else –
the OT distance between \(\mathbf{a}\) and each of the histograms \(\mathbf{b}_i\)
log dictionary returned only if log is True
Examples
>>> import ot
>>> import numpy as np
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> np.round(ot.sinkhorn_unbalanced(a, b, M, 1, 1), 7)
array([[0.3220536, 0.1184769],
[0.1184769, 0.3220536]])
References
M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
Séjourné, T., Vialard, F. X., & Peyré, G. (2022). Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
See also
ot.unbalanced.sinkhorn_knopp_unbalancedUnbalanced Classic Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized_unbalancedUnbalanced Stabilized sinkhorn [9, 10]
ot.unbalanced.sinkhorn_reg_scaling_unbalancedUnbalanced Sinkhorn with epsilon scaling [9, 10]
ot.unbalanced.sinkhorn_unbalanced_translation_invariantTranslation Invariant Unbalanced Sinkhorn [73]
Solve the entropic regularization unbalanced optimal transport problem and return the cost
The function solves the following optimization problem:
where :
\(\mathbf{M}\) is the (dim_a, dim_b) metric cost matrix
\(\mathbf{a}\) and \(\mathbf{b}\) are source and target unbalanced distributions
\(\mathbf{c}\) is a reference distribution for the regularization
KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 25]
Warning
Starting from version 0.9.5, the default value has been changed to reg_type=’kl’ instead of reg_type=’entropy’. This makes the function more consistent with the literature and the other solvers. If you want to use the entropy regularization, please set reg_type=’entropy’ explicitly.
a (array-like, shape (dim_a,)) – Unnormalized histogram of dimension dim_a If a is an empty list or array ([]), then a is set to uniform distribution.
b (array-like, shape (dim_b,)) – One or multiple unnormalized histograms of dimension dim_b. If b is an empty list or array ([]), then b is set to uniform distribution. If many, compute all the OT costs \((\mathbf{a}, \mathbf{b}_i)_i\)
M (array-like, shape (dim_a, dim_b)) – loss matrix
reg (float) – Entropy regularization term > 0
reg_m (float or indexable object of length 1 or 2) – Marginal relaxation term. If \(\mathrm{reg_{m}}\) is a scalar or an indexable object of length 1, then the same \(\mathrm{reg_{m}}\) is applied to both marginal relaxations. The entropic balanced OT can be recovered using \(\mathrm{reg_{m}}=float("inf")\). For semi-relaxed case, use either \(\mathrm{reg_{m}}=(float("inf"), scalar)\) or \(\mathrm{reg_{m}}=(scalar, float("inf"))\). If \(\mathrm{reg_{m}}\) is an array, it must have the same backend as input arrays (a, b, M).
method (str) – method used for the solver either ‘sinkhorn’, ‘sinkhorn_stabilized’, ‘sinkhorn_translation_invariant’ or ‘sinkhorn_reg_scaling’, see those function for specific parameters
reg_type (string, optional) –
Regularizer term. Can take two values:
Negative entropy: ‘entropy’: \(\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}\). This is equivalent (up to a constant) to \(\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)\).
Kullback-Leibler divergence: ‘kl’: \(\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)\).
c (array-like, shape (dim_a, dim_b), optional (default=None)) – Reference measure for the regularization. If None, then use \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\). If \(\texttt{reg_type}=\)’entropy’, then \(\mathbf{c} = 1_{dim_a} 1_{dim_b}^T\).
warmstart (tuple of arrays, shape (dim_a, dim_b), optional) – Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors).
returnCost (string, optional (default = "linear")) – If returnCost = “linear”, then return the linear part of the unbalanced OT loss. If returnCost = “total”, then return the total unbalanced OT loss.
numItermax (int, optional) – Max number of iterations
stopThr (float, optional) – Stop threshold on error (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
ot_cost (array-like, shape (n_hists,)) – the OT cost between \(\mathbf{a}\) and each of the histograms \(\mathbf{b}_i\)
log (dict) – log dictionary returned only if log is True
Examples
>>> import ot
>>> import numpy as np
>>> a=[.5, .10]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
>>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8)
0.19600125
References
M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
Séjourné, T., Vialard, F. X., & Peyré, G. (2022). Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
See also
ot.unbalanced.sinkhorn_knoppUnbalanced Classic Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilizedUnbalanced Stabilized sinkhorn [9, 10]
ot.unbalanced.sinkhorn_reg_scalingUnbalanced Sinkhorn with epsilon scaling [9, 10]
ot.unbalanced.sinkhorn_unbalanced_translation_invariantTranslation Invariant Unbalanced Sinkhorn [73]
Compute the Sliced Unbalanced Optimal Transport (SUOT) between two empirical distributions. The 1D UOT problem is computed with KL regularization and solved with a Frank-Wolfe algorithm, see [82].
The Sliced Unbalanced Optimal Transport (SUOT) is defined as
with \(P^\theta(x)=\langle x,\theta\rangle\) and \(\lambda\) the uniform distribution on the unit sphere.
Warning
This function only works in pytorch or jax as it uses autodifferentiation to compute the 1D UOT problems. It is not maintained in jax.
X_s (ndarray, shape (n_samples_a, dim)) – samples in the source domain
X_t (ndarray, shape (n_samples_b, dim)) – samples in the target domain
reg_m (float or indexable object of length 1 or 2) – Marginal relaxation term. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. The balanced OT can be recovered using reg_m=float(“inf”). For semi-relaxed case, use either reg_m=(float(“inf”), scalar) or reg_m=(scalar, float(“inf”)). If reg_m is an array, it must have the same backend as input arrays (X_s, X_t).
a (ndarray, shape (n_samples_a,), optional) – samples weights in the source domain
b (ndarray, shape (n_samples_b,), optional) – samples weights in the target domain
n_projections (int, optional) – Number of projections used for the Monte-Carlo approximation
p (float, optional, by default =2) – Power p used for computing the sliced Wasserstein
projections (shape (dim, n_projections), optional) – Projection matrix (n_projections and seed are not used in this case)
seed (int or RandomState or None, optional) – Seed used for random number generator
numItermax (int, optional)
log (bool, optional) – if True, returns the projections used and their associated UOTs and reweighted marginals.
loss (float/array-like, shape (…)) – SUOT
log (dict, optional) – If log is True, then returns a dictionary containing the projection directions used, the projected UOTs, and reweighted marginals on each slices.
References
Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research.
See also
ot.unbalanced.uot_1d1D OT problem
ot.unbalanced.unbalanced_sliced_otUnbalanced SOT problem
Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance
where :
\(\theta_\# \mu\) stands for the pushforwards of the projection \(X \in \mathbb{R}^d \mapsto \langle \theta, X \rangle\)
X_s (ndarray, shape (n_samples_a, dim)) – samples in the source domain
X_t (ndarray, shape (n_samples_b, dim)) – samples in the target domain
a (ndarray, shape (n_samples_a,), optional) – samples weights in the source domain
b (ndarray, shape (n_samples_b,), optional) – samples weights in the target domain
n_projections (int, optional) – Number of projections used for the Monte-Carlo approximation
p (float, optional) – Power p used for computing the sliced Wasserstein
projections (shape (dim, n_projections), optional) – Projection matrix (n_projections and seed are not used in this case)
seed (int or RandomState or None, optional) – Seed used for random number generator
log (bool, optional) – if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
scaler (None, object with .transform(), or callable, optional) –
Preprocessing applied to X_s and X_t before computing the distance. Useful for normalizing inputs when features have very different scales.
None : no preprocessing (default)
Object with .transform() method : e.g. an ot.utils.DataScaler
fitted on a representative sample. This is the recommended way to get
stable, consistent normalization across multiple calls (e.g. when
using SWD as a loss in mini-batch training).
Callable : any function, lambda, or PyTorch transform applied
directly as scaler(X_s) and scaler(X_t).
See ot.utils.DataScaler for a backend-aware scaler that supports
joint fitting on multiple distributions.
cost (float) – Sliced Wasserstein Cost
log (dict, optional) – log dictionary return only if log==True in parameters
Examples
>>> import numpy as np
>>> n_samples_a = 20
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0)
0.0
References
Bonneel, Nicolas, et al. “Sliced and radon wasserstein barycenters of measures.” Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
Compute the spherical sliced-Wasserstein discrepancy.
where:
\(P^U_\# \mu\) stands for the pushforwards of the projection \(\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}\)
The function runs on backend but tensorflow and jax are not supported.
X_s (ndarray, shape (n_samples_a, dim)) – Samples in the source domain
X_t (ndarray, shape (n_samples_b, dim)) – Samples in the target domain
a (ndarray, shape (n_samples_a,), optional) – samples weights in the source domain
b (ndarray, shape (n_samples_b,), optional) – samples weights in the target domain
n_projections (int, optional) – Number of projections used for the Monte-Carlo approximation
p (float, optional (default=2)) – Power p used for computing the spherical sliced Wasserstein
projections (shape (n_projections, dim, 2), optional) – Projection matrix (n_projections and seed are not used in this case)
seed (int or RandomState or None, optional) – Seed used for random number generator
log (bool, optional) – if True, sliced_wasserstein_sphere returns the projections used and their associated EMD.
cost (float) – Spherical Sliced Wasserstein Cost
log (dict, optional) – log dictionary return only if log==True in parameters
Examples
>>> import ot
>>> import numpy as np
>>> n_samples_a = 20
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
>>> ot.sliced_wasserstein_sphere(X, X, seed=0)
0.0
References
Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution.
where
\(\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}\)
\(\nu=\mathrm{Unif}(S^{d-1})\)
X_s (ndarray, shape (n_samples_a, dim)) – Samples in the source domain
a (ndarray, shape (n_samples_a,), optional) – samples weights in the source domain
n_projections (int, optional) – Number of projections used for the Monte-Carlo approximation
projections (shape (n_projections, dim, 2), optional) – Projection matrix (n_projections and seed are not used in this case)
seed (int or RandomState or None, optional) – Seed used for random number generator
log (bool, optional) – if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
cost (float) – Spherical Sliced Wasserstein Cost
log (dict, optional) – log dictionary return only if log==True in parameters
Examples
>>> import ot
>>> import numpy as np
>>> np.random.seed(42)
>>> x0 = np.random.randn(500,3)
>>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True))
>>> ssw = ot.sliced_wasserstein_sphere_unif(x0, seed=42)
>>> np.allclose(ot.sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3)
True
Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
Solve the discrete optimal transport problem and return OTResult object
The function solves the following general optimal transport problem
The regularization is selected with reg (\(\lambda_r\)) and reg_type. By
default reg=None and there is no regularization. The unbalanced marginal
penalization can be selected with unbalanced (\((\lambda_1, \lambda_2)\)) and
unbalanced_type. By default unbalanced=None and the function
solves the exact optimal transport problem (respecting the marginals).
M (array-like, shape (dim_a, dim_b)) – Loss matrix
a (array-like, shape (dim_a,), optional) – Samples weights in the source domain (default is uniform)
b (array-like, shape (dim_b,), optional) – Samples weights in the source domain (default is uniform)
reg (float, optional) – Regularization weight \(\lambda_r\), by default None (no reg., exact OT)
c (array-like, shape (dim_a, dim_b), optional (default=None)) – Reference measure for the regularization. If None, then use \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\). If \(\texttt{reg_type}=\)’entropy’, then \(\mathbf{c} = 1_{dim_a} 1_{dim_b}^T\).
reg_type (str, optional) – Type of regularization \(R\) either “KL”, “L2”, “entropy”,
by default “KL”. a tuple of functions can be provided for general
solver (see cg). This is only used when reg!=None.
unbalanced (float or indexable object of length 1 or 2) – Marginal relaxation term. If it is a scalar or an indexable object of length 1, then the same relaxation is applied to both marginal relaxations. The balanced OT can be recovered using \(unbalanced=float("inf")\). For semi-relaxed case, use either \(unbalanced=(float("inf"), scalar)\) or \(unbalanced=(scalar, float("inf"))\). If unbalanced is an array, it must have the same backend as input arrays (a, b, M).
unbalanced_type (str, optional) – Type of unbalanced penalization function \(U\) either “KL”, “L2”, “TV”, by default “KL”.
method (str, optional) – Method for solving the problem when multiple algorithms are available, default None for automatic selection.
n_threads (int, optional) – Number of OMP threads for exact OT solver, by default 1
max_iter (int, optional) – Maximum number of iterations, by default None (default values in each solvers)
plan_init (array-like, shape (dim_a, dim_b), optional) – Initialization of the OT plan for iterative methods, by default None
potentials_init ((array-like(dim_a,),array-like(dim_b,)), optional) – Initialization of the OT dual potentials for iterative methods (Sinkhorn, unbalanced) or warmstart for exact EMD solver. For EMD, these potentials are used to guide initial pivots in the network simplex. By default None
tol (_type_, optional) – Tolerance for solution precision, by default None (default values in each solvers)
verbose (bool, optional) – Print information in the solver, by default False
grad (str, optional) – Type of gradient computation, either or ‘autodiff’, ‘envelope’ or ‘last_step’ used only for Sinkhorn solver. By default ‘autodiff’ provides gradients wrt all outputs (plan, value, value_linear) but with important memory cost. ‘envelope’ provides gradients only for value and and other outputs are detached. This is useful for memory saving when only the value is needed. ‘last_step’ provides gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. ‘detach’ does not compute the gradients for the Sinkhorn solver.
res – Result of the optimization problem. The information can be obtained as follows:
res.plan : OT plan \(\mathbf{T}\)
res.potentials : OT dual potentials
res.value : Optimal value of the optimization problem
res.value_linear : Linear OT loss with the optimal OT plan
See OTResult for more information.
OTResult()
Notes
The following methods are available for solving the OT problems:
Classical exact OT problem [1] (default parameters) :
can be solved with the following code:
res = ot.solve(M, a, b)
Entropic regularized OT [2] (when reg!=None):
can be solved with the following code:
# default is ``"KL"`` regularization (``reg_type="KL"``)
res = ot.solve(M, a, b, reg=1.0)
# or for original Sinkhorn paper formulation [2]
res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
# Use envelope theorem differentiation for memory saving
res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors
res.value.backward() # only the value is differentiable
Note that by default the Sinkhorn solver uses automatic differentiation to compute the gradients of the values and plan. This can be changed with the grad parameter. The envelope mode computes the gradients only for the value and the other outputs are detached. This is useful for memory saving when only the gradient of value is needed.
Quadratic regularized OT [17] (when reg!=None and reg_type="L2"):
can be solved with the following code:
res = ot.solve(M,a,b,reg=1.0,reg_type='L2')
Unbalanced OT [41] (when unbalanced!=None):
can be solved with the following code:
# default is ``"KL"``
res = ot.solve(M,a,b,unbalanced=1.0)
# quadratic unbalanced OT
res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='L2')
# TV = partial OT
res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='TV')
Regularized unbalanced regularized OT [34] (when unbalanced!=None and reg!=None):
can be solved with the following code:
# default is ``"KL"`` for both
res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0)
# quadratic unbalanced OT with KL regularization
res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2')
# both quadratic
res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2')
References
Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS.
Solve the discrete OT barycenter problem over source distributions optimizing the barycenter support using Block-Coordinate Descent.
The function solves the following general OT barycenter problem
where the cost matrices \(\mathbf{M}^{(k)}\) from each input distribution \((\mathbf{X}^{(k)}, \mathbf{a}^{(k)})\) to the barycenter domain are computed as \(M^{(k)}_{i,j} = d(x^{(k)}_i,x_j)\) where \(d\) is a metric (by default the squared Euclidean distance). For common metrics the barycenter is computed in closed-form. For balanced OT, the metric parameter can also be any callable function, or list of functions, that computes the distance from an input to the barycenter. In which case, the barycenter is updated by gradient descent using the provided metric(s) and the optimal transport plan(s) at each iteration. The barycenter probability weights are fixed to \(\mathbf{b}\).
The regularization is selected with reg (\(\lambda_r\)) and reg_type. By
default reg=None and there is no regularization. The unbalanced marginal
penalization can be selected with unbalanced (\(\lambda_u\)) and
unbalanced_type. By default unbalanced=None and the function
solves the exact optimal transport problem (respecting the marginals).
X_a_list (list of array-like, shape (n_samples_k, dim)) – List of N samples in each source distribution
n (int) – number of samples in the barycenter domain
a_list (list of array-like, shape (n_samples_k,), optional) – List of samples weights in each source distribution (default is uniform)
w (list of array-like, shape (N,), optional) – Samples barycentric weights (default is uniform)
X_b_init (array-like, shape (n, dim), optional) – Initialization of the barycenter samples (default is gaussian random sampling)
b (array-like, shape (n,), optional) – Barycenter weights (default is uniform)
metric (str, callable or list of callables optional) – Metric to use for the computation of the cost matrix, by default “sqeuclidean”. It can be a list of callables (bary, source) of length N (number of source distributions) to use different metrics for each source distribution. In this case, the barycenter is updated by gradient descent using the provided metric(s) and the optimal transport plan(s) at each iteration. If only callable is provided the same cost function is used for all source distributions.
reg (float, optional) – Regularization weight \(\lambda_r\), by default None (no reg., exact OT)
c (array-like, shape (dim_a, dim_b), optional (default=None)) – Reference measure for the regularization. If None, then use \(\mathbf{c} = \mathbf{a}^{(k)} \mathbf{b}^T\). If \(\texttt{reg_type}=\)’entropy’, then \(\mathbf{c} = 1_{|a^{(k)}|} 1_{|b|}^T\).
reg_type (str, optional) – Type of regularization \(R\) either “KL”, “L2”, “entropy”, by default “KL”
unbalanced (float or indexable object of length 1 or 2) – Marginal relaxation term. If it is a scalar or an indexable object of length 1, then the same relaxation is applied to both marginal relaxations. The balanced OT can be recovered using \(unbalanced=float("inf")\). For semi-relaxed case, use either \(unbalanced=(float("inf"), scalar)\) or \(unbalanced=(scalar, float("inf"))\). If unbalanced is an array, it must have the same backend as input arrays (a, b, M).
unbalanced_type (str, optional) – Type of unbalanced penalization function \(U\) either “KL”, “L2”, “TV”, by default “KL”
lazy (bool, optional) – Return OTResultlazy object to reduce memory cost when True, by
default False
method (str, optional) – Method for solving the problem, this can be used to select the solver
for unbalanced problems (see ot.solve), or to select a specific
large scale solver.
auto_bary_method (str, optional) – For balanced OT with callable metric functions, the barycenter method to use in ‘L2_barycentric_proj’ (default) for Euclidean barycentric projection, or ‘true_fixed_point’ for iterates using the North West Corner multi-marginal gluing method.
warmstart (bool, optional) – Use the previous OT or potentials as initialization for the next inner solver iteration, by default False.
stopping_criterion (str, optional) – Stopping criterion for the outer loop of the BCD solver, by default ‘loss’. Either ‘loss’ to use the optimize objective or ‘bary’ for variations of the barycenter w.r.t the Frobenius norm.
max_iter_bary (int, optional) – Maximum number of iteration for the outer loop of the BCD solver, by default 1000.
tol_bary (float, optional) – Tolerance for solution precision of the barycenter problem, by default 1e-5.
random_state (int, optional) – Random seed for the initialization of the barycenter samples, by default 0. Only used if X_init is None.
verbose (bool, optional) – Print information in the solver, by default False
kwargs (optional) – Additional parameters for the inner solver (see ot.solve_sample and ot.lp.free_support_barycenter_generic_costs)
res – Result of the optimization problem. The information can be obtained as follows:
res.X : Barycenter samples
res.b : Barycenter weights
res.value : Optimal value of the optimization problem
res.value_linear : Linear OT loss with the optimal OT plan
res.list_res: List of OTResult for each inner OT problem (one per source distribution)
res.log: log of the optimization process (if log=True)
See BaryResult for more information.
Notes
The following methods are available for solving barycenter problems with respect to these inner OT problems:
Classical exact OT problem [1] (default parameters) :
can be solved with the following code for various cost metrics between the source distributions and the barycenter:
# for squared Euclidean cost, where closed-form solutions are used to update the barycenter
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, metric='sqeuclidean')
# for uniform sample weights and barycentric weights,
res = ot.solve_bary_sample([x1, x2], n, [a1, a2], w, metric='sqeuclidean')
# for other cost functions, where the barycenter is updated with gradient descent using Pytorch
# refer to the documentation and examples for more details.
Entropic regularized OT [2] (when reg!=None):
can be solved with the following code:
# default is ``"KL"`` regularization (``reg_type="KL"``)
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0)
# or for original Sinkhorn paper formulation [2]
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='entropy')
Quadratic regularized OT [17] (when reg!=None and reg_type="L2"):
can be solved with the following code:
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='L2')
Unbalanced OT [41] (when unbalanced!=None):
can be solved with the following code:
# default is ``"KL"``
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0)
# quadratic unbalanced OT
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0, unbalanced_type='L2')
# TV = partial OT
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0, unbalanced_type='TV')
Regularized unbalanced regularized OT [34] (when unbalanced!=None and reg!=None):
can be solved with the following code:
# default is ``"KL"`` for both
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, unbalanced=1.0)
# quadratic unbalanced OT with KL regularization
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, unbalanced=1.0, unbalanced_type='L2')
# both quadratic
res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='L2', unbalanced=1.0, unbalanced_type='L2')
References
Cuturi, Marco, and Arnaud Doucet. “Fast computation of Wasserstein barycenters.” International Conference on Machine Learning. 2014.
Álvarez-Esteban, Pedro C., et al. “A fixed-point approach to barycenters in Wasserstein space.” Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024)
Batched version of ot.solve, use it to solve many entropic OT problems in parallel.
M (array-like, shape (B, ns, nt)) – Cost matrix
reg (float) – Regularization parameter for entropic regularization
a (array-like, shape (B, ns)) – Source distribution (optional). If None, uniform distribution is used.
b (array-like, shape (B, nt)) – Target distribution (optional). If None, uniform distribution is used.
max_iter (int) – Maximum number of iterations
tol (float) – Tolerance for convergence
solver (str) – Solver to use, either ‘log_sinkhorn’ or ‘sinkhorn’. Default is “log_sinkhorn” which is more stable.
reg_type (str, optional) – Type of regularization \(R\) either “KL”, or “entropy”. Default is “entropy”.
grad (str, optional) – Type of gradient computation, either or ‘autodiff’, ‘envelope’ or ‘last_step’ used only for Sinkhorn solver. By default ‘autodiff’ provides gradients wrt all outputs (plan, value, value_linear) but with important memory cost. ‘envelope’ provides gradients only for value and and other outputs are detached. This is useful for memory saving when only the value is needed. ‘last_step’ provides gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. ‘detach’ does not compute the gradients for the Sinkhorn solver.
res – Result of the optimization problem. The information can be obtained as follows:
res.plan : OT plan \(\mathbf{T}\)
res.potentials : OT dual potentials
res.value : Optimal value of the optimization problem
res.value_linear : Linear OT loss with the optimal OT plan
See OTResult for more information.
OTResult()
Examples
>>> import numpy as np
>>> from ot.batch import solve_batch, dist_batch
>>> X = np.random.randn(5, 10, 3) # 5 batches of 10 samples in 3D
>>> Y = np.random.randn(5, 15, 3) # 5 batches of 15 samples in 3D
>>> M = dist_batch(X, Y, metric="euclidean") # Compute cost matrices
>>> reg = 0.1
>>> result = solve_batch(M, reg)
>>> result.plan.shape # Optimal transport plans for each batch
(5, 10, 15)
>>> result.value.shape # Optimal transport values for each batch
(5,)
See also
ot.batch.dist_batchbatched cost matrix computation for computing M.
ot.solvenon-batched version of the OT solver.
Solve the discrete (Fused) Gromov-Wasserstein and return OTResult object
The function solves the following optimization problem:
The regularization is selected with reg (\(\lambda_r\)) and
reg_type. By default reg=None and there is no regularization. The
unbalanced marginal penalization can be selected with unbalanced
(\(\lambda_u\)) and unbalanced_type. By default unbalanced=None
and the function solves the exact optimal transport problem (respecting the
marginals).
Ca (array-like, shape (dim_a, dim_a)) – Cost matrix in the source domain
Cb (array-like, shape (dim_b, dim_b)) – Cost matrix in the target domain
M (array-like, shape (dim_a, dim_b), optional) – Linear cost matrix for Fused Gromov-Wasserstein (default is None).
a (array-like, shape (dim_a,), optional) – Samples weights in the source domain (default is uniform)
b (array-like, shape (dim_b,), optional) – Samples weights in the source domain (default is uniform)
loss (str, optional) – Type of loss function, either "L2" or "KL", by default "L2"
symmetric (bool, optional) – Use symmetric version of the Gromov-Wasserstein problem, by default None tests whether the matrices are symmetric or True/False to avoid the test.
reg (float, optional) – Regularization weight \(\lambda_r\), by default None (no reg., exact OT)
reg_type (str, optional) – Type of regularization \(R\), by default “entropy” (only used when
reg!=None)
alpha (float, optional) – Weight the quadratic term (alpha*Gromov) and the linear term
((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for
Gromov problem (when M is not provided). By default alpha=None
corresponds to alpha=1 for Gromov problem (M==None) and
alpha=0.5 for Fused Gromov-Wasserstein problem (M!=None)
unbalanced (float, optional) – Unbalanced penalization weight \(\lambda_u\), by default None (balanced OT). Not implemented yet for “KL” unbalanced penalization function \(U\). Corresponds to the total transport mass for partial OT.
unbalanced_type (str, optional) – Type of unbalanced penalization function \(U\) either “KL”, “semirelaxed”, “partial”, by default “KL”.
n_threads (int, optional) – Number of OMP threads for exact OT solver, by default 1
method (str, optional) – Method for solving the problem when multiple algorithms are available, default None for automatic selection.
max_iter (int, optional) – Maximum number of iterations, by default None (default values in each solvers)
plan_init (array-like, shape (dim_a, dim_b), optional) – Initialization of the OT plan for iterative methods, by default None
tol (float, optional) – Tolerance for solution precision, by default None (default values in each solvers)
verbose (bool, optional) – Print information in the solver, by default False
res – Result of the optimization problem. The information can be obtained as follows:
res.plan : OT plan \(\mathbf{T}\)
res.potentials : OT dual potentials
res.value : Optimal value of the optimization problem
res.value_linear : Linear OT loss with the optimal OT plan
res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan
See OTResult for more information.
OTResult()
Notes
The following methods are available for solving the Gromov-Wasserstein problem:
Classical Gromov-Wasserstein (GW) problem [3] (default parameters):
can be solved with the following code:
res = ot.solve_gromov(Ca, Cb) # uniform weights
res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights
res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss
plan = res.plan # GW plan
value = res.value # GW value
Fused Gromov-Wasserstein (FGW) problem [24] (when M!=None):
can be solved with the following code:
res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default)
res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha
plan = res.plan # FGW plan
loss_linear_term = res.value_linear # Wasserstein part of the loss
loss_quad_term = res.value_quad # Gromov part of the loss
loss = res.value # FGW value
Regularized (Fused) Gromov-Wasserstein (GW) problem [12] (when reg!=None):
can be solved with the following code:
res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default)
res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy
plan = res.plan # FGW plan
loss_linear_term = res.value_linear # Wasserstein part of the loss
loss_quad_term = res.value_quad # Gromov part of the loss
loss = res.value # FGW value (including regularization)
Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48] (when unbalanced='semirelaxed'):
can be solved with the following code:
res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW
res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW
res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW
plan = res.plan # FGW plan
right_marginal = res.marginal_b # right marginal of the plan
Partial (Fused) Gromov-Wasserstein (GW) problem [29] (when unbalanced='partial'):
can be solved with the following code:
res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8
res = ot.solve_gromov(Ca, Cb, M, unbalanced_type='partial', unbalanced=0.8, alpha=0.5) # partial FGW with m=0.8
References
Mémoli, F. (2011). Gromov–Wasserstein distances and the metric approach to object matching. Foundations of computational mathematics, 11(4), 417-487.
Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), Gromov-Wasserstein averaging of kernel and distance matrices International Conference on Machine Learning (ICML).
Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). Optimal Transport for structured data with application on graphs Proceedings of the 36th International Conference on Machine Learning (ICML).
Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and applications on graphs. International Conference on Learning Representations (ICLR), 2022.
Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport with Applications on Positive-Unlabeled Learning, Advances in Neural Information Processing Systems (NeurIPS), 2020.
Solves a batch of Gromov-Wasserstein optimal transport problems using proximal gradient [12, 81]. For each problem in the batch, solves:
If \(\mathbf{M}\) and \(\alpha\) are given, solves the more general fused Gromov-Wasserstein problem:
Writing the objective as \((1-\alpha) \langle \mathbf{M}, \mathbf{T} \rangle + \alpha \langle \mathcal{L} \otimes \mathbf{T}, \mathbf{T} \rangle\), the solver uses proximal gradient descent where each iteration is:
This can be rewritten as:
where \(H(\mathbf{T})\) is the entropy of \(\mathbf{T}\). Thus each iteration can be solved using the Bregman projection solver implemented in bregman_log_projection_batch.
Note that the inner optimization problem does not need to be solved exactly. In practice it sufficient to set max_iter_inner to a low value (e.g. 20) and/or tol_inner to a high value (e.g. 1e-2).
Ca (array-like, shape (B, n, n, d) or (B, n, n)) – Samples affinity matrices from source distribution
Cb (array-like, shape (B, n, n, d) or (B, n, n)) – Samples affinity matrices from target distribution
a (array-like, shape (B, n), optional) – Marginal distribution of the source samples. If None, uniform distribution is used.
b (array-like, shape (B, m), optional) – Marginal distribution of the target samples. If None, uniform distribution is used.
loss (str, optional) – Type of loss function, can be ‘sqeuclidean’ or ‘kl’ or a QuadraticMetric instance.
symmetric (bool, optional) – Either Ca and Cb are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Else if set to True (resp. False), Ca and Cb will be assumed symmetric (resp. asymmetric).
M (array-like, shape (dim_a, dim_b), optional) – Linear cost matrix for Fused Gromov-Wasserstein (default is None).
alpha (float, optional) – Weight the quadratic term (alpha*Gromov) and the linear term
((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for
Gromov problem (when M is not provided). By default alpha=None
corresponds to alpha=1 for Gromov problem (M==None) and
alpha=0.5 for Fused Gromov-Wasserstein problem (M!=None)
epsilon (float, optional) – Regularization parameter for proximal gradient descent. Default is 1e-2.
T_init (array-like, shape (B, n, m), optional) – Initial transport plan. If None, it is initialized to uniform distribution.
max_iter (int, optional) – Maximum number of iterations for the proximal gradient descent. Default is 50.
tol (float, optional) – Tolerance for convergence of the proximal gradient descent. Default is 1e-5.
max_iter_inner (int, optional) – Maximum number of iterations for the inner Bregman projection. Default is 50.
tol_inner (float, optional) – Tolerance for convergence of the inner Bregman projection. Default is 1e-5.
grad (str, optional) – Type of gradient computation, either or ‘autodiff’, ‘envelope’ or ‘detach’. ‘autodiff’ provides gradients wrt all outputs (plan, value, value_linear) but with important memory cost. ‘envelope’ provides gradients only for (value, value_linear)`. detach` is the fastest option but provides no gradients. Default is ‘detach’.
assume_inner_convergence (bool, optional) – If True, assumes that the inner Bregman projection always converged i.e. the transport plan satisfies the marginal constraints. This enables faster computations of the tensor product but might results in inaccurate results (e.g. negative values of the loss). Default is True.
res – Result of the optimization problem. The information can be obtained as follows:
res.plan : OT plan \(\mathbf{T}\)
res.potentials : OT dual potentials
res.value : Optimal value of the optimization problem
res.value_linear : Linear OT loss with the optimal OT plan
res.value_quad : Quadratic OT loss with the optimal OT plan
See OTResult for more information.
OTResult()
See also
ot.batch.tensor_batchFrom computing the cost tensor L.
ot.solve_gromovNon-batched solver for Gromov-Wasserstein. Note that the non-batched solver uses a different algorithm (conditional gradient) and might not give the same results.
References
Gabriel Peyré, Marco Cuturi, and Justin Solomon, “Gromov-Wasserstein averaging of kernel and distance matrices.” International Conference on Machine Learning (ICML). 2016.
Xu, H., Luo, D., & Carin, L. (2019). “Scalable Gromov-Wasserstein learning for graph partitioning and matching.” Advances in neural information processing systems (NeurIPS). 2019.
Solve the discrete optimal transport problem using the samples in the source and target domains.
The function solves the following general optimal transport problem
where the cost matrix \(\mathbf{M}\) is computed from the samples in the source and target domains such that \(M_{i,j} = d(x_i,y_j)\) where \(d\) is a metric (by default the squared Euclidean distance).
The regularization is selected with reg (\(\lambda_r\)) and reg_type. By
default reg=None and there is no regularization. The unbalanced marginal
penalization can be selected with unbalanced (\(\lambda_u\)) and
unbalanced_type. By default unbalanced=None and the function
solves the exact optimal transport problem (respecting the marginals).
X_a (array-like, shape (n_samples_a, dim)) – samples in the source domain
X_b (array-like, shape (n_samples_b, dim)) – samples in the target domain
a (array-like, shape (dim_a,), optional) – Samples weights in the source domain (default is uniform)
b (array-like, shape (dim_b,), optional) – Samples weights in the source domain (default is uniform)
reg (float, optional) – Regularization weight \(\lambda_r\), by default None (no reg., exact OT)
c (array-like, shape (dim_a, dim_b), optional (default=None)) – Reference measure for the regularization. If None, then use \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\). If \(\texttt{reg_type}=\)’entropy’, then \(\mathbf{c} = 1_{dim_a} 1_{dim_b}^T\).
reg_type (str, optional) – Type of regularization \(R\) either “KL”, “L2”, “entropy”, by default “KL”
unbalanced (float or indexable object of length 1 or 2) – Marginal relaxation term. If it is a scalar or an indexable object of length 1, then the same relaxation is applied to both marginal relaxations. The balanced OT can be recovered using \(unbalanced=float("inf")\). For semi-relaxed case, use either \(unbalanced=(float("inf"), scalar)\) or \(unbalanced=(scalar, float("inf"))\). If unbalanced is an array, it must have the same backend as input arrays (a, b, M).
unbalanced_type (str, optional) – Type of unbalanced penalization function \(U\) either “KL”, “L2”, “TV”, by default “KL”
lazy (bool, optional) – Return OTResultlazy object to reduce memory cost when True, by
default False
batch_size (int, optional) – Batch size for lazy solver, by default None (default values in each solvers)
method (str, optional) – Method for solving the problem, this can be used to select the solver
for unbalanced problems (see ot.solve), or to select a specific
large scale solver.
n_threads (int, optional) – Number of OMP threads for exact OT solver, by default 1
max_iter (int, optional) – Maximum number of iteration, by default None (default values in each solvers)
plan_init (array-like, shape (dim_a, dim_b), optional) – Initialization of the OT plan for iterative methods, by default None
rank (int, optional) – Rank of the OT matrix for lazy solvers (method=’factored’) or (method=’nystroem’), by default 100
scaling (float, optional) – Scaling factor for the epsilon scaling lazy solvers (method=’geomloss’), by default 0.95
potentials_init ((array-like(dim_a,),array-like(dim_b,)), optional) – Initialization of the OT dual potentials for iterative methods (Sinkhorn, unbalanced) or warmstart for exact EMD solver. For EMD, these potentials are used to guide initial pivots in the network simplex. By default None
tol (_type_, optional) – Tolerance for solution precision, by default None (default values in each solvers)
verbose (bool, optional) – Print information in the solver, by default False
grad (str, optional) – Type of gradient computation, either or ‘autodiff’ or ‘envelope’ used only for Sinkhorn solver. By default ‘autodiff’ provides gradients wrt all outputs (plan, value, value_linear) but with important memory cost. ‘envelope’ provides gradients only for value and and other outputs are detached. This is useful for memory saving when only the value is needed.
random_state (int, optional) – The random state for sampling the components in each distribution for method=’nystroem’.
res – Result of the optimization problem. The information can be obtained as follows:
res.plan : OT plan \(\mathbf{T}\)
res.potentials : OT dual potentials
res.value : Optimal value of the optimization problem
res.value_linear : Linear OT loss with the optimal OT plan
res.lazy_plan : Lazy OT plan (when lazy=True or lazy method)
See OTResult for more information.
OTResult()
Notes
The following methods are available for solving the OT problems:
Classical exact OT problem [1] (default parameters) :
can be solved with the following code:
res = ot.solve_sample(xa, xb, a, b)
# for uniform weights
res = ot.solve_sample(xa, xb)
Entropic regularized OT [2] (when reg!=None):
can be solved with the following code:
# default is ``"KL"`` regularization (``reg_type="KL"``)
res = ot.solve_sample(xa, xb, a, b, reg=1.0)
# or for original Sinkhorn paper formulation [2]
res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy')
# lazy solver of memory complexity O(n)
res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100)
# lazy OT plan
lazy_plan = res.lazy_plan
# Use envelope theorem differentiation for memory saving
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope')
res.value.backward() # only the value is differentiable
Note that by default the Sinkhorn solver uses automatic differentiation to compute the gradients of the values and plan. This can be changed with the grad parameter. The envelope mode computes the gradients only for the value and the other outputs are detached. This is useful for memory saving when only the gradient of value is needed.
We also have a very efficient solver with compiled CPU/CUDA code using geomloss/PyKeOps that can be used with the following code:
# automatic solver
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss')
Warning
The geomloss solver is a thin wrapper around the geomloss.ot.solve_sample function. The API is still under development and some features might be missing. Please refer to the geomloss documentation for more information.
Quadratic regularized OT [17] (when reg!=None and reg_type="L2"):
can be solved with the following code:
res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2')
Unbalanced OT [41] (when unbalanced!=None):
can be solved with the following code:
# default is ``"KL"``
res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0)
# quadratic unbalanced OT
res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2')
# TV = partial OT
res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV')
Regularized unbalanced regularized OT [34] (when unbalanced!=None and reg!=None):
can be solved with the following code:
# default is ``"KL"`` for both
res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0)
# quadratic unbalanced OT with KL regularization
res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2')
# both quadratic
res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2',
unbalanced=1.0, unbalanced_type='L2')
Factored OT [2] (when method='factored'):
This method solve the following OT problem [40]_
where $mu$ is a uniform weighted empirical distribution of \(\mu_a\) and \(\mu_b\) are the empirical measures associated to the samples in the source and target domains, and \(W_2\) is the Wasserstein distance. This problem is solved using exact OT solvers for reg=None and the Sinkhorn solver for reg!=None. The solution provides two transport plans that can be used to recover a low rank OT plan between the two distributions.
res = ot.solve_sample(xa, xb, method='factored', rank=10)
# recover the lazy low rank plan
factored_solution_lazy = res.lazy_plan
# recover the full low rank plan
factored_solution = factored_solution_lazy[:]
** Nystroem OT [76] ** (when method='nystroem'):
Corresponds to a low rank approximation of entropic OT (for a squared Euclidean cost) that runs in linear time.
Gaussian Bures-Wasserstein [2] (when method='gaussian'):
This method computes the Gaussian Bures-Wasserstein distance between two Gaussian distributions estimated from the empirical distributions
where :
The covariances and means are estimated from the data.
res = ot.solve_sample(xa, xb, method='gaussian')
# recover the squared Gaussian Bures-Wasserstein distance
BW_dist = res.value
Wasserstein 1d [1] (when method='1D'):
This method computes the Wasserstein distance between two 1d distributions estimated from the empirical distributions. For multivariate data the distances are computed independently for each dimension.
res = ot.solve_sample(xa, xb, method='1D')
# recover the squared Wasserstein distances
W_dists = res.value
References
Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). Statistical optimal transport via factored couplings. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.
Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS.
Scetbon, M., Cuturi, M., & Peyré, G. (2021). Low-rank Sinkhorn Factorization. In International Conference on Machine Learning.
Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J. (2019). Massively scalable Sinkhorn distances via the Nyström method. NeurIPS.
Batched version of ot.solve, use it to solve many entropic OT problems in parallel.
M (array-like, shape (B, ns, nt)) – Cost matrix
reg (float) – Regularization parameter for entropic regularization
metric (str, optional) – ‘sqeuclidean’, ‘euclidean’, ‘minkowski’ or ‘kl’
p (float, optional) – p-norm for the Minkowski metrics. Default value is 2.
a (array-like, shape (B, ns)) – Source distribution (optional). If None, uniform distribution is used.
b (array-like, shape (B, nt)) – Target distribution (optional). If None, uniform distribution is used.
max_iter (int) – Maximum number of iterations
tol (float) – Tolerance for convergence
solver (str) – Solver to use, either ‘log_sinkhorn’ or ‘sinkhorn’. Default is “log_sinkhorn” which is more stable.
reg_type (str, optional) – Type of regularization \(R\) either “KL”, or “entropy”. Default is “entropy”.
grad (str, optional) – Type of gradient computation, either or ‘autodiff’, ‘envelope’ or ‘last_step’ used only for Sinkhorn solver. By default ‘autodiff’ provides gradients wrt all outputs (plan, value, value_linear) but with important memory cost. ‘envelope’ provides gradients only for value and and other outputs are detached. This is useful for memory saving when only the value is needed. ‘last_step’ provides gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. ‘detach’ does not compute the gradients for the Sinkhorn solver.
res – Result of the optimization problem. The information can be obtained as follows:
res.plan : OT plan \(\mathbf{T}\)
res.potentials : OT dual potentials
res.value : Optimal value of the optimization problem
res.value_linear : Linear OT loss with the optimal OT plan
See OTResult for more information.
OTResult()
See also
ot.batch.solve_batchsolver for computing the optimal T from arbitrary cost matrix M.
Compute the Unbalanced Sliced Optimal Transpot (USOT) with KL regularization between two empirical distributions. The Unbalanced SOT problem reads as
The USOT problem is solved with a Frank-Wolfe algorithm as proposed in [82].
Warning
This function only works in pytorch or jax as it uses autodifferentiation to compute the 1D potentials. It is not maintained in jax.
X_s (ndarray, shape (n_samples_a, dim)) – samples in the source domain
X_t (ndarray, shape (n_samples_b, dim)) – samples in the target domain
reg_m (float or indexable object of length 1 or 2) – Marginal relaxation term. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. The balanced OT can be recovered using reg_m=float(“inf”). For semi-relaxed case, use either reg_m=(float(“inf”), scalar) or reg_m=(scalar, float(“inf”)). If reg_m is an array, it must have the same backend as input arrays (X_s, X_t).
a (ndarray, shape (n_samples_a,), optional) – samples weights in the source domain
b (ndarray, shape (n_samples_b,), optional) – samples weights in the target domain
n_projections (int, optional) – Number of projections used for the Monte-Carlo approximation
p (float, optional, by default =2) – Power p used for computing the sliced Wasserstein
projections (shape (dim, n_projections), optional) – Projection matrix (n_projections and seed are not used in this case)
seed (int or RandomState or None, optional) – Seed used for random number generator
numItermax (int, optional)
log (bool, optional) – if True, returns the sot loss, the projections used, their associated EMD and the full mass of the reweighted marginals.
a_reweighted (array-like shape (n, …)) – First marginal reweighted
b_reweighted (array-like shape (m, …)) – Second marginal reweighted
loss (float/array-like, shape (…)) – USOT
log (dict, optional) – If log is True, then returns a dictionary containing the projection directions used, the 1D OT losses, the SOT loss and the full mass of reweighted marginals.
References
Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research.
See also
ot.unbalanced.uot_1d1D OT problem
ot.unbalanced.sliced_unbalanced_otSUOT problem
Return a uniform histogram of length n (simplex).
n (int) – number of bins in the histogram
type_as (array-like) – array of the same type of the expected output (numpy/pytorch/jax)
h – histogram of length n such that \(\forall i, \mathbf{h}_i = \frac{1}{n}\)
array-like, shape (n,)
Solves the 1D unbalanced OT problem with KL regularization. The function implements the Frank-Wolfe algorithm to solve the dual problem, as proposed in [73].
The unbalanced OT problem reads
Warning
This function only works in pytorch or jax as it uses autodifferentiation to compute the potentials. It is not maintained in jax.
u_values (array-like, shape (n, ...)) – locations of the first empirical distribution
v_values (array-like, shape (m, ...)) – locations of the second empirical distribution
reg_m (float or indexable object of length 1 or 2) – Marginal relaxation term. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. The balanced OT can be recovered using reg_m=float(“inf”). For semi-relaxed case, use either reg_m=(float(“inf”), scalar) or reg_m=(scalar, float(“inf”)). If reg_m is an array, it must have the same backend as input arrays (u_values, v_values).
u_weights (array-like, shape (n, ...), optional) – weights of the first empirical distribution, if None then uniform weights are used
v_weights (array-like, shape (m, ...), optional) – weights of the second empirical distribution, if None then uniform weights are used
p (int, optional) – order of the ground metric used, should be at least 1, default is 2
require_sort (bool, optional) – sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to the function, default is True
numItermax (int, optional)
returnCost (string, optional (default = "linear")) – If returnCost = “linear”, then return the linear part of the unbalanced OT loss. If returnCost = “total”, then return the total unbalanced OT loss.
log (bool, optional)
u_reweighted (array-like shape (n, …)) – First marginal reweighted
v_reweighted (array-like shape (m, …)) – Second marginal reweighted
loss (float/array-like, shape (…)) – The batched 1D UOT
log (dict, optional) – If log is True, then returns a dictionary containing the dual potentials, the total cost and the linear cost.
References
Séjourné, T., Vialard, F. X., & Peyré, G. (2022). Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
Computes the 1 dimensional OT loss [15] between two (batched) empirical distributions
It is formally the p-Wasserstein distance raised to the power p. We do so in a vectorized way by first building the individual quantile functions then integrating them.
This function should be preferred to emd_1d whenever the backend is different to numpy, and when gradients over either sample positions or weights are required.
u_values (array-like, shape (n, ...)) – locations of the first empirical distributions
v_values (array-like, shape (m, ...)) – locations of the second empirical distributions
u_weights (array-like, shape (n, ...), optional) – weights of the first empirical distributions, if None then uniform weights are used
v_weights (array-like, shape (m, ...), optional) – weights of the second empirical distributions, if None then uniform weights are used
p (int, optional) – order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1
require_sort (bool, optional) – sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to the function, default is True
return_plans (True, False or "coo_tuple", optional) – if True, also returns the optimal transport plan between the two (batched) measures as a coo_matrix, default is False. if “coo_tuple”, returns the optimal transport plans as a tuple of (data, rows, cols) of the non-zero elements of the transportation matrix. This is useful for backends that do not support well sparse matrices (e.g. JAX, Tensorflow).
cost (float/array-like, shape (…)) – the batched EMD
plans (list of coo_matrix or namedTuple, optional) – if return_plans is True, returns a list of coo_matrix containing the plans. if return_plans is “coo_tuple”, returns the plans as a list of namedTuple containing the data, rows and cols of the non-zero elements of the transportation matrix.
References
Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport.
Computes the Wasserstein distance on the circle using either [45] for p=1 or the binary search algorithm proposed in [44] otherwise. Samples need to be in \(S^1\cong [0,1[\). If they are on \(\mathbb{R}\), takes the value modulo 1. If the values are on \(S^1\subset\mathbb{R}^2\), it requires to first find the coordinates using e.g. the atan2 function.
General loss returned:
For p=1, [45]
For values \(x=(x_1,x_2)\in S^1\), it is required to first get their coordinates with
using e.g. ot.utils.get_coordinate_circle(x)
The function runs on backend but tensorflow and jax are not supported.
u_values (ndarray, shape (n, ...)) – samples in the source domain (coordinates on [0,1[)
v_values (ndarray, shape (n, ...)) – samples in the target domain (coordinates on [0,1[)
u_weights (ndarray, shape (n, ...), optional) – samples weights in the source domain
v_weights (ndarray, shape (n, ...), optional) – samples weights in the target domain
p (float, optional (default=1)) – Power p used for computing the Wasserstein distance
Lm (int, optional) – Lower bound dC. For p>1.
Lp (int, optional) – Upper bound dC. For p>1.
tm (float, optional) – Lower bound theta. For p>1.
tp (float, optional) – Upper bound theta. For p>1.
eps (float, optional) – Stopping condition. For p>1.
require_sort (bool, optional) – If True, sort the values.
loss – Batched cost associated to the optimal transportation
float/array-like, shape (…)
Examples
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> wasserstein_circle(u.T, v.T)
array([0.1])
References
Hundrieser, Shayan, Marcel Klatt, and Axel Munk. “The statistics of circular optimal transport.” Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
Delon, Julie, Julien Salomon, and Andrei Sobolevski. “Fast transport optimization for Monge costs on the circle.” SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
Solves the weak optimal transport problem between two empirical distributions
where :
\(X^a\) and \(X^b\) are the sample matrices.
\(\mathbf{a}\) and \(\mathbf{b}\) are the sample weights
Note
This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays.
Uses the conditional gradient algorithm to solve the problem proposed in [39].
Xa ((ns,d) array-like, float) – Source samples
Xb ((nt,d) array-like, float) – Target samples
a ((ns,) array-like, float) – Source histogram (uniform weight if empty list)
b ((nt,) array-like, float) – Target histogram (uniform weight if empty list))
G0 ((ns,nt) array-like, float) – initial guess (default is indep joint density)
numItermax (int, optional) – Max number of iterations
numItermaxEmd (int, optional) – Max number of iterations for emd
stopThr (float, optional) – Stop threshold on the relative variation (>0)
stopThr2 (float, optional) – Stop threshold on the absolute variation (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
gamma (array-like, shape (ns, nt)) – Optimal transportation matrix for the given parameters
log (dict, optional) – If input log is true, a dictionary containing the cost and dual variables and exit status
References
Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). Kantorovich duality for general transport costs and applications. Journal of Functional Analysis, 273(11), 3327-3405.
See also
ot.bregman.sinkhornEntropic regularized OT
ot.optim.cgGeneral regularized OT