predict.ObliqueForest {aorsf} | R Documentation |
Prediction for ObliqueForest Objects
Description
Compute predicted values from an oblique random forest. Predictions may be returned in aggregate (i.e., averaging over all the trees) or tree-specific.
Usage
## S3 method for class 'ObliqueForest'
predict(
object,
new_data = NULL,
pred_type = NULL,
pred_horizon = NULL,
pred_aggregate = TRUE,
pred_simplify = FALSE,
oobag = FALSE,
na_action = NULL,
boundary_checks = TRUE,
n_thread = NULL,
verbose_progress = NULL,
...
)
Arguments
object |
(ObliqueForest) a trained oblique random forest object (see orsf). |
new_data |
a data.frame, tibble, or data.table to compute predictions in. |
pred_type |
(character) the type of predictions to compute. Valid options for survival are:
For classification:
For regression:
|
pred_horizon |
(double) Only relevent for survival forests.
A value or vector indicating the time(s) that predictions will be
calibrated to. E.g., if you were predicting risk of incident heart
failure within the next 10 years, then |
pred_aggregate |
(logical) If |
pred_simplify |
(logical) If |
oobag |
(logical) If |
na_action |
(character) what should happen when
|
boundary_checks |
(logical) if |
n_thread |
(integer) number of threads to use while computing predictions. Default is 0, which allows a suitable number of threads to be used based on availability. |
verbose_progress |
(logical) if |
... |
Further arguments passed to or from other methods (not currently used). |
Details
new_data
must have the same columns with equivalent types as the data
used to train object
. Also, factors in new_data
must not have levels
that were not in the data used to train object
.
pred_horizon
values should not exceed the maximum follow-up time in
object
's training data, but if you truly want to do this, set
boundary_checks = FALSE
and you can use a pred_horizon
as large
as you want. Note that predictions beyond the maximum follow-up time
in the object
's training data are equal to predictions at the
maximum follow-up time, because aorsf
does not estimate survival
beyond its maximum observed time.
If unspecified, pred_horizon
may be automatically specified as the value
used for oobag_pred_horizon
when object
was created (see orsf).
Value
a matrix
of predictions. Column j
of the matrix corresponds
to value j
in pred_horizon
. Row i
of the matrix corresponds to
row i
in new_data
.
Examples
library(aorsf)
Classification
set.seed(329) index_train <- sample(nrow(penguins_orsf), 150) penguins_orsf_train <- penguins_orsf[index_train, ] penguins_orsf_test <- penguins_orsf[-index_train, ] fit_clsf <- orsf(data = penguins_orsf_train, formula = species ~ .)
Predict probability for each class or the predicted class:
# predicted probabilities, the default predict(fit_clsf, new_data = penguins_orsf_test[1:5, ], pred_type = 'prob')
## Adelie Chinstrap Gentoo ## [1,] 0.9405310 0.04121955 0.018249405 ## [2,] 0.9628988 0.03455909 0.002542096 ## [3,] 0.9032074 0.08510528 0.011687309 ## [4,] 0.9300133 0.05209040 0.017896329 ## [5,] 0.7965703 0.16243492 0.040994821
# predicted class (as a matrix by default) predict(fit_clsf, new_data = penguins_orsf_test[1:5, ], pred_type = 'class')
## [,1] ## [1,] 1 ## [2,] 1 ## [3,] 1 ## [4,] 1 ## [5,] 1
# predicted class (as a factor if you use simplify) predict(fit_clsf, new_data = penguins_orsf_test[1:5, ], pred_type = 'class', pred_simplify = TRUE)
## [1] Adelie Adelie Adelie Adelie Adelie ## Levels: Adelie Chinstrap Gentoo
Regression
set.seed(329) index_train <- sample(nrow(penguins_orsf), 150) penguins_orsf_train <- penguins_orsf[index_train, ] penguins_orsf_test <- penguins_orsf[-index_train, ] fit_regr <- orsf(data = penguins_orsf_train, formula = bill_length_mm ~ .)
Predict the mean value of the outcome:
predict(fit_regr, new_data = penguins_orsf_test[1:5, ], pred_type = 'mean')
## [,1] ## [1,] 37.74136 ## [2,] 37.42367 ## [3,] 37.04598 ## [4,] 39.89602 ## [5,] 39.14848
Survival
Begin by fitting an oblique survival random forest:
set.seed(329) index_train <- sample(nrow(pbc_orsf), 150) pbc_orsf_train <- pbc_orsf[index_train, ] pbc_orsf_test <- pbc_orsf[-index_train, ] fit_surv <- orsf(data = pbc_orsf_train, formula = Surv(time, status) ~ . - id, oobag_pred_horizon = 365.25 * 5)
Predict risk, survival, or cumulative hazard at one or several times:
# predicted risk, the default predict(fit_surv, new_data = pbc_orsf_test[1:5, ], pred_type = 'risk', pred_horizon = c(500, 1000, 1500))
## [,1] [,2] [,3] ## [1,] 0.013648562 0.058393393 0.11184029 ## [2,] 0.003811413 0.026857586 0.04774151 ## [3,] 0.030548361 0.100600301 0.14847107 ## [4,] 0.040381075 0.169596943 0.27018952 ## [5,] 0.001484698 0.006663576 0.01337655
# predicted survival, i.e., 1 - risk predict(fit_surv, new_data = pbc_orsf_test[1:5, ], pred_type = 'surv', pred_horizon = c(500, 1000, 1500))
## [,1] [,2] [,3] ## [1,] 0.9863514 0.9416066 0.8881597 ## [2,] 0.9961886 0.9731424 0.9522585 ## [3,] 0.9694516 0.8993997 0.8515289 ## [4,] 0.9596189 0.8304031 0.7298105 ## [5,] 0.9985153 0.9933364 0.9866235
# predicted cumulative hazard function # (expected number of events for person i at time j) predict(fit_surv, new_data = pbc_orsf_test[1:5, ], pred_type = 'chf', pred_horizon = c(500, 1000, 1500))
## [,1] [,2] [,3] ## [1,] 0.015395388 0.067815817 0.14942956 ## [2,] 0.004022524 0.028740305 0.05424314 ## [3,] 0.034832754 0.127687156 0.20899732 ## [4,] 0.059978334 0.233048809 0.42562310 ## [5,] 0.001651365 0.007173177 0.01393016
Predict mortality, defined as the number of events in the forest’s population if all observations had characteristics like the current observation. This type of prediction does not require you to specify a prediction horizon
predict(fit_surv, new_data = pbc_orsf_test[1:5, ], pred_type = 'mort')
## [,1] ## [1,] 23.405016 ## [2,] 15.362916 ## [3,] 26.180648 ## [4,] 36.515629 ## [5,] 5.856674