metric_auc {keras3} | R Documentation |
Approximates the AUC (Area under the curve) of the ROC or PR curves.
Description
The AUC (Area under the curve) of the ROC (Receiver operating characteristic; default) or PR (Precision Recall) curves are quality measures of binary classifiers. Unlike the accuracy, and like cross-entropy losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
This class approximates AUCs using a Riemann sum. During the metric accumulation phrase, predictions are accumulated within predefined buckets by value. The AUC is then computed by interpolating per-bucket averages. These buckets define the evaluated operational points.
This metric creates four local variables, true_positives
,
true_negatives
, false_positives
and false_negatives
that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
This value is ultimately returned as auc
, an idempotent operation that
computes the area under a discretized curve of precision versus recall
values (computed using the aforementioned variables). The num_thresholds
variable controls the degree of discretization with larger numbers of
thresholds more closely approximating the true AUC. The quality of the
approximation may vary dramatically depending on num_thresholds
. The
thresholds
parameter can be used to manually specify thresholds which
split the predictions more evenly.
For a best approximation of the real AUC, predictions
should be
distributed approximately uniformly in the range [0, 1]
(if
from_logits=FALSE
). The quality of the AUC approximation may be poor if
this is not the case. Setting summation_method
to 'minoring' or 'majoring'
can help quantify the error in the approximation by providing lower or upper
bound estimate of the AUC.
If sample_weight
is NULL
, weights default to 1.
Use sample_weight
of 0 to mask values.
Usage
metric_auc(
...,
num_thresholds = 200L,
curve = "ROC",
summation_method = "interpolation",
name = NULL,
dtype = NULL,
thresholds = NULL,
multi_label = FALSE,
num_labels = NULL,
label_weights = NULL,
from_logits = FALSE
)
Arguments
... |
For forward/backward compatability. |
num_thresholds |
(Optional) The number of thresholds to
use when discretizing the roc curve. Values must be > 1.
Defaults to |
curve |
(Optional) Specifies the name of the curve to be computed,
|
summation_method |
(Optional) Specifies the Riemann summation method used.
'interpolation' (default) applies mid-point summation scheme for
|
name |
(Optional) string name of the metric instance. |
dtype |
(Optional) data type of the metric result. |
thresholds |
(Optional) A list of floating point values to use as the
thresholds for discretizing the curve. If set, the |
multi_label |
boolean indicating whether multilabel data should be
treated as such, wherein AUC is computed separately for each label
and then averaged across labels, or (when |
num_labels |
(Optional) The number of labels, used when |
label_weights |
(Optional) list, array, or tensor of non-negative weights
used to compute AUCs for multilabel data. When |
from_logits |
boolean indicating whether the predictions ( |
Value
a Metric
instance is returned. The Metric
instance can be passed
directly to compile(metrics = )
, or used as a standalone object. See
?Metric
for example usage.
Usage
Standalone usage:
m <- metric_auc(num_thresholds = 3) m$update_state(c(0, 0, 1, 1), c(0, 0.5, 0.3, 0.9)) # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0] # auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0))) # = 0.75 m$result()
## tf.Tensor(0.75, shape=(), dtype=float32)
m$reset_state() m$update_state(c(0, 0, 1, 1), c(0, 0.5, 0.3, 0.9), sample_weight=c(1, 0, 0, 1)) m$result()
## tf.Tensor(1.0, shape=(), dtype=float32)
Usage with compile()
API:
# Reports the AUC of a model outputting a probability. model |> compile( optimizer = 'sgd', loss = loss_binary_crossentropy(), metrics = list(metric_auc()) ) # Reports the AUC of a model outputting a logit. model |> compile( optimizer = 'sgd', loss = loss_binary_crossentropy(from_logits = TRUE), metrics = list(metric_auc(from_logits = TRUE)) )
See Also
Other confusion metrics:
metric_false_negatives()
metric_false_positives()
metric_precision()
metric_precision_at_recall()
metric_recall()
metric_recall_at_precision()
metric_sensitivity_at_specificity()
metric_specificity_at_sensitivity()
metric_true_negatives()
metric_true_positives()
Other metrics:
Metric()
custom_metric()
metric_binary_accuracy()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_binary_iou()
metric_categorical_accuracy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_cosine_similarity()
metric_f1_score()
metric_false_negatives()
metric_false_positives()
metric_fbeta_score()
metric_hinge()
metric_huber()
metric_iou()
metric_kl_divergence()
metric_log_cosh()
metric_log_cosh_error()
metric_mean()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_iou()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_mean_wrapper()
metric_one_hot_iou()
metric_one_hot_mean_iou()
metric_poisson()
metric_precision()
metric_precision_at_recall()
metric_r2_score()
metric_recall()
metric_recall_at_precision()
metric_root_mean_squared_error()
metric_sensitivity_at_specificity()
metric_sparse_categorical_accuracy()
metric_sparse_categorical_crossentropy()
metric_sparse_top_k_categorical_accuracy()
metric_specificity_at_sensitivity()
metric_squared_hinge()
metric_sum()
metric_top_k_categorical_accuracy()
metric_true_negatives()
metric_true_positives()