policy_learn {polle}R Documentation

Create Policy Learner

Description

policy_learn() is used to specify a policy learning method (Q-learning, doubly robust Q-learning, policy tree learning and outcome weighted learning). Evaluating the policy learner returns a policy object.

Usage

policy_learn(
  type = "ql",
  control = list(),
  alpha = 0,
  full_history = FALSE,
  L = 1,
  cross_fit_g_models = TRUE,
  save_cross_fit_models = FALSE,
  future_args = list(future.seed = TRUE),
  name = type
)

## S3 method for class 'policy_learn'
print(x, ...)

## S3 method for class 'policy_object'
print(x, ...)

Arguments

type

Type of policy learner method:

  • "ql": Quality/Q-learning.

  • "drql": Doubly Robust Q-learning.

  • "blip": Doubly Robust blip-learning (only for dichotomous actions).

  • "ptl": Policy Tree Learning.

  • "owl": Outcome Weighted Learning.

  • "earl": Efficient Augmentation and Relaxation Learning (only single stage).

  • "rwl": Residual Weighted Learning (only single stage).

control

List of control arguments. Values (and default values) are set using control_{type}(). Key arguments include:
control_drql():

  • qv_models: Single element or list of V-restricted Q-models created by q_glm(), q_rf(), q_sl() or similar functions.

control_blip():

  • blip_models: Single element or list of V-restricted blip-models created by q_glm(), q_rf(), q_sl() or similar functions.

control_ptl():

  • policy_vars: Character vector/string or list of character vectors/strings. Variable names used to construct the V-restricted policy tree. The names must be a subset of the history names, see get_history_names().

  • hybrid: If TRUE, policytree::hybrid_policy_tree() is used to fit a policy tree.

  • depth: Integer or integer vector. The depth of the fitted policy tree for each stage.

control_owl():

  • policy_vars: As in control_ptl().

  • loss: Loss function. The options are "hinge", "ramp", "logit", "logit.lasso", "l2", "l2.lasso".

  • kernel: Type of kernel used by the support vector machine. The options are "linear", "rbf".

  • augment: If TRUE the outcomes are augmented.

control_earl()/control_rwl():

  • moPropen: Propensity model of class "ModelObj", see modelObj::modelObj.

  • moMain: Main effects outcome model of class "ModelObj".

  • moCont Contrast outcome model of class "ModelObj".

  • regime: An object of class formula specifying the design of the policy.

  • surrogate: The surrogate 0-1 loss function. The options are "logit", "exp", "hinge", "sqhinge", "huber".

  • kernel: The options are "linear", "poly", "radial".

alpha

Probability threshold for determining realistic actions.

full_history

If TRUE, the full history is used to fit each policy function (e.g. QV-model, policy tree). If FALSE, the single stage/ "Markov type" history is used to fit each policy function.

L

Number of folds for cross-fitting nuisance models.

cross_fit_g_models

If TRUE, the g-models will not be cross-fitted even if L > 1.

save_cross_fit_models

If TRUE, the cross-fitted models will be saved.

future_args

Arguments passed to future.apply::future_apply().

name

Character string.

x

Object of class "policy_object" or "policy_learn".

...

Additional arguments passed to print.

Value

Function of inherited class "policy_learn". Evaluating the function on a policy_data object returns an object of class policy_object. A policy object is a list containing all or some of the following elements:

q_functions

Fitted Q-functions. Object of class "nuisance_functions".

g_functions

Fitted g-functions. Object of class "nuisance_functions".

action_set

Sorted character vector describing the action set, i.e., the possible actions at each stage.

alpha

Numeric. Probability threshold to determine realistic actions.

K

Integer. Maximal number of stages.

qv_functions

(only if type = "drql") Fitted V-restricted Q-functions. Contains a fitted model for each stage and action.

ptl_objects

(only if type = "ptl") Fitted V-restricted policy trees. Contains a policy_tree for each stage.

ptl_designs

(only if type = "ptl") Specification of the V-restricted design matrix for each stage

S3 generics

The following S3 generic functions are available for an object of class "policy_object":

get_g_functions()

Extract the fitted g-functions.

get_q_functions()

Extract the fitted Q-functions.

get_policy()

Extract the fitted policy object.

get_policy_functions()

Extract the fitted policy function for a given stage.

get_policy_actions()

Extract the (fitted) policy actions.

References

Doubly Robust Q-learning (type = "drql"): Luedtke, Alexander R., and Mark J. van der Laan. "Super-learning of an optimal dynamic treatment rule." The international journal of biostatistics 12.1 (2016): 305-332. doi:10.1515/ijb-2015-0052.

Policy Tree Learning (type = "ptl"): Zhou, Zhengyuan, Susan Athey, and Stefan Wager. "Offline multi-action policy learning: Generalization and optimization." Operations Research (2022). doi:10.1287/opre.2022.2271.

(Augmented) Outcome Weighted Learning: Liu, Ying, et al. "Augmented outcome‐weighted learning for estimating optimal dynamic treatment regimens." Statistics in medicine 37.26 (2018): 3776-3788. doi:10.1002/sim.7844.

See Also

policy_eval()

Examples

library("polle")
### Two stages:
d <- sim_two_stage(5e2, seed=1)
pd <- policy_data(d,
                  action = c("A_1", "A_2"),
                  baseline = c("BB"),
                  covariates = list(L = c("L_1", "L_2"),
                                    C = c("C_1", "C_2")),
                  utility = c("U_1", "U_2", "U_3"))
pd

### V-restricted (Doubly Robust) Q-learning

# specifying the learner:
pl <- policy_learn(
  type = "drql",
  control = control_drql(qv_models = list(q_glm(formula = ~ C_1 + BB),
                                          q_glm(formula = ~ L_1 + BB))),
  full_history = TRUE
)

# evaluating the learned policy
pe <- policy_eval(policy_data = pd,
                  policy_learn = pl,
                  q_models = q_glm(),
                  g_models = g_glm())
pe
# getting the policy object:
po <- get_policy_object(pe)
# inspecting the fitted QV-model for each action strata at stage 1:
po$qv_functions$stage_1
head(get_policy(pe)(pd))

[Package polle version 1.4 Index]