EN_predict {DMTL}R Documentation

Predictive Modeling using Elastic Net

Description

This function trains a Elastic Net regressor using the training data provided and predict response for the test features. This implementation depends on the glmnet package.

Usage

EN_predict(
  x_train,
  y_train,
  x_test,
  lims,
  optimize = FALSE,
  alpha = 0.8,
  seed = NULL,
  verbose = FALSE,
  parallel = FALSE
)

Arguments

x_train

Training features for designing the EN regressor.

y_train

Training response for designing the EN regressor.

x_test

Test features for which response values are to be predicted. If x_test is not given, the function will return the trained model.

lims

Vector providing the range of the response values for modeling. If missing, these values are estimated from the training response.

optimize

Flag for model tuning. If TRUE, performs a grid search for parameters. If FALSE, uses the parameters provided. Defaults to FALSE.

alpha

EN mixing parameter with 0 \le \alpha \le 1. alpha = 1 is the lasso penalty, and alpha = 0 the ridge penalty. Defaults to 0.8. Valid only when optimize = FALSE.

seed

Seed for random number generator (for reproducible outcomes). Defaults to NULL.

verbose

Flag for printing the tuning progress when optimize = TRUE. Defaults to FALSE.

parallel

Flag for allowing parallel processing when performing grid search i.e., optimimze = TRUE. Defaults to FALSE.

Value

If x_test is missing, the trained EN regressor.

If x_test is provided, the predicted values using the model.

Note

The response values are filtered to be bound by range in lims.

Examples

set.seed(86420)
x <- matrix(rnorm(3000, 0.2, 1.2), ncol = 3);    colnames(x) <- paste0("x", 1:3)
y <- 0.3*x[, 1] + 0.1*x[, 2] - x[, 3] + rnorm(1000, 0, 0.05)

## Get the model only...
model <- EN_predict(x_train = x[1:800, ], y_train = y[1:800], alpha = 0.6)

## Get predictive performance...
y_pred <- EN_predict(x_train = x[1:800, ], y_train = y[1:800], x_test = x[801:1000, ])
y_test <- y[801:1000]
print(performance(y_test, y_pred, measures = "RSQ"))


[Package DMTL version 0.1.2 Index]