classify {tidyfit} | R Documentation |
Classification on tidy data
Description
This function is a wrapper to fit many different types of linear
classification models on a (grouped) tibble
.
Arguments
.data |
a data frame, data frame extension (e.g. a tibble), or a lazy data frame (e.g. from dbplyr or dtplyr). The data frame can be grouped. |
formula |
an object of class "formula": a symbolic description of the model to be fitted. |
... |
name-function pairs of models to be estimated. See 'Details'. |
.cv |
type of 'rsample' cross validation procedure to use to determine optimal hyperparameter values. Default is |
.cv_args |
additional settings to pass to the 'rsample' cross validation function. |
.weights |
optional name of column containing sample weights. |
.mask |
optional vector of columns names to ignore. Can be useful when using 'y ~ .' formula syntax. |
.return_slices |
logical. Should the output of individual cross validation slices be returned or only the final fit. Default is |
.return_grid |
logical. Should the output of the individual hyperparameter grids be returned or only the best fitting set of hyperparameters. Default is |
.tune_each_group |
logical. Should optimal hyperparameters be selected for each group or once across all groups. Default is |
.force_cv |
logical. Should models be evaluated across all cross validation slices, even if no hyperparameters are tuned. Default is |
Details
classify
fits all models passed in ...
using the m
function. The models can be passed as name-function pairs (e.g. ols = m("lm")
) or without including a name.
Hyperparameters are tuned automatically using the '.cv' and '.cv_args' arguments, or can be passed to m()
(e.g. lasso = m("lasso", lambda = 0.5)
). See the individual model functions (?m()
) for an overview of hyperparameters.
Cross validation is performed using the 'rsample' package with possible methods including
'initial_split' (simple train-test split)
'initial_time_split' (train-test split with retained order)
'vfold_cv' (aka kfold cross validation)
'loo_cv' (leave-one-out)
'rolling_origin' (generalized time series cross validation, e.g. rolling or expanding windows)
'sliding_window', 'sliding_index', 'sliding_period' (specialized time series splits)
'bootstraps'
'group_vfold_cv', 'group_bootstraps'
See package documentation for 'rsample' for all available methods.
The negative log loss is used to validate performance in the cross validation.
Note that arguments for weights are automatically passed to the functions by setting the '.weights' argument. Weights are also considered during cross validation by calculating weighted versions of the cross validation loss function.
classify
can handle both binomial and multinomial response distributions, however not all underlying methods are capable of handling a multinomial response.
Value
A tidyfit.models
frame containing model details for each group.
The 'tidyfit.models' frame consists of 4 different components:
A group of identifying columns (e.g. model name, data groups, grid IDs)
A 'model_object' column, which contains the fitted model.
A nested 'settings' column containing model arguments and hyperparameters
Columns showing errors, warnings and messages (if applicable)
Coefficients, predictions, fitted values or residuals can be accessed using the built-in coef
, predict
, fitted
and resid
methods. Note that all coefficients are transformed to ensure comparability across methods.
Author(s)
Johann Pfitzinger
See Also
regress
, coef.tidyfit.models
and predict.tidyfit.models
method
Examples
data <- tidyfit::Factor_Industry_Returns
data <- dplyr::mutate(data, Return = ifelse(Return > 0, 1, 0))
fit <- classify(data, Return ~ ., m("lasso", lambda = c(0.001, 0.1)), .mask = c("Date", "Industry"))
# Print the models frame
tidyr::unnest(fit, settings)
# View coefficients
coef(fit)