predict.bp {causalOT}R Documentation

Predict method for barycentric projection models

Description

Predict method for barycentric projection models

Usage

## S3 method for class 'bp'
predict(
  object,
  newdata = NULL,
  source.sample,
  cost_function = NULL,
  niter = 1000,
  tol = 1e-07,
  ...
)

Arguments

object

An object of class "bp"

newdata

a data.frame containing new observations

source.sample

a vector giving the sample each observations arise from

cost_function

a cost metric between observations

niter

number of iterations to run the barycentric projection for powers > 2.

tol

Tolerance on the optimization problem for projections with powers > 2.

...

Dots passed to the lbfgs method in the torch package.

Examples

if(torch::torch_is_installed()) {
set.seed(23483)
n <- 2^5
pp <- 6
overlap <- "low"
design <- "A"
estimate <- "ATT"
power <- 2
data <- causalOT::Hainmueller$new(n = n, p = pp,
design = design, overlap = overlap)

data$gen_data()

weights <- causalOT::calc_weight(x = data,
  z = NULL, y = NULL,
  estimand = estimate,
  method = "NNM")
  
 df <- data.frame(y = data$get_y(), z = data$get_z(), data$get_x())
  
 # undebiased
 fit <- causalOT::barycentric_projection(y ~ ., data = df, 
    weight = weights,
    separate.samples.on = "z", niter = 2)
    
 #debiased
 fit_d <- causalOT::barycentric_projection(y ~ ., data = df, 
    weight = weights,
    separate.samples.on = "z", debias = TRUE, niter = 2)
 
 # predictions, without new data
 undebiased_predictions <- predict(fit,   source.sample = df$z)
 debiased_predictions   <- predict(fit_d, source.sample = df$z)
 
 isTRUE(all.equal(unname(undebiased_predictions), df$y)) # FALSE
 isTRUE(all.equal(unname(debiased_predictions), df$y)) # TRUE
 }

[Package causalOT version 1.0.2 Index]