predict.brulee_linear_reg {brulee}R Documentation

Predict from a brulee_linear_reg

Description

Predict from a brulee_linear_reg

Usage

## S3 method for class 'brulee_linear_reg'
predict(object, new_data, type = NULL, epoch = NULL, ...)

Arguments

object

A brulee_linear_reg object.

new_data

A data frame or matrix of new predictors.

type

A single character. The type of predictions to generate. Valid options are:

  • "numeric" for numeric predictions.

epoch

An integer for the epoch to make predictions. If this value is larger than the maximum number that was fit, a warning is issued and the parameters from the last epoch are used. If left NULL, the epoch associated with the smallest loss is used.

...

Not used, but required for extensibility.

Value

A tibble of predictions. The number of rows in the tibble is guaranteed to be the same as the number of rows in new_data.

Examples


if (torch::torch_is_installed()) {

 data(ames, package = "modeldata")

 ames$Sale_Price <- log10(ames$Sale_Price)

 set.seed(1)
 in_train <- sample(1:nrow(ames), 2000)
 ames_train <- ames[ in_train,]
 ames_test  <- ames[-in_train,]

 # Using recipe
 library(recipes)

 ames_rec <-
  recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) %>%
    step_normalize(all_numeric_predictors())

 set.seed(2)
 fit <- brulee_linear_reg(ames_rec, data = ames_train,
                           epochs = 50, batch_size = 32)

 predict(fit, ames_test)
}


[Package brulee version 0.3.0 Index]