Predictor {distillML}R Documentation

Predictor class description

Description

A wrapper class for generic ML algorithms (xgboost, RF, BART, rpart, etc.) in order to standardize the predictions given by different algorithms to be compatible with the interpretability functions.

The necessary variables are model, data, y. The other variables are optional, and depend on the use cases. Type should be used only when a prediction function is NOT specified.

The outputs of the algorithm must be the values if it is regression, or probabilities if classification. For classification problems with more than two categories, the output comes out as vectors of probabilities for the specified "class" category. Because this is for ML interpretability, other types of predictions (ex: predictions that spit out the factor) are not allowed.

Public fields

data

The training data that was used during training for the model. This should be a data frame matching the data frame the model was given for training, which includes the label or outcome.

model

The object corresponding to the trained model that we want to make a Predictor object for. If this model doesn't have a generic predict method, the user has to provide a custom predict function that accepts a data frame.

task

The prediction task the model is trained to perform ('classification' or 'regression').

class

The class for which we get predictions. We specify this to get the predictions (such as probabilites) for an observation being in a specific class (e.g. Male or Female). This parameter is necessary for classification predictions with more than a single vector of predictions.

prediction.function

An optional parameter if the model doesn't have a generic prediction function. This should take a data frame and return a vector of predictions for each observation in the data frame.

y

The name of the outcome feature in the 'data' data frame.

Methods

Public methods


Method new()

Usage
Predictor$new(
  model = NULL,
  data = NULL,
  predict.func = NULL,
  y = NULL,
  task = NULL,
  class = NULL,
  type = NULL
)
Arguments
model

The object corresponding to the trained model that we want to make a Predictor object for. If this model doesn't have a generic predict method, the user has to provide a custom predict function that accepts a data frame.

data

The training data that was used during training for the model. This should be a data frame matching the data frame the model was given for training, including the label or outcome.

predict.func

An optional parameter if the model doesn't have a generic prediction function. This should take a data frame and return a vector of predictions for each observation in the data frame.

y

The name of the outcome feature in the 'data' data frame.

task

The prediction task the model is trained to perform ('classification' or 'regression').

class

The class for which we get predictions. We specify this to get the predictions (such as probabilites) for an observation being in a specific class (e.g. Male or Female). This parameter is necessary for classification predictions with more than a single vector of predictions.

type

The type of predictions done (i.e. 'response' for predicted probabliities for classification). This feature should only be used if no predict.func is specified.

Returns

A 'Predictor' object.


Method clone()

The objects of this class are cloneable with this method.

Usage
Predictor$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

Note

The class that wraps a machine learning model in order to provide a standardized method for predictions for different models. prediction method must be constructed, with optional argument of type

Examples


library(distillML)
library(Rforestry)
set.seed(491)
data <- iris

test_ind <- sample(1:nrow(data), nrow(data)%/%5)
train_reg <- data[-test_ind,]
test_reg <- data[test_ind,]


forest <- forestry(x=data[,-1],
                   y=data[,1])

forest_predictor <- Predictor$new(model = forest, data=train_reg,
                                  y="Sepal.Length", task = "regression")


[Package distillML version 0.1.0.13 Index]