ot_distance {causalOT} | R Documentation |
Optimal Transport Distance
Description
Optimal Transport Distance
Usage
ot_distance(
x1,
x2 = NULL,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
## S3 method for class 'causalWeights'
ot_distance(
x1,
x2 = NULL,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
## S3 method for class 'matrix'
ot_distance(
x1,
x2,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
## S3 method for class 'array'
ot_distance(
x1,
x2,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
## S3 method for class 'torch_tensor'
ot_distance(
x1,
x2,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
Arguments
x1 |
Either an object of class causalWeights or a matrix of the covariates in the first sample |
x2 |
|
a |
Empirical measure of the first sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights. |
b |
Empirical measure of the second sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights. |
penalty |
The penalty of the optimal transport distance to use. If missing or NULL, the function will try to guess a suitable value depending if debias is TRUE or FALSE. |
p |
|
cost |
Supply your own cost function. Should take arguments |
debias |
TRUE or FALSE. Should the debiased optimal transport distances be used. |
online.cost |
How to calculate the distance matrix. One of "auto", "tensorized", or "online". |
diameter |
The diameter of the metric space, if known. Default is NULL. |
niter |
The maximum number of iterations for the Sinkhorn updates |
tol |
The tolerance for convergence |
Value
For objects of class matrix, numeric value giving the optimal transport distance. For objects of class causalWeights, results are returned as a list for before ('pre') and after adjustment ('post').
Methods (by class)
-
ot_distance(causalWeights)
: method for causalWeights class -
ot_distance(matrix)
: method for matrices -
ot_distance(array)
: method for arrays -
ot_distance(torch_tensor)
: method for torch_tensors
Examples
if ( torch::torch_is_installed()) {
x <- matrix(stats::rnorm(10*5), 10, 5)
z <- stats::rbinom(10, 1, 0.5)
weights <- calc_weight(x = x, z = z, method = "Logistic", estimand = "ATT")
ot1 <- ot_distance(x1 = weights, penalty = 100,
p = 2, debias = TRUE, online.cost = "auto",
diameter = NULL)
ot2<- ot_distance(x1 = x[z==0, ], x2 = x[z == 1,],
a= weights@w0/sum(weights@w0), b = weights@w1,
penalty = 100, p = 2, debias = TRUE, online.cost = "auto", diameter = NULL)
all.equal(ot1$post, ot2)
}