add_epred_draws {tidybayes}R Documentation

Add draws from the posterior fit, predictions, or residuals of a model to a data frame

Description

Given a data frame and a model, adds draws from the linear/link-level predictor, the expectation of the posterior predictive, the posterior predictive, or the residuals of a model to the data frame in a long format.

Usage

add_epred_draws(
  newdata,
  object,
  ...,
  value = ".epred",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  dpar = NULL
)

epred_draws(
  object,
  newdata,
  ...,
  value = ".epred",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  dpar = NULL
)

## Default S3 method:
epred_draws(
  object,
  newdata,
  ...,
  value = ".epred",
  seed = NULL,
  category = NULL
)

## S3 method for class 'stanreg'
epred_draws(
  object,
  newdata,
  ...,
  value = ".epred",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  dpar = NULL
)

## S3 method for class 'brmsfit'
epred_draws(
  object,
  newdata,
  ...,
  value = ".epred",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  dpar = NULL
)

add_linpred_draws(
  newdata,
  object,
  ...,
  value = ".linpred",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  dpar = NULL,
  n
)

linpred_draws(
  object,
  newdata,
  ...,
  value = ".linpred",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  dpar = NULL,
  n,
  scale
)

## Default S3 method:
linpred_draws(
  object,
  newdata,
  ...,
  value = ".linpred",
  seed = NULL,
  category = NULL
)

## S3 method for class 'stanreg'
linpred_draws(
  object,
  newdata,
  ...,
  value = ".linpred",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  dpar = NULL
)

## S3 method for class 'brmsfit'
linpred_draws(
  object,
  newdata,
  ...,
  value = ".linpred",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  dpar = NULL
)

add_predicted_draws(
  newdata,
  object,
  ...,
  value = ".prediction",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  n
)

predicted_draws(
  object,
  newdata,
  ...,
  value = ".prediction",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  n,
  prediction
)

## Default S3 method:
predicted_draws(
  object,
  newdata,
  ...,
  value = ".prediction",
  seed = NULL,
  category = ".category"
)

## S3 method for class 'stanreg'
predicted_draws(
  object,
  newdata,
  ...,
  value = ".prediction",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category"
)

## S3 method for class 'brmsfit'
predicted_draws(
  object,
  newdata,
  ...,
  value = ".prediction",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category"
)

add_residual_draws(
  newdata,
  object,
  ...,
  value = ".residual",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  n
)

residual_draws(
  object,
  newdata,
  ...,
  value = ".residual",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category",
  n,
  residual
)

## Default S3 method:
residual_draws(object, newdata, ...)

## S3 method for class 'brmsfit'
residual_draws(
  object,
  newdata,
  ...,
  value = ".residual",
  ndraws = NULL,
  seed = NULL,
  re_formula = NULL,
  category = ".category"
)

Arguments

newdata

Data frame to generate predictions from.

object

A supported Bayesian model fit that can provide fits and predictions. Supported models are listed in the second section of tidybayes-models: Models Supporting Prediction. While other functions in this package (like spread_draws()) support a wider range of models, to work with add_epred_draws(), add_predicted_draws(), etc. a model must provide an interface for generating predictions, thus more generic Bayesian modeling interfaces like runjags and rstan are not directly supported for these functions (only wrappers around those languages that provide predictions, like rstanarm and brm, are supported here).

...

Additional arguments passed to the underlying prediction method for the type of model given.

value

The name of the output column:

  • for ⁠[add_]epred_draws()⁠, defaults to ".epred".

  • for ⁠[add_]predicted_draws()⁠, defaults to ".prediction".

  • for ⁠[add_]linpred_draws()⁠, defaults to ".linpred".

  • for ⁠[add_]residual_draws()⁠, defaults to ".residual"

ndraws

The number of draws to return, or NULL to return all draws.

seed

A seed to use when subsampling draws (i.e. when ndraws is not NULL).

re_formula

formula containing group-level effects to be considered in the prediction. If NULL (default), include all group-level effects; if NA, include no group-level effects. Some model types (such as brms::brmsfit and rstanarm::stanreg-objects) allow marginalizing over grouping factors by specifying new levels of a factor in newdata. In the case of brms::brm(), you must also pass allow_new_levels = TRUE here to include new levels (see brms::posterior_predict()).

category

For some ordinal, multinomial, and multivariate models (notably, brms::brm() models but not rstanarm::stan_polr() models), multiple sets of rows will be returned per input row for epred_draws() or predicted_draws(), depending on the model type. For ordinal/multinomial models, these rows correspond to different categories of the response variable. For multivariate models, these correspond to different response variables. The category argument specifies the name of the column to put the category names (or variable names) into in the resulting data frame. The default name of this column (".category") reflects the fact that this functionality was originally used only for ordinal models and has been re-used for multivariate models. The fact that multiple rows per response are returned only for some model types reflects the fact that tidybayes takes the approach of tidying whatever output is given to us, and the output from different modeling functions differs on this point. See vignette("tidy-brms") and vignette("tidy-rstanarm") for examples of dealing with output from ordinal models using both approaches.

dpar

For add_epred_draws() and add_linpred_draws(): Should distributional regression parameters be included in the output? Valid only for models that support distributional regression parameters, such as submodels for variance parameters (as in brms::brm()). If TRUE, distributional regression parameters are included in the output as additional columns named after each parameter (alternative names can be provided using a list or named vector, e.g. c(sigma.hat = "sigma") would output the "sigma" parameter from a model as a column named "sigma.hat"). If NULL or FALSE (the default), distributional regression parameters are not included.

n

(Deprecated). Use ndraws.

scale

(Deprecated). Use the appropriate function (epred_draws() or linpred_draws()) depending on what type of distribution you want. For linpred_draws(), you may want the transform argument. See rstanarm::posterior_linpred() or brms::posterior_linpred().

prediction, residual

(Deprecated). Use value.

Details

Consider a model like:

\begin{array}{rcl} y &\sim& \textrm{SomeDist}(\theta_1, \theta_2)\\ f_1(\theta_1) &=& \alpha_1 + \beta_1 x\\ f_2(\theta_2) &=& \alpha_2 + \beta_2 x \end{array}

This model has:

We fit this model to some observed data, y_\textrm{obs}, and predictors, x_\textrm{obs}. Given new values of predictors, x_\textrm{new}, supplied in the data frame newdata, the functions for posterior draws are defined as follows:

The corresponding functions without add_ as a prefix are alternate spellings with the opposite order of the first two arguments: e.g. add_predicted_draws(newdata, object) versus predicted_draws(object, newdata). This facilitates use in data processing pipelines that start either with a data frame or a model.

Given equal choice between the two, the spellings prefixed with add_ are preferred.

Value

A data frame (actually, a tibble) with a .row column (a factor grouping rows from the input newdata), .chain column (the chain each draw came from, or NA if the model does not provide chain information), .iteration column (the iteration the draw came from, or NA if the model does not provide iteration information), and a .draw column (a unique index corresponding to each draw from the distribution). In addition, epred_draws includes a column with its name specified by the epred argument (default ".epred"); linpred_draws includes a column with its name specified by the linpred argument (default ".linpred"), and predicted_draws contains a column with its name specified by the .prediction argument (default ".prediction"). For convenience, the resulting data frame comes grouped by the original input rows.

Author(s)

Matthew Kay

See Also

add_draws() for the variant of these functions for use with packages that do not have explicit support for these functions yet. See spread_draws() for manipulating posteriors directly.

Examples

## Not run: 

library(ggplot2)
library(dplyr)
library(brms)
library(modelr)

theme_set(theme_light())

m_mpg = brm(mpg ~ hp * cyl, data = mtcars,
  # 1 chain / few iterations just so example runs quickly
  # do not use in practice
  chains = 1, iter = 500)

# draw 100 lines from the posterior means and overplot them
mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  # NOTE: only use ndraws here when making spaghetti plots; for
  # plotting intervals it is always best to use all draws (omit ndraws)
  add_epred_draws(m_mpg, ndraws = 100) %>%
  ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
  geom_line(aes(y = .epred, group = paste(cyl, .draw)), alpha = 0.25) +
  geom_point(data = mtcars)

# plot posterior predictive intervals
mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  add_predicted_draws(m_mpg) %>%
  ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
  stat_lineribbon(aes(y = .prediction), .width = c(.99, .95, .8, .5), alpha = 0.25) +
  geom_point(data = mtcars) +
  scale_fill_brewer(palette = "Greys")


## End(Not run)

[Package tidybayes version 3.0.6 Index]