conf_mat {yardstick} | R Documentation |
Confusion Matrix for Categorical Data
Description
Calculates a cross-tabulation of observed and predicted classes.
Usage
conf_mat(data, ...)
## S3 method for class 'data.frame'
conf_mat(
data,
truth,
estimate,
dnn = c("Prediction", "Truth"),
case_weights = NULL,
...
)
## S3 method for class 'conf_mat'
tidy(x, ...)
Arguments
data |
A data frame or a |
... |
Not used. |
truth |
The column identifier for the true class results
(that is a |
estimate |
The column identifier for the predicted class
results (that is also |
dnn |
A character vector of dimnames for the table. |
case_weights |
The optional column identifier for case weights.
This should be an unquoted column name that evaluates to a numeric column
in |
x |
A |
Details
For conf_mat()
objects, a broom
tidy()
method has been created
that collapses the cell counts by cell into a data frame for
easy manipulation.
There is also a summary()
method that computes various classification
metrics at once. See summary.conf_mat()
There is a ggplot2::autoplot()
method for quickly visualizing the matrix. Both a heatmap and mosaic type
is implemented.
The function requires that the factors have exactly the same levels.
Value
conf_mat()
produces an object with class conf_mat
. This contains the
table and other objects. tidy.conf_mat()
generates a tibble with columns
name
(the cell identifier) and value
(the cell count).
When used on a grouped data frame, conf_mat()
returns a tibble containing
columns for the groups along with conf_mat
, a list-column
where each element is a conf_mat
object.
See Also
summary.conf_mat()
for computing a large number of metrics from one
confusion matrix.
Examples
library(dplyr)
data("hpc_cv")
# The confusion matrix from a single assessment set (i.e. fold)
cm <- hpc_cv %>%
filter(Resample == "Fold01") %>%
conf_mat(obs, pred)
cm
# Now compute the average confusion matrix across all folds in
# terms of the proportion of the data contained in each cell.
# First get the raw cell counts per fold using the `tidy` method
library(tidyr)
cells_per_resample <- hpc_cv %>%
group_by(Resample) %>%
conf_mat(obs, pred) %>%
mutate(tidied = lapply(conf_mat, tidy)) %>%
unnest(tidied)
# Get the totals per resample
counts_per_resample <- hpc_cv %>%
group_by(Resample) %>%
summarize(total = n()) %>%
left_join(cells_per_resample, by = "Resample") %>%
# Compute the proportions
mutate(prop = value / total) %>%
group_by(name) %>%
# Average
summarize(prop = mean(prop))
counts_per_resample
# Now reshape these into a matrix
mean_cmat <- matrix(counts_per_resample$prop, byrow = TRUE, ncol = 4)
rownames(mean_cmat) <- levels(hpc_cv$obs)
colnames(mean_cmat) <- levels(hpc_cv$obs)
round(mean_cmat, 3)
# The confusion matrix can quickly be visualized using autoplot()
library(ggplot2)
autoplot(cm, type = "mosaic")
autoplot(cm, type = "heatmap")