Predictor {iml} | R Documentation |
Predictor object
Description
A Predictor
object holds any machine learning model (mlr
, caret
,
randomForest
, ...) and the data to be used for analyzing the model. The
interpretation methods in the iml
package need the machine learning model
to be wrapped in a Predictor
object.
Details
A Predictor object is a container for the prediction model and the data. This ensures that the machine learning model can be analyzed in a robust way.
Note: In case of classification, the model should return one column per class with the class probability.
Public fields
data
data.frame
Data object with the data for the model interpretation.model
(any)
The machine learning model.batch.size
numeric(1)
The number of rows to be input the model for prediction at once.class
character(1)
The class column to be returned.prediction.colnames
character
The column names of the predictions.prediction.function
function
The function to predict newdata.task
character(1)
The inferred prediction task:"classification"
or"regression"
.
Methods
Public methods
Method new()
Create a Predictor object
Usage
Predictor$new( model = NULL, data = NULL, predict.function = NULL, y = NULL, class = NULL, type = NULL, batch.size = 1000 )
Arguments
model
any
The machine learning model. Recommended are models frommlr
andcaret
. Other machine learning with a S3 predict functions work as well, but less robust (e.g.randomForest
).data
data.frame
The data to be used for analyzing the prediction model. Allowed column classes are: numeric, factor, integer, ordered and character For some models the data can be extracted automatically.Predictor$new()
throws an error when it can't extract the data automatically.predict.function
function
The function to predict newdata. Only needed ifmodel
is not a model frommlr
orcaret
package. The first argument ofpredict.fun
has to be the model, the second thenewdata
:function(model, newdata)
y
character(1)
| numeric | factor
The target vector or (preferably) the name of the target column in thedata
argument. Predictor tries to infer the target automatically from the model.class
character(1)
The class column to be returned. You should use the column name of the predicted class, e.g.class="setosa"
.type
character(1)
)
This argument is passed to the prediction function of the model. For regression models you usually don't have to provide the type argument. The classic use case is to saytype="prob"
for classification models. Consult the documentation of the machine learning package you use to find which type options you have. If bothpredict.fun
andtype
are used, then type is passed as an argument topredict.fun
.batch.size
numeric(1)
The maximum number of rows to be input the model for prediction at once. Currently only respected for FeatureImp, Partial and Interaction.
Method predict()
Predict new data with the machine learning model.
Usage
Predictor$predict(newdata)
Arguments
newdata
data.frame
Data to predict on.
Method print()
Print the Predictor object.
Usage
Predictor$print()
Method clone()
The objects of this class are cloneable with this method.
Usage
Predictor$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
Examples
library("mlr")
task <- makeClassifTask(data = iris, target = "Species")
learner <- makeLearner("classif.rpart", minsplit = 7, predict.type = "prob")
mod.mlr <- train(learner, task)
mod <- Predictor$new(mod.mlr, data = iris)
mod$predict(iris[1:5, ])
mod <- Predictor$new(mod.mlr, data = iris, class = "setosa")
mod$predict(iris[1:5, ])
library("randomForest")
rf <- randomForest(Species ~ ., data = iris, ntree = 20)
mod <- Predictor$new(rf, data = iris, type = "prob")
mod$predict(iris[50:55, ])
# Feature importance needs the target vector, which needs to be supplied:
mod <- Predictor$new(rf, data = iris, y = "Species", type = "prob")