luz_metric_multiclass_auroc {luz} | R Documentation |
Computes the multi-class AUROC
Description
The same definition as Keras
is used by default. This is equivalent to the 'micro'
method in SciKit Learn
too. See docs.
Usage
luz_metric_multiclass_auroc(
num_thresholds = 200,
thresholds = NULL,
from_logits = FALSE,
average = c("micro", "macro", "weighted", "none")
)
Arguments
num_thresholds |
Number of thresholds used to compute confusion matrices.
In that case, thresholds are created by getting |
thresholds |
(optional) If threshold are passed, then those are used to compute the
confusion matrices and |
from_logits |
If |
average |
The averaging method:
|
Details
Note that class imbalance can affect this metric unlike the AUC for binary classification.
Currently the AUC is approximated using the 'interpolation' method described in Keras.
See Also
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_rmse()
,
luz_metric()
Examples
if (torch::torch_is_installed()) {
library(torch)
actual <- c(1, 1, 1, 0, 0, 0) + 1L
predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2)
predicted <- cbind(1-predicted, predicted)
y_true <- torch_tensor(as.integer(actual))
y_pred <- torch_tensor(predicted)
m <- luz_metric_multiclass_auroc(thresholds = as.numeric(predicted),
average = "micro")
m <- m$new()
m$update(y_pred[1:2,], y_true[1:2])
m$update(y_pred[3:4,], y_true[3:4])
m$update(y_pred[5:6,], y_true[5:6])
m$compute()
}