augment.model_fit {parsnip} | R Documentation |
Augment data with predictions
Description
augment()
will add column(s) for predictions to the given data.
Usage
## S3 method for class 'model_fit'
augment(x, new_data, eval_time = NULL, ...)
Arguments
x |
A |
new_data |
A data frame or matrix. |
eval_time |
For censored regression models, a vector of time points at which the survival probability is estimated. |
... |
Not currently used. |
Details
Regression
For regression models, a .pred
column is added. If x
was created using
fit.model_spec()
and new_data
contains a regression outcome column, a
.resid
column is also added.
Classification
For classification models, the results can include a column called
.pred_class
as well as class probability columns named .pred_{level}
.
This depends on what type of prediction types are available for the model.
Censored Regression
For these models, predictions for the expected time and survival probability
are created (if the model engine supports them). If the model supports
survival prediction, the eval_time
argument is required.
If survival predictions are created and new_data
contains a
survival::Surv()
object, additional columns are added for inverse
probability of censoring weights (IPCW) are also created (see tidymodels.org
page in the references below). This enables the user to compute performance
metrics in the yardstick package.
References
https://www.tidymodels.org/learn/statistics/survival-metrics/
Examples
car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]
reg_form <-
linear_reg() %>%
set_engine("lm") %>%
fit(mpg ~ ., data = car_trn)
reg_xy <-
linear_reg() %>%
set_engine("lm") %>%
fit_xy(car_trn[, -1], car_trn$mpg)
augment(reg_form, car_tst)
augment(reg_form, car_tst[, -1])
augment(reg_xy, car_tst)
augment(reg_xy, car_tst[, -1])
# ------------------------------------------------------------------------------
data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[ 1:10 , ]
cls_form <-
logistic_reg() %>%
set_engine("glm") %>%
fit(Class ~ ., data = cls_trn)
cls_xy <-
logistic_reg() %>%
set_engine("glm") %>%
fit_xy(cls_trn[, -3],
cls_trn$Class)
augment(cls_form, cls_tst)
augment(cls_form, cls_tst[, -3])
augment(cls_xy, cls_tst)
augment(cls_xy, cls_tst[, -3])