method_user {WeightIt}R Documentation

User-Defined Functions for Estimating Weights

Description

This page explains the details of estimating weights using a user-defined function. The function must take in arguments that are passed to it by weightit() or weightitMSM() and return a vector of weights or a list containing the weights.

To supply a user-defined function, the function object should be entered directly to method; for example, for a function fun, method = fun.

Point Treatments

The following arguments are automatically passed to the user-defined function, which should have named parameters corresponding to them:

None of these parameters are required to be in the fitting function. These are simply those that are automatically available.

In addition, any additional arguments supplied to weightit() will be passed on to the fitting function. weightit() ensures the arguments correspond to the parameters of the fitting function and throws an error if an incorrectly named argument is supplied and the fitting function doesn't include ⁠\dots⁠ as a parameter.

The fitting function must output either a numeric vector of weights or a list (or list-like object) with an entry named wither "w" or "weights". If a list, the list can contain other named entries, but only entries named "w", "weights", "ps", and "fit.obj" will be processed. "ps" is a vector of propensity scores and "fit.obj" should be an object used in the fitting process that a user may want to examine and that is included in the weightit output object as "obj" when include.obj = TRUE. The "ps" and "fit.obj" components are optional, but "weights" or "w" is required.

Longitudinal Treatments

Longitudinal treatments can be handled either by running the fitting function for point treatments for each time point and multiplying the resulting weights together or by running a method that accommodates multiple time points and outputs a single set of weights. For the former, weightitMSM() can be used with the user-defined function just as it is with weightit(). The latter method is not yet accommodated by weightitMSM(), but will be someday, maybe.

See Also

weightit(), weightitMSM()

Examples


library("cobalt")
data("lalonde", package = "cobalt")

#A user-defined version of method = "ps"
my.ps <- function(treat, covs, estimand, focal = NULL) {
  covs <- make_full_rank(covs)
  d <- data.frame(treat, covs)
  f <- formula(d)
  ps <- glm(f, data = d, family = "binomial")$fitted
  w <- get_w_from_ps(ps, treat = treat, estimand = estimand,
                     focal = focal)

  list(w = w, ps = ps)
}

#Balancing covariates between treatment groups (binary)
(W1 <- weightit(treat ~ age + educ + married +
                  nodegree + re74, data = lalonde,
                method = my.ps, estimand = "ATT"))
summary(W1)
bal.tab(W1)

data("msmdata")
(W2 <- weightitMSM(list(A_1 ~ X1_0 + X2_0,
                        A_2 ~ X1_1 + X2_1 +
                          A_1 + X1_0 + X2_0,
                        A_3 ~ X1_2 + X2_2 +
                          A_2 + X1_1 + X2_1 +
                          A_1 + X1_0 + X2_0),
                   data = msmdata,
                   method = my.ps))

summary(W2)
bal.tab(W2)

# Kernel balancing using the `kbal` package, available
# using `devtools::install_github("chadhazlett/KBAL")`.
# Only the ATT and ATC are available.

## Not run: 
  kbal.fun <- function(treat, covs, estimand, focal, verbose, ...) {
    args <- list(...)
    if (is_not_null(focal))
      treat <- as.numeric(treat == focal)
    else
      stop("`estimand` must be \"ATT\" or \"ATC\".", call. = FALSE)

    args <- args[names(args) %in% names(formals(kbal::kbal))]
    args$allx <- covs
    args$treatment <- treat
    args$printprogress <- verbose

    cat_cols <- apply(covs, 2, function(x) length(unique(x)) <= 2)

    if (all(cat_cols)) {
      args$cat_data <- TRUE
      args$mixed_data <- FALSE
      args$scale_data <- FALSE
      args$linkernel <- FALSE
      args$drop_MC <- FALSE
    }
    else if (any(cat_cols)) {
      args$cat_data <- FALSE
      args$mixed_data <- TRUE
      args$cat_columns <- colnames(covs)[cat_cols]
      args$allx[,!cat_cols] <- scale(args$allx[,!cat_cols])
      args$cont_scale <- 1
    }
    else {
      args$cat_data <- FALSE
      args$mixed_data <- FALSE
    }

    k.out <- do.call(kbal::kbal, args)
    w <- k.out$w

    list(w = w, fit.obj = k.out)
  }

  (Wk <- weightit(treat ~ age + educ + married +
                    nodegree + re74, data = lalonde,
                  method = kbal.fun, estimand = "ATT",
                  include.obj = TRUE))
  summary(Wk)
  bal.tab(Wk, stats = c("m", "ks"))

## End(Not run)


[Package WeightIt version 1.1.0 Index]