cotOptions {causalOT} | R Documentation |
Options available for the COT method
Description
Options available for the COT method
Usage
cotOptions(
lambda = NULL,
delta = NULL,
opt.direction = c("dual", "primal"),
debias = TRUE,
p = 2,
cost.function = NULL,
cost.online = "auto",
diameter = NULL,
balance.formula = NULL,
quick.balance.function = TRUE,
grid.length = 7L,
torch.optimizer = torch::optim_rmsprop,
torch.scheduler = torch::lr_multiplicative,
niter = 2000,
nboot = 100L,
lambda.bootstrap = 0.05,
tol = 1e-04,
device = NULL,
dtype = NULL,
...
)
Arguments
lambda |
The penalty parameter for the entropy penalized optimal transport. Default is NULL. Can be a single number or a set of numbers to try. |
delta |
The bound for balancing functions if they are being used. Only available for biased entropy penalized optimal transport. Can be a single number or a set of numbers to try. |
opt.direction |
Should the optimizer solve the primal or dual problems. Should be one of "dual" or "primal" with a default of "dual" since it is typically faster. |
debias |
Should debiased optimal transport be used? TRUE or FALSE. |
p |
The power of the cost function to use for the cost. |
cost.function |
A function to calculate the pairwise costs. Should take arguments |
cost.online |
Should an online cost algorithm be used? One of "auto", "online", or "tensorized". "tensorized" is the offline option. |
diameter |
The diameter of the covariate space, if known. Default is NULL. |
balance.formula |
Formula for the balancing functions. |
quick.balance.function |
TRUE or FALSE denoting whether balance function constraints should be selected via a linear program (TRUE) or just checked for feasibility (FALSE). Default is TRUE. |
grid.length |
The number of penalty parameters to explore in a grid search if none are provided in arguments |
torch.optimizer |
The torch optimizer to use for methods using debiased entropy penalized optimal transport. If |
torch.scheduler |
The scheduler for the optimizer. Defaults to |
niter |
The number of iterations to run the solver |
nboot |
The number of iterations for the bootstrap to select the final penalty parameters. |
lambda.bootstrap |
The penalty parameter to use for the bootstrap hyperparameter selection of lambda. |
tol |
The tolerance for convergence |
device |
An object of class |
dtype |
An object of class |
... |
Arguments passed to the solvers. See details |
Value
A list of class cotOptions
with the following slots
-
lambda
The penalty parameter for the optimal transport distance -
delta
The constraint for the balancing functions -
opt.direction
Whether to solve the primal or dual optimization problems -
debias
TRUE or FALSE if debiased optimal transport distances are used -
balance.formula
The formula giving how to generate the balancing functions. -
quick.balance.function
TRUE or FALSE whether quick balance functions will be run. -
grid.length
The number of parameters to check in a grid search of best parameters -
p
The power of the cost function -
cost.online
Whether online costs are used -
cost.function
The user supplied cost function if supplied. -
diameter
The diameter of the covariate space. -
torch.optimizer
Thetorch
optimizer used for Sinkhorn Divergences -
torch.scheduler
The scheduler for thetorch
optimizer -
solver.options
The arguments to be passeed to thetorch.optimizer
-
scheduler.options
The arguments to be passeed to thetorch.scheduler
-
osqp.options
Arguments passed to theosqp
function if quick balance functions are used. -
niter
The number of iterations to run the solver -
nboot
The number of bootstrap samples -
lambda.bootstrap
The penalty parameter to use for the bootstrap hyperparameter selection. -
tol
The tolerance for convergence. -
device
An object of classtorch_device
. -
dtype
An object of classtorch_dtype
.
Solvers and distances
The function is setup to direct the COT optimizer to run two basic methods: debiased entropy penalized optimal transport (Sinkhorn Divergences) or entropy penalized optimal transport (Sinkhorn Distances).
Sinkhorn Distances
The optimal transport problem solved is min_w OT_\lambda(w,b)
where
OT_\lambda(w,b) = \sum_{ij} C(x_i, x_j) P_{ij} + \lambda \sum_{ij} P_{ij}\log(P_{ij}),
such that the rows of the matrix P_{ij}
sum to w
and the columns sum to b
. In this case C(,)
is the cost between units i and j.
Sinkhorn Divergences
The Sinkhorn Divergence solves
min_w OT_\lambda(w,b) - 0.5 OT_\lambda(w,w) - 0.5 * OT_\lambda(b,b).
The solver for this function uses the torch
package in R
and by default will use the optim_rmsprop
solver. Your desired torch
optimizer can be passed via torch.optimizer
with a scheduler passed via torch.scheduler
. GPU support is available as detailed in the torch
package. Additional arguments in ...
are passed as extra arguments to the torch
optimizer and schedulers as appropriate.
Function balancing
There may be certain functions of the covariates that we wish to balance within some tolerance, \delta
. For these functions B
, we will desire
\frac{\sum_{i: Z_i = 0} w_i B(x_i) - \sum_{j: Z_j = 1} B(x_j)/n_1}{\sigma} \leq \delta
, where in this case we are targeting balance with the treatment group for the ATT. \sigma
is the pooled standard deviation prior to balancing.
Cost functions
The cost function specifies pairwise distances. If argument cost.function
is NULL, the function will default to using L_p^p
distances with a default p = 2
supplied by the argument p
. So for p = 2
, the cost between units x_i
and x_j
will be
C(x_i, x_j) = \frac{1}{2} \| x_i - x_j \|_2^2.
If cost.function
is provided, it should be a function that takes arguments x1
, x2
, and p
: function(x1, x2, p){...}
.
Examples
if ( torch::torch_is_installed()) {
opts1 <- cotOptions(lambda = 1e3, torch.optimizer = torch::optim_rmsprop)
opts2 <- cotOptions(lambda = NULL)
opts3 <- cotOptions(lambda = seq(0.1, 100, length.out = 7))
}