ot_distance {causalOT}R Documentation

Optimal Transport Distance

Description

Optimal Transport Distance

Usage

ot_distance(
  x1,
  x2 = NULL,
  a = NULL,
  b = NULL,
  penalty,
  p = 2,
  cost = NULL,
  debias = TRUE,
  online.cost = "auto",
  diameter = NULL,
  niter = 1000,
  tol = 1e-07
)

## S3 method for class 'causalWeights'
ot_distance(
  x1,
  x2 = NULL,
  a = NULL,
  b = NULL,
  penalty,
  p = 2,
  cost = NULL,
  debias = TRUE,
  online.cost = "auto",
  diameter = NULL,
  niter = 1000,
  tol = 1e-07
)

## S3 method for class 'matrix'
ot_distance(
  x1,
  x2,
  a = NULL,
  b = NULL,
  penalty,
  p = 2,
  cost = NULL,
  debias = TRUE,
  online.cost = "auto",
  diameter = NULL,
  niter = 1000,
  tol = 1e-07
)

## S3 method for class 'array'
ot_distance(
  x1,
  x2,
  a = NULL,
  b = NULL,
  penalty,
  p = 2,
  cost = NULL,
  debias = TRUE,
  online.cost = "auto",
  diameter = NULL,
  niter = 1000,
  tol = 1e-07
)

## S3 method for class 'torch_tensor'
ot_distance(
  x1,
  x2,
  a = NULL,
  b = NULL,
  penalty,
  p = 2,
  cost = NULL,
  debias = TRUE,
  online.cost = "auto",
  diameter = NULL,
  niter = 1000,
  tol = 1e-07
)

Arguments

x1

Either an object of class causalWeights or a matrix of the covariates in the first sample

x2

NULL or a matrix of the covariates in the second sample.

a

Empirical measure of the first sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights.

b

Empirical measure of the second sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights.

penalty

The penalty of the optimal transport distance to use. If missing or NULL, the function will try to guess a suitable value depending if debias is TRUE or FALSE.

p

L_p distance metric power

cost

Supply your own cost function. Should take arguments x1, x2, and p.

debias

TRUE or FALSE. Should the debiased optimal transport distances be used.

online.cost

How to calculate the distance matrix. One of "auto", "tensorized", or "online".

diameter

The diameter of the metric space, if known. Default is NULL.

niter

The maximum number of iterations for the Sinkhorn updates

tol

The tolerance for convergence

Value

For objects of class matrix, numeric value giving the optimal transport distance. For objects of class causalWeights, results are returned as a list for before ('pre') and after adjustment ('post').

Methods (by class)

Examples

if ( torch::torch_is_installed()) {
x <- matrix(stats::rnorm(10*5), 10, 5)
z <- stats::rbinom(10, 1, 0.5)
weights <- calc_weight(x = x, z = z, method = "Logistic", estimand = "ATT")
ot1 <- ot_distance(x1 = weights, penalty = 100, 
p = 2, debias = TRUE, online.cost = "auto", 
diameter = NULL)
ot2<- ot_distance(x1 = x[z==0, ], x2 = x[z == 1,], 
a= weights@w0/sum(weights@w0), b = weights@w1,
 penalty = 100, p = 2, debias = TRUE, online.cost = "auto", diameter = NULL)

 all.equal(ot1$post, ot2) 
}

[Package causalOT version 1.0.2 Index]