OTProblem {causalOT} | R Documentation |
Object Oriented OT Problem
Description
Object Oriented OT Problem
Usage
OTProblem(measure_1, measure_2, ...)
Arguments
measure_1 |
An object of class Measure |
measure_2 |
An object of class Measure |
... |
Not used at this time |
Value
An R6 object of class "OTProblem"
Public fields
device
the
torch::torch_device()
of the data.dtype
the torch::torch_dtype of the data.
selected_delta
the delta value selected after
choose_hyperparameters
selected_lambda
the lambda value selected after
choose_hyperparameters
Active bindings
loss
prints the current value of the objective. Only availble after the
OTProblem$solve()
method has been runpenalty
Returns a list of the lambda and delta penalities that will be iterated through. To set these values, use the
OTProblem$setup_arguments()
function.
Methods
Public methods
Method add()
adds o2
to the OTProblem
Usage
OTProblem$add(o2)
Arguments
o2
A number or object of class OTProblem
Method subtract()
subtracts o2
from OTProblem
Usage
OTProblem$subtract(o2)
Arguments
o2
A number or object of class OTProblem
Method multiply()
multiplies OTProblem by o2
Usage
OTProblem$multiply(o2)
Arguments
o2
A number or an object of class OTProblem
Method divide()
divides OTProblem by o2
Usage
OTProblem$divide(o2)
Arguments
o2
A number or object of class OTProblem
Method setup_arguments()
Usage
OTProblem$setup_arguments( lambda, delta, grid.length = 7L, cost.function = NULL, p = 2, cost.online = "auto", debias = TRUE, diameter = NULL, ot_niter = 1000L, ot_tol = 0.001 )
Arguments
lambda
The penalty parameters to try for the OT problems. If not provided, function will select some
delta
The constraint paramters to try for the balance function problems, if any
grid.length
The number of hyperparameters to try if not provided
cost.function
The cost function for the data. Can be any function that takes arguments
x1
,x2
,p
. Defaults to the Euclidean distancep
The power to raise the cost matrix by. Default is 2
cost.online
Should online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.
debias
Should debiased OT problems be used? Defaults to TRUE
diameter
Diameter of the cost function.
ot_niter
Number of iterations to run the OT problems
ot_tol
The tolerance for convergence of the OT problems
Returns
NULL
Examples
ot$setup_arguments(lambda = c(1000,10))
Method solve()
Solve the OTProblem at each parameter value. Must run setup_arguments first.
Usage
OTProblem$solve( niter = 1000L, tol = 1e-05, optimizer = c("torch", "frank-wolfe"), torch_optim = torch::optim_lbfgs, torch_scheduler = torch::lr_reduce_on_plateau, torch_args = NULL, osqp_args = NULL, quick.balance.function = TRUE )
Arguments
niter
The nubmer of iterations to run solver at each combination of hyperparameter values
tol
The tolerance for convergence
optimizer
The optimizer to use. One of "torch" or "frank-wolfe"
torch_optim
The
torch_optimizer
to use. Default is torch::optim_lbfgstorch_scheduler
The torch::lr_scheduler to use. Default is torch::lr_reduce_on_plateau
torch_args
Arguments passed to the torch optimizer and scheduler
osqp_args
Arguments passed to
osqp::osqpSettings()
if appropriatequick.balance.function
Should
osqp::osqp()
be used to select balance function constraints (delta) or not. Default true.
Examples
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
Method choose_hyperparameters()
Selects the hyperparameter values through a bootstrap algorithm
Usage
OTProblem$choose_hyperparameters( n_boot_lambda = 100L, n_boot_delta = 1000L, lambda_bootstrap = Inf )
Arguments
n_boot_lambda
The number of bootstrap iterations to run when selecting lambda
n_boot_delta
The number of bootstrap iterations to run when selecting delta
lambda_bootstrap
The penalty parameter to use when selecting lambda. Higher numbers run faster.
Examples
ot$choose_hyperparameters(n_boot_lambda = 10, n_boot_delta = 10, lambda_bootstrap = Inf)
Method info()
Provides diagnostics after solve and choose_hyperparameter methods have been run.
Usage
OTProblem$info()
Returns
a list with slots
-
loss
the final loss values -
iterations
The number of iterations run for each combination of parameters -
balance.function.differences
The final differences in the balance functions -
hyperparam.metrics
A list of the bootstrap evalustion for delta and lambda values
Examples
ot$info()
Method clone()
The objects of this class are cloneable with this method.
Usage
OTProblem$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
Examples
## ------------------------------------------------
## Method `OTProblem(measure_1, measure_2)`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights")
z <- matrix(3,102, 10)
m3 <- Measure(x = z)
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
ot <- 0.5 * ot1 + 0.5 * ot2
print(ot)
## ------------------------------------------------
## Method `OTProblem$setup_arguments`
## ------------------------------------------------
ot$setup_arguments(lambda = 1000)
## ------------------------------------------------
## Method `OTProblem$solve`
## ------------------------------------------------
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
## ------------------------------------------------
## Method `OTProblem$choose_hyperparameters`
## ------------------------------------------------
ot$choose_hyperparameters(n_boot_lambda = 1,
n_boot_delta = 1,
lambda_bootstrap = Inf)
## ------------------------------------------------
## Method `OTProblem$info`
## ------------------------------------------------
ot$info()
}