luz_metric {luz}R Documentation

Creates a new luz metric

Description

Creates a new luz metric

Usage

luz_metric(
  name = NULL,
  ...,
  private = NULL,
  active = NULL,
  parent_env = parent.frame(),
  inherit = NULL
)

Arguments

name

string naming the new metric.

...

named list of public methods. You should implement at least initialize, update and compute. See the details section for more information.

private

An optional list of private members, which can be functions and non-functions.

active

An optional list of active binding functions.

parent_env

An environment to use as the parent of newly-created objects.

inherit

A R6ClassGenerator object to inherit from; in other words, a superclass. This is captured as an unevaluated expression which is evaluated in parent_env each time an object is instantiated.

Details

In order to implement a new luz_metric we need to implement 3 methods:

Optionally, you can implement an abbrev field that gives the metric an abbreviation that will be used when displaying metric information in the console or tracking record. If no abbrev is passed, the class name will be used.

Let’s take a look at the implementation of luz_metric_accuracy so you can see how to implement a new one:

luz_metric_accuracy <- luz_metric(
  # An abbreviation to be shown in progress bars, or 
  # when printing progress
  abbrev = "Acc", 
  # Initial setup for the metric. Metrics are initialized
  # every epoch, for both training and validation
  initialize = function() {
    self$correct <- 0
    self$total <- 0
  },
  # Run at every training or validation step and updates
  # the internal state. The update function takes `preds`
  # and `target` as parameters.
  update = function(preds, target) {
    pred <- torch::torch_argmax(preds, dim = 2)
    self$correct <- self$correct + (pred == target)$
      to(dtype = torch::torch_float())$
      sum()$
      item()
    self$total <- self$total + pred$numel()
  },
  # Use the internal state to query the metric value
  compute = function() {
    self$correct/self$total
  }
)

Note: It’s good practice that the compute metric returns regular R values instead of torch tensors and other parts of luz will expect that.

Value

Returns new luz metric.

See Also

Other luz_metrics: luz_metric_accuracy(), luz_metric_binary_accuracy_with_logits(), luz_metric_binary_accuracy(), luz_metric_binary_auroc(), luz_metric_mae(), luz_metric_mse(), luz_metric_multiclass_auroc(), luz_metric_rmse()

Examples

luz_metric_accuracy <- luz_metric(
  # An abbreviation to be shown in progress bars, or
  # when printing progress
  abbrev = "Acc",
  # Initial setup for the metric. Metrics are initialized
  # every epoch, for both training and validation
  initialize = function() {
    self$correct <- 0
    self$total <- 0
  },
  # Run at every training or validation step and updates
  # the internal state. The update function takes `preds`
  # and `target` as parameters.
  update = function(preds, target) {
    pred <- torch::torch_argmax(preds, dim = 2)
    self$correct <- self$correct + (pred == target)$
      to(dtype = torch::torch_float())$
      sum()$
      item()
    self$total <- self$total + pred$numel()
  },
  # Use the internal state to query the metric value
  compute = function() {
    self$correct/self$total
  }
)


[Package luz version 0.4.0 Index]