| OTProblem {causalOT} | R Documentation |
Object Oriented OT Problem
Description
Object Oriented OT Problem
Usage
OTProblem(measure_1, measure_2, ...)
Arguments
measure_1 |
An object of class Measure |
measure_2 |
An object of class Measure |
... |
Not used at this time |
Value
An R6 object of class "OTProblem"
Public fields
devicethe
torch::torch_device()of the data.dtypethe torch::torch_dtype of the data.
selected_deltathe delta value selected after
choose_hyperparametersselected_lambdathe lambda value selected after
choose_hyperparameters
Active bindings
lossprints the current value of the objective. Only availble after the
OTProblem$solve()method has been runpenaltyReturns a list of the lambda and delta penalities that will be iterated through. To set these values, use the
OTProblem$setup_arguments()function.
Methods
Public methods
Method add()
adds o2 to the OTProblem
Usage
OTProblem$add(o2)
Arguments
o2A number or object of class OTProblem
Method subtract()
subtracts o2 from OTProblem
Usage
OTProblem$subtract(o2)
Arguments
o2A number or object of class OTProblem
Method multiply()
multiplies OTProblem by o2
Usage
OTProblem$multiply(o2)
Arguments
o2A number or an object of class OTProblem
Method divide()
divides OTProblem by o2
Usage
OTProblem$divide(o2)
Arguments
o2A number or object of class OTProblem
Method setup_arguments()
Usage
OTProblem$setup_arguments(
lambda,
delta,
grid.length = 7L,
cost.function = NULL,
p = 2,
cost.online = "auto",
debias = TRUE,
diameter = NULL,
ot_niter = 1000L,
ot_tol = 0.001
)Arguments
lambdaThe penalty parameters to try for the OT problems. If not provided, function will select some
deltaThe constraint paramters to try for the balance function problems, if any
grid.lengthThe number of hyperparameters to try if not provided
cost.functionThe cost function for the data. Can be any function that takes arguments
x1,x2,p. Defaults to the Euclidean distancepThe power to raise the cost matrix by. Default is 2
cost.onlineShould online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.
debiasShould debiased OT problems be used? Defaults to TRUE
diameterDiameter of the cost function.
ot_niterNumber of iterations to run the OT problems
ot_tolThe tolerance for convergence of the OT problems
Returns
NULL
Examples
ot$setup_arguments(lambda = c(1000,10))
Method solve()
Solve the OTProblem at each parameter value. Must run setup_arguments first.
Usage
OTProblem$solve(
niter = 1000L,
tol = 1e-05,
optimizer = c("torch", "frank-wolfe"),
torch_optim = torch::optim_lbfgs,
torch_scheduler = torch::lr_reduce_on_plateau,
torch_args = NULL,
osqp_args = NULL,
quick.balance.function = TRUE
)Arguments
niterThe nubmer of iterations to run solver at each combination of hyperparameter values
tolThe tolerance for convergence
optimizerThe optimizer to use. One of "torch" or "frank-wolfe"
torch_optimThe
torch_optimizerto use. Default is torch::optim_lbfgstorch_schedulerThe torch::lr_scheduler to use. Default is torch::lr_reduce_on_plateau
torch_argsArguments passed to the torch optimizer and scheduler
osqp_argsArguments passed to
osqp::osqpSettings()if appropriatequick.balance.functionShould
osqp::osqp()be used to select balance function constraints (delta) or not. Default true.
Examples
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
Method choose_hyperparameters()
Selects the hyperparameter values through a bootstrap algorithm
Usage
OTProblem$choose_hyperparameters(
n_boot_lambda = 100L,
n_boot_delta = 1000L,
lambda_bootstrap = Inf
)Arguments
n_boot_lambdaThe number of bootstrap iterations to run when selecting lambda
n_boot_deltaThe number of bootstrap iterations to run when selecting delta
lambda_bootstrapThe penalty parameter to use when selecting lambda. Higher numbers run faster.
Examples
ot$choose_hyperparameters(n_boot_lambda = 10,
n_boot_delta = 10,
lambda_bootstrap = Inf)
Method info()
Provides diagnostics after solve and choose_hyperparameter methods have been run.
Usage
OTProblem$info()
Returns
a list with slots
-
lossthe final loss values -
iterationsThe number of iterations run for each combination of parameters -
balance.function.differencesThe final differences in the balance functions -
hyperparam.metricsA list of the bootstrap evalustion for delta and lambda values
Examples
ot$info()
Method clone()
The objects of this class are cloneable with this method.
Usage
OTProblem$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Examples
## ------------------------------------------------
## Method `OTProblem(measure_1, measure_2)`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights")
z <- matrix(3,102, 10)
m3 <- Measure(x = z)
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
ot <- 0.5 * ot1 + 0.5 * ot2
print(ot)
## ------------------------------------------------
## Method `OTProblem$setup_arguments`
## ------------------------------------------------
ot$setup_arguments(lambda = 1000)
## ------------------------------------------------
## Method `OTProblem$solve`
## ------------------------------------------------
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
## ------------------------------------------------
## Method `OTProblem$choose_hyperparameters`
## ------------------------------------------------
ot$choose_hyperparameters(n_boot_lambda = 1,
n_boot_delta = 1,
lambda_bootstrap = Inf)
## ------------------------------------------------
## Method `OTProblem$info`
## ------------------------------------------------
ot$info()
}