dataset_rejection_resample {tfdatasets}R Documentation

A transformation that resamples a dataset to a target distribution.

Description

A transformation that resamples a dataset to a target distribution.

Usage

dataset_rejection_resample(
  dataset,
  class_func,
  target_dist,
  initial_dist = NULL,
  seed = NULL,
  name = NULL
)

Arguments

dataset

A tf.Dataset

class_func

A function mapping an element of the input dataset to a scalar tf.int32 tensor. Values should be in ⁠[0, num_classes)⁠.

target_dist

A floating point type tensor, shaped ⁠[num_classes]⁠.

initial_dist

(Optional.) A floating point type tensor, shaped ⁠[num_classes]⁠. If not provided, the true class distribution is estimated live in a streaming fashion.

seed

(Optional.) Integer seed for the resampler.

name

(Optional.) A name for the tf.data operation.

Value

A tf.Dataset

Examples

## Not run: 
initial_dist <- c(.5, .5)
target_dist <- c(.6, .4)
num_classes <- length(initial_dist)
num_samples <- 100000
data <- sample.int(num_classes, num_samples, prob = initial_dist, replace = TRUE)
dataset <- tensor_slices_dataset(data)
tally <- c(0, 0)
`add<-` <- function (x, value) x + value
# tfautograph::autograph({
#   for(i in dataset)
#     add(tally[as.numeric(i)]) <- 1
# })
dataset %>%
  as_array_iterator() %>%
  iterate(function(i) {
    add(tally[i]) <<- 1
  }, simplify = FALSE)
# The value of `tally` will be close to c(50000, 50000) as
# per the `initial_dist` distribution.
tally # c(50287, 49713)

tally <- c(0, 0)
dataset %>%
  dataset_rejection_resample(
    class_func = function(x) (x-1) %% 2,
    target_dist = target_dist,
    initial_dist = initial_dist
  ) %>%
  as_array_iterator() %>%
  iterate(function(element) {
    names(element) <- c("class_id", "i")
    add(tally[element$i]) <<- 1
  }, simplify = FALSE)
# The value of tally will be now be close to c(75000, 50000)
# thus satisfying the target_dist distribution.
tally # c(74822, 49921)

## End(Not run)

[Package tfdatasets version 2.17.0 Index]