OTProblem {causalOT}R Documentation

Object Oriented OT Problem

Description

Object Oriented OT Problem

Usage

OTProblem(measure_1, measure_2, ...)

Arguments

measure_1

An object of class Measure

measure_2

An object of class Measure

...

Not used at this time

Value

An R6 object of class "OTProblem"

Public fields

device

the torch::torch_device() of the data.

dtype

the torch::torch_dtype of the data.

selected_delta

the delta value selected after choose_hyperparameters

selected_lambda

the lambda value selected after choose_hyperparameters

Active bindings

loss

prints the current value of the objective. Only availble after the OTProblem$solve() method has been run

penalty

Returns a list of the lambda and delta penalities that will be iterated through. To set these values, use the OTProblem$setup_arguments() function.

Methods

Public methods


Method add()

adds o2 to the OTProblem

Usage
OTProblem$add(o2)
Arguments
o2

A number or object of class OTProblem


Method subtract()

subtracts o2 from OTProblem

Usage
OTProblem$subtract(o2)
Arguments
o2

A number or object of class OTProblem


Method multiply()

multiplies OTProblem by o2

Usage
OTProblem$multiply(o2)
Arguments
o2

A number or an object of class OTProblem


Method divide()

divides OTProblem by o2

Usage
OTProblem$divide(o2)
Arguments
o2

A number or object of class OTProblem


Method setup_arguments()

Usage
OTProblem$setup_arguments(
      lambda,
      delta,
      grid.length = 7L,
      cost.function = NULL,
      p = 2,
      cost.online = "auto",
      debias = TRUE,
      diameter = NULL,
      ot_niter = 1000L,
      ot_tol = 0.001
    )
Arguments
lambda

The penalty parameters to try for the OT problems. If not provided, function will select some

delta

The constraint paramters to try for the balance function problems, if any

grid.length

The number of hyperparameters to try if not provided

cost.function

The cost function for the data. Can be any function that takes arguments x1, x2, p. Defaults to the Euclidean distance

p

The power to raise the cost matrix by. Default is 2

cost.online

Should online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.

debias

Should debiased OT problems be used? Defaults to TRUE

diameter

Diameter of the cost function.

ot_niter

Number of iterations to run the OT problems

ot_tol

The tolerance for convergence of the OT problems

Returns

NULL

Examples
 ot$setup_arguments(lambda = c(1000,10))
    

Method solve()

Solve the OTProblem at each parameter value. Must run setup_arguments first.

Usage
OTProblem$solve(
      niter = 1000L,
      tol = 1e-05,
      optimizer = c("torch", "frank-wolfe"),
      torch_optim = torch::optim_lbfgs,
      torch_scheduler = torch::lr_reduce_on_plateau,
      torch_args = NULL,
      osqp_args = NULL,
      quick.balance.function = TRUE
    )
Arguments
niter

The nubmer of iterations to run solver at each combination of hyperparameter values

tol

The tolerance for convergence

optimizer

The optimizer to use. One of "torch" or "frank-wolfe"

torch_optim

The torch_optimizer to use. Default is torch::optim_lbfgs

torch_scheduler

The torch::lr_scheduler to use. Default is torch::lr_reduce_on_plateau

torch_args

Arguments passed to the torch optimizer and scheduler

osqp_args

Arguments passed to osqp::osqpSettings() if appropriate

quick.balance.function

Should osqp::osqp() be used to select balance function constraints (delta) or not. Default true.

Examples
 ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
    

Method choose_hyperparameters()

Selects the hyperparameter values through a bootstrap algorithm

Usage
OTProblem$choose_hyperparameters(
      n_boot_lambda = 100L,
      n_boot_delta = 1000L,
      lambda_bootstrap = Inf
    )
Arguments
n_boot_lambda

The number of bootstrap iterations to run when selecting lambda

n_boot_delta

The number of bootstrap iterations to run when selecting delta

lambda_bootstrap

The penalty parameter to use when selecting lambda. Higher numbers run faster.

Examples
 ot$choose_hyperparameters(n_boot_lambda = 10, 
                                             n_boot_delta = 10, 
                                             lambda_bootstrap = Inf)
    

Method info()

Provides diagnostics after solve and choose_hyperparameter methods have been run.

Usage
OTProblem$info()
Returns

a list with slots

Examples
 ot$info()
    

Method clone()

The objects of this class are cloneable with this method.

Usage
OTProblem$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

Examples

## ------------------------------------------------
## Method `OTProblem(measure_1, measure_2)`
## ------------------------------------------------

if (torch::torch_is_installed()) {
  # setup measures
  x <- matrix(1, 100, 10)
  m1 <- Measure(x = x)
  
  y <- matrix(2, 100, 10)
  m2 <- Measure(x = y, adapt = "weights")
  
  z <- matrix(3,102, 10)
  m3 <- Measure(x = z)
  
  # setup OT problems
  ot1 <- OTProblem(m1, m2)
  ot2 <- OTProblem(m3, m2)
  ot <- 0.5 * ot1 + 0.5 * ot2
  print(ot)

## ------------------------------------------------
## Method `OTProblem$setup_arguments`
## ------------------------------------------------

  ot$setup_arguments(lambda = 1000)

## ------------------------------------------------
## Method `OTProblem$solve`
## ------------------------------------------------

  ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)

## ------------------------------------------------
## Method `OTProblem$choose_hyperparameters`
## ------------------------------------------------

  ot$choose_hyperparameters(n_boot_lambda = 1,
                            n_boot_delta = 1, 
                            lambda_bootstrap = Inf)

## ------------------------------------------------
## Method `OTProblem$info`
## ------------------------------------------------

ot$info()
}

[Package causalOT version 1.0.2 Index]