Predictor {distillML} | R Documentation |
Predictor class description
Description
A wrapper class for generic ML algorithms (xgboost, RF, BART, rpart, etc.) in order to standardize the predictions given by different algorithms to be compatible with the interpretability functions.
The necessary variables are model, data, y. The other variables are optional, and depend on the use cases. Type should be used only when a prediction function is NOT specified.
The outputs of the algorithm must be the values if it is regression, or probabilities if classification. For classification problems with more than two categories, the output comes out as vectors of probabilities for the specified "class" category. Because this is for ML interpretability, other types of predictions (ex: predictions that spit out the factor) are not allowed.
Public fields
data
The training data that was used during training for the model. This should be a data frame matching the data frame the model was given for training, which includes the label or outcome.
model
The object corresponding to the trained model that we want to make a Predictor object for. If this model doesn't have a generic predict method, the user has to provide a custom predict function that accepts a data frame.
task
The prediction task the model is trained to perform ('classification' or 'regression').
class
The class for which we get predictions. We specify this to get the predictions (such as probabilites) for an observation being in a specific class (e.g. Male or Female). This parameter is necessary for classification predictions with more than a single vector of predictions.
prediction.function
An optional parameter if the model doesn't have a generic prediction function. This should take a data frame and return a vector of predictions for each observation in the data frame.
y
The name of the outcome feature in the 'data' data frame.
Methods
Public methods
Method new()
Usage
Predictor$new( model = NULL, data = NULL, predict.func = NULL, y = NULL, task = NULL, class = NULL, type = NULL )
Arguments
model
The object corresponding to the trained model that we want to make a Predictor object for. If this model doesn't have a generic predict method, the user has to provide a custom predict function that accepts a data frame.
data
The training data that was used during training for the model. This should be a data frame matching the data frame the model was given for training, including the label or outcome.
predict.func
An optional parameter if the model doesn't have a generic prediction function. This should take a data frame and return a vector of predictions for each observation in the data frame.
y
The name of the outcome feature in the 'data' data frame.
task
The prediction task the model is trained to perform ('classification' or 'regression').
class
The class for which we get predictions. We specify this to get the predictions (such as probabilites) for an observation being in a specific class (e.g. Male or Female). This parameter is necessary for classification predictions with more than a single vector of predictions.
type
The type of predictions done (i.e. 'response' for predicted probabliities for classification). This feature should only be used if no predict.func is specified.
Returns
A 'Predictor' object.
Method clone()
The objects of this class are cloneable with this method.
Usage
Predictor$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
Note
The class that wraps a machine learning model in order to provide a standardized method for predictions for different models. prediction method must be constructed, with optional argument of type
Examples
library(distillML)
library(Rforestry)
set.seed(491)
data <- iris
test_ind <- sample(1:nrow(data), nrow(data)%/%5)
train_reg <- data[-test_ind,]
test_reg <- data[test_ind,]
forest <- forestry(x=data[,-1],
y=data[,1])
forest_predictor <- Predictor$new(model = forest, data=train_reg,
y="Sepal.Length", task = "regression")