plot_confusion_matrix {cvms}R Documentation

Plot a confusion matrix

Description

[Experimental]

Creates a ggplot2 object representing a confusion matrix with counts, overall percentages, row percentages and column percentages. An extra row and column with sum tiles and the total count can be added.

The confusion matrix can be created with evaluate(). See `Examples`.

While this function is intended to be very flexible (hence the large number of arguments), the defaults should work in most cases for most users. See the Examples.

Usage

plot_confusion_matrix(
  conf_matrix,
  target_col = "Target",
  prediction_col = "Prediction",
  counts_col = "N",
  class_order = NULL,
  add_sums = FALSE,
  add_counts = TRUE,
  add_normalized = TRUE,
  add_row_percentages = TRUE,
  add_col_percentages = TRUE,
  diag_percentages_only = FALSE,
  rm_zero_percentages = TRUE,
  rm_zero_text = TRUE,
  add_zero_shading = TRUE,
  add_arrows = TRUE,
  counts_on_top = FALSE,
  palette = "Blues",
  intensity_by = "counts",
  theme_fn = ggplot2::theme_minimal,
  place_x_axis_above = TRUE,
  rotate_y_text = TRUE,
  digits = 1,
  font_counts = font(),
  font_normalized = font(),
  font_row_percentages = font(),
  font_col_percentages = font(),
  arrow_size = 0.048,
  arrow_nudge_from_text = 0.065,
  tile_border_color = NA,
  tile_border_size = 0.1,
  tile_border_linetype = "solid",
  sums_settings = sum_tile_settings(),
  darkness = 0.8
)

Arguments

conf_matrix

Confusion matrix tibble with each combination of targets and predictions along with their counts.

E.g. for a binary classification:

Target Prediction N
class_1 class_1 5
class_1 class_2 9
class_2 class_1 3
class_2 class_2 2

As created with the various evaluation functions in cvms, like evaluate().

Note: If you supply the results from evaluate() or confusion_matrix() directly, the confusion matrix tibble is extracted automatically, if possible.

target_col

Name of column with target levels.

prediction_col

Name of column with prediction levels.

counts_col

Name of column with a count for each combination of the target and prediction levels.

class_order

Names of the classes in `conf_matrix` in the desired order. When NULL, the classes are ordered alphabetically.

add_sums

Add tiles with the row/column sums. Also adds a total count tile. (Logical)

The appearance of these tiles can be specified in `sums_settings`.

Note: Adding the sum tiles with a palette requires the ggnewscale package.

add_counts

Add the counts to the middle of the tiles. (Logical)

add_normalized

Normalize the counts to percentages and add to the middle of the tiles. (Logical)

add_row_percentages

Add the row percentages, i.e. how big a part of its row the tile makes up. (Logical)

By default, the row percentage is placed to the right of the tile, rotated 90 degrees.

add_col_percentages

Add the column percentages, i.e. how big a part of its column the tile makes up. (Logical)

By default, the row percentage is placed at the bottom of the tile.

diag_percentages_only

Whether to only have row and column percentages in the diagonal tiles. (Logical)

rm_zero_percentages

Whether to remove row and column percentages when the count is 0. (Logical)

rm_zero_text

Whether to remove counts and normalized percentages when the count is 0. (Logical)

add_zero_shading

Add image of skewed lines to zero-tiles. (Logical)

Note: Adding the zero-shading requires the rsvg and ggimage packages.

add_arrows

Add the arrows to the row and col percentages. (Logical)

Note: Adding the arrows requires the rsvg and ggimage packages.

counts_on_top

Switch the counts and normalized counts, such that the counts are on top. (Logical)

palette

Color scheme. Passed directly to `palette` in ggplot2::scale_fill_distiller.

Try these palettes: "Greens", "Oranges", "Greys", "Purples", "Reds", as well as the default "Blues".

intensity_by

The measure that should control the color intensity of the tiles. Either `counts` or `normalized`. For the latter, the color limits become 0-100, why the intensities can better be compared across plots.

theme_fn

The ggplot2 theme function to apply.

place_x_axis_above

Move the x-axis text to the top and reverse the levels such that the "correct" diagonal goes from top left to bottom right. (Logical)

rotate_y_text

Whether to rotate the y-axis text to be vertical instead of horizontal. (Logical)

digits

Number of digits to round to (percentages only). Set to a negative number for no rounding.

Can be set for each font individually via the font_* arguments.

font_counts

list of font settings for the counts. Can be provided with font().

font_normalized

list of font settings for the normalized counts. Can be provided with font().

font_row_percentages

list of font settings for the row percentages. Can be provided with font().

font_col_percentages

list of font settings for the column percentages. Can be provided with font().

arrow_size

Size of arrow icons. (Numeric)

Is divided by sqrt(nrow(conf_matrix)) and passed on to ggimage::geom_icon().

arrow_nudge_from_text

Distance from the percentage text to the arrow. (Numeric)

tile_border_color

Color of the tile borders. Passed as `colour` to ggplot2::geom_tile.

tile_border_size

Size of the tile borders. Passed as `size` to ggplot2::geom_tile.

tile_border_linetype

Linetype for the tile borders. Passed as `linetype` to ggplot2::geom_tile.

sums_settings

A list of settings for the appearance of the sum tiles. Can be provided with sum_tile_settings().

darkness

How dark the darkest colors should be, between 0 and 1, where 1 is darkest.

Technically, a lower value increases the upper limit in ggplot2::scale_fill_distiller.

Details

Inspired by Antoine Sachet's answer at https://stackoverflow.com/a/53612391/11832955

Value

A ggplot2 object representing a confusion matrix. Color intensity depends on either the counts (default) or the overall percentages.

By default, each tile has the normalized count (overall percentage) and count in the middle, the column percentage at the bottom, and the row percentage to the right and rotated 90 degrees.

In the "correct" diagonal (upper left to bottom right, by default), the column percentages are the class-level sensitivity scores, while the row percentages are the class-level positive predictive values.

Author(s)

Ludvig Renbo Olsen, r-pkgs@ludvigolsen.dk

See Also

Other plotting functions: font(), plot_metric_density(), plot_probabilities_ecdf(), plot_probabilities(), sum_tile_settings()

Examples


# Attach cvms
library(cvms)
library(ggplot2)

# Two classes

# Create targets and predictions data frame
data <- data.frame(
  "target" = c("A", "B", "A", "B", "A", "B", "A", "B",
               "A", "B", "A", "B", "A", "B", "A", "A"),
  "prediction" = c("B", "B", "A", "A", "A", "B", "B", "B",
                   "B", "B", "A", "B", "A", "A", "A", "A"),
  stringsAsFactors = FALSE
)

# Evaluate predictions and create confusion matrix
eval <- evaluate(
  data = data,
  target_col = "target",
  prediction_cols = "prediction",
  type = "binomial"
)

# Inspect confusion matrix tibble
eval[["Confusion Matrix"]][[1]]

# Plot confusion matrix
# Supply confusion matrix tibble directly
plot_confusion_matrix(eval[["Confusion Matrix"]][[1]])
# Plot first confusion matrix in evaluate() output
plot_confusion_matrix(eval)

# Add sum tiles
plot_confusion_matrix(eval, add_sums = TRUE)

# Three (or more) classes

# Create targets and predictions data frame
data <- data.frame(
  "target" = c("A", "B", "C", "B", "A", "B", "C",
               "B", "A", "B", "C", "B", "A"),
  "prediction" = c("C", "B", "A", "C", "A", "B", "B",
                   "C", "A", "B", "C", "A", "C"),
  stringsAsFactors = FALSE
)

# Evaluate predictions and create confusion matrix
eval <- evaluate(
  data = data,
  target_col = "target",
  prediction_cols = "prediction",
  type = "multinomial"
)

# Inspect confusion matrix tibble
eval[["Confusion Matrix"]][[1]]

# Plot confusion matrix
# Supply confusion matrix tibble directly
plot_confusion_matrix(eval[["Confusion Matrix"]][[1]])
# Plot first confusion matrix in evaluate() output
plot_confusion_matrix(eval)

# Add sum tiles
plot_confusion_matrix(eval, add_sums = TRUE)

# Counts only
plot_confusion_matrix(
  eval[["Confusion Matrix"]][[1]],
  add_normalized = FALSE,
  add_row_percentages = FALSE,
  add_col_percentages = FALSE
)

# Change color palette to green
# Change theme to \code{theme_light}.
plot_confusion_matrix(
  eval[["Confusion Matrix"]][[1]],
  palette = "Greens",
  theme_fn = ggplot2::theme_light
)

# The output is a ggplot2 object
# that you can add layers to
# Here we change the axis labels
plot_confusion_matrix(eval[["Confusion Matrix"]][[1]]) +
  ggplot2::labs(x = "True", y = "Guess")


[Package cvms version 1.3.3 Index]