predict {sMTL}R Documentation

predict: predict on smtl model object

Description

predict: predict on smtl model object

Usage

predict(model, X, lambda_1 = NA, lambda_2 = NA, lambda_z = NA, stack = FALSE)

Arguments

model

An sMTL model object returned from the smtl() function

X

A matrix of deatures

lambda_1

A optional numeric scalar specifying which lambda_1 to use for prediction. Only needed if the model object is fit on a path (multiple hyperparameterr values)

lambda_2

A optional numeric scalar specifying which lambda_2 to use for prediction. Only needed if the model object is fit on a path (multiple hyperparameterr values)

lambda_z

A optional numeric scalar specifying which lambda_2 to use for prediction. Only needed if the model object is fit on a path (multiple hyperparameterr values)

stack

An optional boolean specifying whether to calculate and apply stacking weights (only for Domain Generalization problems).

Value

A matrix of task-specific predictions for multi-task/multi-label or for Domain Generalization problems, average and multi-study stacking predictions.

Examples


#####################################################################################
##### First Time Loading, Julia is Installed and Julia Path is Known ######
#####################################################################################
# fit model
## Not run: 

if (identical(Sys.getenv("AUTO_JULIA_INSTALL"), "true")) { ## The examples are quite time consuming
## Do initiation for and automatic installation if necessary
mod <- smtl(y = y, 
            X = X, 
            study = task, 
            s = 5, 
            commonSupp = FALSE,
            lambda_1 = c(0.1, 0.2, 0.3),
            lambda_z = c(0.01, 0.05, 0.1))

# make predictions
preds <- sMTL::predict.smtl(model = mod, 
                       X = X, 
                       lambda_1 = 0.1, 
                       lambda_z = 0.01) }
                       
## End(Not run)
                       

[Package sMTL version 0.1.0 Index]