ipot {T4transport} | R Documentation |
Wasserstein Distance by Inexact Proximal Point Method
Description
Due to high computational cost for linear programming approaches to compute
Wasserstein distance, Cuturi (2013) proposed an entropic regularization
scheme as an efficient approximation to the original problem. This comes with
a regularization parameter \lambda > 0
in the term
\lambda h(\Gamma) = \lambda \sum_{m,n} \Gamma_{m,n} \log (\Gamma_{m,n}).
IPOT algorithm is known to be relatively robust to the choice of
regularization parameter \lambda
. Empirical observation says that
very small number of inner loop iteration like L=1
is sufficient.
Usage
ipot(X, Y, p = 2, wx = NULL, wy = NULL, lambda = 1, ...)
ipotD(D, p = 2, wx = NULL, wy = NULL, lambda = 1, ...)
Arguments
X |
an |
Y |
an |
p |
an exponent for the order of the distance (default: 2). |
wx |
a length- |
wy |
a length- |
lambda |
a regularization parameter (default: 0.1). |
... |
extra parameters including
|
D |
an |
Value
a named list containing
- distance
\mathcal{W}_p
distance value- iteration
the number of iterations it took to converge.
- plan
an
(M\times N)
nonnegative matrix for the optimal transport plan.
References
Xie Y, Wang X, Wang R, Zha H (2020). “A fast proximal point method for computing exact wasserstein distance.” In Adams RP, Gogate V (eds.), Proceedings of The 35th Uncertainty in Artificial Intelligence Conference, volume 115 of Proceedings of machine learning research, 433–453.
Examples
#-------------------------------------------------------------------
# Wasserstein Distance between Samples from Two Bivariate Normal
#
# * class 1 : samples from Gaussian with mean=(-1, -1)
# * class 2 : samples from Gaussian with mean=(+1, +1)
#-------------------------------------------------------------------
## SMALL EXAMPLE
set.seed(100)
m = 20
n = 30
X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X
Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y
## COMPARE WITH WASSERSTEIN
outw = wasserstein(X, Y)
ipt1 = ipot(X, Y, lambda=1)
ipt2 = ipot(X, Y, lambda=10)
## VISUALIZE : SHOW THE PLAN AND DISTANCE
pmw = paste0("wasserstein plan ; dist=",round(outw$distance,2))
pm1 = paste0("ipot lbd=1 ; dist=",round(ipt1$distance,2))
pm2 = paste0("ipot lbd=10; dist=",round(ipt2$distance,2))
opar <- par(no.readonly=TRUE)
par(mfrow=c(1,3))
image(outw$plan, axes=FALSE, main=pmw)
image(ipt1$plan, axes=FALSE, main=pm1)
image(ipt2$plan, axes=FALSE, main=pm2)
par(opar)