bl_imp.GeDSboost {GeDS} | R Documentation |
Base Learner Importance for GeDSboost objects
Description
This function calculates the in-bag mean squared error (MSE) reduction
ascribable to each of the base-learners with regards to the final prediction
of the component-wise gradient boosted model encapsulated in a
GeDSboost-Class
object. Essentially, it measures the decrease
in MSE attributable to each base-learner for every time it is selected across
the boosting iterations, and aggregates them. This provides a measure on how
much each base-learner contributes to the overall improvement in the model's
accuracy, as reflected by the decrease in MSE. This function is adapted from
varimp
and is compatible with the available
mboost-package
methods for varimp
,
including plot
, print
and as.data.frame
.
Usage
## S3 method for class 'GeDSboost'
bl_imp(object, boosting_iter_only = FALSE, ...)
Arguments
object |
an object of class |
boosting_iter_only |
logical value, if |
... |
potentially further arguments. |
Details
See varimp
for details.
Value
An object of class varimp
with available plot
,
print
and as.data.frame
methods.
References
Hothorn T., Buehlmann P., Kneib T., Schmid M. and Hofner B. (2022). mboost: Model-Based Boosting. R package version 2.9-7, https://CRAN.R-project.org/package=mboost.
Examples
library(GeDS)
library(TH.data)
set.seed(290875)
data("bodyfat", package = "TH.data")
data = bodyfat
Gmodboost <- NGeDSboost(formula = DEXfat ~ f(hipcirc) + f(kneebreadth) + f(anthro3a),
data = data, initial_learner = FALSE)
MSE_Gmodboost_linear <- mean((data$DEXfat - Gmodboost$predictions$pred_linear)^2)
MSE_Gmodboost_quadratic <- mean((data$DEXfat - Gmodboost$predictions$pred_quadratic)^2)
MSE_Gmodboost_cubic <- mean((data$DEXfat - Gmodboost$predictions$pred_cubic)^2)
# Print MSE
cat("\n", "MEAN SQUARED ERROR", "\n",
"Linear NGeDSboost:", MSE_Gmodboost_linear, "\n",
"Quadratic NGeDSboost:", MSE_Gmodboost_quadratic, "\n",
"Cubic NGeDSboost:", MSE_Gmodboost_cubic, "\n")
# Base Learner Importance
bl_imp <- bl_imp(Gmodboost)
print(bl_imp)
plot(bl_imp)