plot_fit_trajectory {WVPlots} | R Documentation |
Plot the trajectory of a model fit.
Description
Plot a history of model fit performance over the a trajectory of times.
Usage
plot_fit_trajectory(
d,
column_description,
title,
...,
epoch_name = "epoch",
needs_flip = c(),
pick_metric = NULL,
discount_rate = NULL,
draw_ribbon = FALSE,
draw_segments = FALSE,
val_color = "#d95f02",
train_color = "#1b9e77",
pick_color = "#e6ab02"
)
Arguments
d |
data frame to get values from. |
column_description |
description of column measures (data.frame with columns measure, validation, and training). |
title |
character title for plot. |
... |
force later arguments to be bound by name |
epoch_name |
name for epoch or trajectory column. |
needs_flip |
character array of measures that need to be flipped. |
pick_metric |
character metric to maximize. |
discount_rate |
numeric what fraction of over-fit to subtract from validation performance. |
draw_ribbon |
present the difference in training and validation performance as a ribbon rather than two curves? (default FALSE) |
draw_segments |
logical if TRUE draw over-fit/under-fit segments. |
val_color |
color for validation performance curve |
train_color |
color for training performance curve |
pick_color |
color for indicating optimal stopping point |
Details
This visualization can be applied to any staged machine learning algorithm. For example one could plot the performance of a gradient boosting machine as a function of the number of trees added. The fit history data should be in the form given in the example below.
The example below gives a fit plot for a history report from Keras R package. Please see https://win-vector.com/2017/12/23/plotting-deep-learning-model-performance-trajectories/ for some examples and details.
Value
ggplot2 plot
See Also
Examples
if (requireNamespace('data.table', quietly = TRUE)) {
# don't multi-thread during CRAN checks
data.table::setDTthreads(1)
}
d <- data.frame(
epoch = c(1, 2, 3, 4, 5),
val_loss = c(0.3769818, 0.2996994, 0.2963943, 0.2779052, 0.2842501),
val_acc = c(0.8722000, 0.8895000, 0.8822000, 0.8899000, 0.8861000),
loss = c(0.5067290, 0.3002033, 0.2165675, 0.1738829, 0.1410933),
acc = c(0.7852000, 0.9040000, 0.9303333, 0.9428000, 0.9545333) )
cT <- data.frame(
measure = c("minus binary cross entropy", "accuracy"),
training = c("loss", "acc"),
validation = c("val_loss", "val_acc"),
stringsAsFactors = FALSE)
plt <- plot_fit_trajectory(
d,
column_description = cT,
needs_flip = "minus binary cross entropy",
title = "model performance by epoch, dataset, and measure",
epoch_name = "epoch",
pick_metric = "minus binary cross entropy",
discount_rate = 0.1)
print(plt)