cotOptions {causalOT}R Documentation

Options available for the COT method

Description

Options available for the COT method

Usage

cotOptions(
  lambda = NULL,
  delta = NULL,
  opt.direction = c("dual", "primal"),
  debias = TRUE,
  p = 2,
  cost.function = NULL,
  cost.online = "auto",
  diameter = NULL,
  balance.formula = NULL,
  quick.balance.function = TRUE,
  grid.length = 7L,
  torch.optimizer = torch::optim_rmsprop,
  torch.scheduler = torch::lr_multiplicative,
  niter = 2000,
  nboot = 100L,
  lambda.bootstrap = 0.05,
  tol = 1e-04,
  device = NULL,
  dtype = NULL,
  ...
)

Arguments

lambda

The penalty parameter for the entropy penalized optimal transport. Default is NULL. Can be a single number or a set of numbers to try.

delta

The bound for balancing functions if they are being used. Only available for biased entropy penalized optimal transport. Can be a single number or a set of numbers to try.

opt.direction

Should the optimizer solve the primal or dual problems. Should be one of "dual" or "primal" with a default of "dual" since it is typically faster.

debias

Should debiased optimal transport be used? TRUE or FALSE.

p

The power of the cost function to use for the cost.

cost.function

A function to calculate the pairwise costs. Should take arguments x1, x2, and p. Default is NULL.

cost.online

Should an online cost algorithm be used? One of "auto", "online", or "tensorized". "tensorized" is the offline option.

diameter

The diameter of the covariate space, if known. Default is NULL.

balance.formula

Formula for the balancing functions.

quick.balance.function

TRUE or FALSE denoting whether balance function constraints should be selected via a linear program (TRUE) or just checked for feasibility (FALSE). Default is TRUE.

grid.length

The number of penalty parameters to explore in a grid search if none are provided in arguments lambda or delta.

torch.optimizer

The torch optimizer to use for methods using debiased entropy penalized optimal transport. If debiased is FALSE or opt.direction is "primal", will default to torch::optim_lbfgs(). Otherwise torch::optim_rmsprop() is used.

torch.scheduler

The scheduler for the optimizer. Defaults to torch::lr_multiplicative().

niter

The number of iterations to run the solver

nboot

The number of iterations for the bootstrap to select the final penalty parameters.

lambda.bootstrap

The penalty parameter to use for the bootstrap hyperparameter selection of lambda.

tol

The tolerance for convergence

device

An object of class torch_device denoting which device the data will be located on. Default is NULL which will try to use a gpu if available.

dtype

An object of class torch_dtype that determines data type of the data, i.e. double, float, integer. Default is NULL which will try to select for you.

...

Arguments passed to the solvers. See details

Value

A list of class cotOptions with the following slots

Solvers and distances

The function is setup to direct the COT optimizer to run two basic methods: debiased entropy penalized optimal transport (Sinkhorn Divergences) or entropy penalized optimal transport (Sinkhorn Distances).

Sinkhorn Distances

The optimal transport problem solved is min_w OT_\lambda(w,b) where

OT_\lambda(w,b) = \sum_{ij} C(x_i, x_j) P_{ij} + \lambda \sum_{ij} P_{ij}\log(P_{ij}),

such that the rows of the matrix P_{ij} sum to w and the columns sum to b. In this case C(,) is the cost between units i and j.

Sinkhorn Divergences

The Sinkhorn Divergence solves

min_w OT_\lambda(w,b) - 0.5 OT_\lambda(w,w) - 0.5 * OT_\lambda(b,b).

The solver for this function uses the torch package in R and by default will use the optim_rmsprop solver. Your desired torch optimizer can be passed via torch.optimizer with a scheduler passed via torch.scheduler. GPU support is available as detailed in the torch package. Additional arguments in ... are passed as extra arguments to the torch optimizer and schedulers as appropriate.

Function balancing

There may be certain functions of the covariates that we wish to balance within some tolerance, \delta. For these functions B, we will desire

\frac{\sum_{i: Z_i = 0} w_i B(x_i) - \sum_{j: Z_j = 1} B(x_j)/n_1}{\sigma} \leq \delta

, where in this case we are targeting balance with the treatment group for the ATT. \sigma is the pooled standard deviation prior to balancing.

Cost functions

The cost function specifies pairwise distances. If argument cost.function is NULL, the function will default to using L_p^p distances with a default p = 2 supplied by the argument p. So for p = 2, the cost between units x_i and x_j will be

C(x_i, x_j) = \frac{1}{2} \| x_i - x_j \|_2^2.

If cost.function is provided, it should be a function that takes arguments x1, x2, and p: function(x1, x2, p){...}.

Examples

if ( torch::torch_is_installed()) {
opts1 <- cotOptions(lambda = 1e3, torch.optimizer = torch::optim_rmsprop)
opts2 <- cotOptions(lambda = NULL)
opts3 <- cotOptions(lambda = seq(0.1, 100, length.out = 7))
}

[Package causalOT version 1.0.2 Index]