predict.brulee_logistic_reg {brulee} | R Documentation |
Predict from a brulee_logistic_reg
Description
Predict from a brulee_logistic_reg
Usage
## S3 method for class 'brulee_logistic_reg'
predict(object, new_data, type = NULL, epoch = NULL, ...)
Arguments
object |
A |
new_data |
A data frame or matrix of new predictors. |
type |
A single character. The type of predictions to generate. Valid options are:
|
epoch |
An integer for the epoch to make predictions. If this value
is larger than the maximum number that was fit, a warning is issued and the
parameters from the last epoch are used. If left |
... |
Not used, but required for extensibility. |
Value
A tibble of predictions. The number of rows in the tibble is guaranteed
to be the same as the number of rows in new_data
.
Examples
if (torch::torch_is_installed()) {
library(recipes)
library(yardstick)
data(penguins, package = "modeldata")
penguins <- penguins %>% na.omit()
set.seed(122)
in_train <- sample(1:nrow(penguins), 200)
penguins_train <- penguins[ in_train,]
penguins_test <- penguins[-in_train,]
rec <- recipe(sex ~ ., data = penguins_train) %>%
step_dummy(all_nominal_predictors()) %>%
step_normalize(all_numeric_predictors())
set.seed(3)
fit <- brulee_logistic_reg(rec, data = penguins_train, epochs = 5)
fit
predict(fit, penguins_test)
predict(fit, penguins_test, type = "prob") %>%
bind_cols(penguins_test) %>%
roc_curve(sex, .pred_female) %>%
autoplot()
}
[Package brulee version 0.3.0 Index]