plot_ensemble {ensModelVis}R Documentation

Draws a plot for model predictions of ensembles of models. For classification the plot is a heatmap, for regression, scatterplot.

Description

Draws a plot for model predictions of ensembles of models. For classification the plot is a heatmap, for regression, scatterplot.

Usage

plot_ensemble(
  truth,
  tibble_pred,
  incorrect = FALSE,
  tibble_prob = NULL,
  order = NULL,
  facet = FALSE
)

Arguments

truth

The y variable. In regression this is numeric vector, in classification this is a factor vector.

tibble_pred

A data.frame of predictions. Each column corresponds to a candidate model.

incorrect

If TRUE, for observations that were correctly classified by all models, remove all but a single observation per class. Classification only.

tibble_prob

If not NULL, a data.frame with same column names as tibble_pred. Applies transparency based on the predicted probability of the predicted class. Classification only.

order

default ordering is by accuracy (classification) or RMSE (regression). Can submit any other ordering e.g. AUC, which should be a data.frame with same column names as tibble_pred.

facet

whether to facet the plots by model (regression only).

Value

a ggplot

Examples

data(iris)
if (require("MASS")){
lda.model <- lda(Species~., data = iris)
lda.pred <- predict(lda.model)
}
if (require("ranger")){
ranger.model <- ranger(Species~., data = iris)
ranger.pred <- predict(ranger.model, iris)
}

library(ensModelVis)

plot_ensemble(iris$Species,
data.frame(LDA = lda.pred$class,
RF = ranger.pred$predictions))

plot_ensemble(iris$Species,
 data.frame(LDA = lda.pred$class,
  RF = ranger.pred$predictions),
  incorrect= TRUE)

if (require("ranger")){
ranger.model <- ranger(Species~., data = iris, probability = TRUE)
ranger.prob <- predict(ranger.model, iris)
}

plot_ensemble(iris$Species,
  data.frame(LDA = lda.pred$class,
   RF = ranger.pred$predictions),
   tibble_prob = data.frame(LDA = apply(lda.pred$posterior, 1, max),
   RF = apply(ranger.prob$predictions, 1, max)))

[Package ensModelVis version 0.1.0 Index]