dataset_rejection_resample {tfdatasets} | R Documentation |
A transformation that resamples a dataset to a target distribution.
Description
A transformation that resamples a dataset to a target distribution.
Usage
dataset_rejection_resample(
dataset,
class_func,
target_dist,
initial_dist = NULL,
seed = NULL,
name = NULL
)
Arguments
dataset |
A |
class_func |
A function mapping an element of the input dataset to a
scalar |
target_dist |
A floating point type tensor, shaped |
initial_dist |
(Optional.) A floating point type tensor, shaped
|
seed |
(Optional.) Integer seed for the resampler. |
name |
(Optional.) A name for the tf.data operation. |
Value
A tf.Dataset
Examples
## Not run:
initial_dist <- c(.5, .5)
target_dist <- c(.6, .4)
num_classes <- length(initial_dist)
num_samples <- 100000
data <- sample.int(num_classes, num_samples, prob = initial_dist, replace = TRUE)
dataset <- tensor_slices_dataset(data)
tally <- c(0, 0)
`add<-` <- function (x, value) x + value
# tfautograph::autograph({
# for(i in dataset)
# add(tally[as.numeric(i)]) <- 1
# })
dataset %>%
as_array_iterator() %>%
iterate(function(i) {
add(tally[i]) <<- 1
}, simplify = FALSE)
# The value of `tally` will be close to c(50000, 50000) as
# per the `initial_dist` distribution.
tally # c(50287, 49713)
tally <- c(0, 0)
dataset %>%
dataset_rejection_resample(
class_func = function(x) (x-1) %% 2,
target_dist = target_dist,
initial_dist = initial_dist
) %>%
as_array_iterator() %>%
iterate(function(element) {
names(element) <- c("class_id", "i")
add(tally[element$i]) <<- 1
}, simplify = FALSE)
# The value of tally will be now be close to c(75000, 50000)
# thus satisfying the target_dist distribution.
tally # c(74822, 49921)
## End(Not run)
[Package tfdatasets version 2.17.0 Index]