| cotOptions {causalOT} | R Documentation |
Options available for the COT method
Description
Options available for the COT method
Usage
cotOptions(
lambda = NULL,
delta = NULL,
opt.direction = c("dual", "primal"),
debias = TRUE,
p = 2,
cost.function = NULL,
cost.online = "auto",
diameter = NULL,
balance.formula = NULL,
quick.balance.function = TRUE,
grid.length = 7L,
torch.optimizer = torch::optim_rmsprop,
torch.scheduler = torch::lr_multiplicative,
niter = 2000,
nboot = 100L,
lambda.bootstrap = 0.05,
tol = 1e-04,
device = NULL,
dtype = NULL,
...
)
Arguments
lambda |
The penalty parameter for the entropy penalized optimal transport. Default is NULL. Can be a single number or a set of numbers to try. |
delta |
The bound for balancing functions if they are being used. Only available for biased entropy penalized optimal transport. Can be a single number or a set of numbers to try. |
opt.direction |
Should the optimizer solve the primal or dual problems. Should be one of "dual" or "primal" with a default of "dual" since it is typically faster. |
debias |
Should debiased optimal transport be used? TRUE or FALSE. |
p |
The power of the cost function to use for the cost. |
cost.function |
A function to calculate the pairwise costs. Should take arguments |
cost.online |
Should an online cost algorithm be used? One of "auto", "online", or "tensorized". "tensorized" is the offline option. |
diameter |
The diameter of the covariate space, if known. Default is NULL. |
balance.formula |
Formula for the balancing functions. |
quick.balance.function |
TRUE or FALSE denoting whether balance function constraints should be selected via a linear program (TRUE) or just checked for feasibility (FALSE). Default is TRUE. |
grid.length |
The number of penalty parameters to explore in a grid search if none are provided in arguments |
torch.optimizer |
The torch optimizer to use for methods using debiased entropy penalized optimal transport. If |
torch.scheduler |
The scheduler for the optimizer. Defaults to |
niter |
The number of iterations to run the solver |
nboot |
The number of iterations for the bootstrap to select the final penalty parameters. |
lambda.bootstrap |
The penalty parameter to use for the bootstrap hyperparameter selection of lambda. |
tol |
The tolerance for convergence |
device |
An object of class |
dtype |
An object of class |
... |
Arguments passed to the solvers. See details |
Value
A list of class cotOptions with the following slots
-
lambdaThe penalty parameter for the optimal transport distance -
deltaThe constraint for the balancing functions -
opt.directionWhether to solve the primal or dual optimization problems -
debiasTRUE or FALSE if debiased optimal transport distances are used -
balance.formulaThe formula giving how to generate the balancing functions. -
quick.balance.functionTRUE or FALSE whether quick balance functions will be run. -
grid.lengthThe number of parameters to check in a grid search of best parameters -
pThe power of the cost function -
cost.onlineWhether online costs are used -
cost.functionThe user supplied cost function if supplied. -
diameterThe diameter of the covariate space. -
torch.optimizerThetorchoptimizer used for Sinkhorn Divergences -
torch.schedulerThe scheduler for thetorchoptimizer -
solver.optionsThe arguments to be passeed to thetorch.optimizer -
scheduler.optionsThe arguments to be passeed to thetorch.scheduler -
osqp.optionsArguments passed to theosqpfunction if quick balance functions are used. -
niterThe number of iterations to run the solver -
nbootThe number of bootstrap samples -
lambda.bootstrapThe penalty parameter to use for the bootstrap hyperparameter selection. -
tolThe tolerance for convergence. -
deviceAn object of classtorch_device. -
dtypeAn object of classtorch_dtype.
Solvers and distances
The function is setup to direct the COT optimizer to run two basic methods: debiased entropy penalized optimal transport (Sinkhorn Divergences) or entropy penalized optimal transport (Sinkhorn Distances).
Sinkhorn Distances
The optimal transport problem solved is min_w OT_\lambda(w,b) where
OT_\lambda(w,b) = \sum_{ij} C(x_i, x_j) P_{ij} + \lambda \sum_{ij} P_{ij}\log(P_{ij}),
such that the rows of the matrix P_{ij} sum to w and the columns sum to b. In this case C(,) is the cost between units i and j.
Sinkhorn Divergences
The Sinkhorn Divergence solves
min_w OT_\lambda(w,b) - 0.5 OT_\lambda(w,w) - 0.5 * OT_\lambda(b,b).
The solver for this function uses the torch package in R and by default will use the optim_rmsprop solver. Your desired torch optimizer can be passed via torch.optimizer with a scheduler passed via torch.scheduler. GPU support is available as detailed in the torch package. Additional arguments in ... are passed as extra arguments to the torch optimizer and schedulers as appropriate.
Function balancing
There may be certain functions of the covariates that we wish to balance within some tolerance, \delta. For these functions B, we will desire
\frac{\sum_{i: Z_i = 0} w_i B(x_i) - \sum_{j: Z_j = 1} B(x_j)/n_1}{\sigma} \leq \delta
, where in this case we are targeting balance with the treatment group for the ATT. \sigma is the pooled standard deviation prior to balancing.
Cost functions
The cost function specifies pairwise distances. If argument cost.function is NULL, the function will default to using L_p^p distances with a default p = 2 supplied by the argument p. So for p = 2, the cost between units x_i and x_j will be
C(x_i, x_j) = \frac{1}{2} \| x_i - x_j \|_2^2.
If cost.function is provided, it should be a function that takes arguments x1, x2, and p: function(x1, x2, p){...}.
Examples
if ( torch::torch_is_installed()) {
opts1 <- cotOptions(lambda = 1e3, torch.optimizer = torch::optim_rmsprop)
opts2 <- cotOptions(lambda = NULL)
opts3 <- cotOptions(lambda = seq(0.1, 100, length.out = 7))
}