gan_trainer {RGAN}R Documentation

gan_trainer

Description

Provides a function to quickly train a GAN model.

Usage

gan_trainer(
  data,
  noise_dim = 2,
  noise_distribution = "normal",
  value_function = "original",
  data_type = "tabular",
  generator = NULL,
  generator_optimizer = NULL,
  discriminator = NULL,
  discriminator_optimizer = NULL,
  base_lr = 1e-04,
  ttur_factor = 4,
  weight_clipper = NULL,
  batch_size = 50,
  epochs = 150,
  plot_progress = FALSE,
  plot_interval = "epoch",
  eval_dropout = FALSE,
  synthetic_examples = 500,
  plot_dimensions = c(1, 2),
  device = "cpu"
)

Arguments

data

Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset.

noise_dim

The dimensions of the GAN noise vector z. Defaults to 2.

noise_distribution

The noise distribution. Expects a function that samples from a distribution and returns a torch_tensor. For convenience "normal" and "uniform" will automatically set a function. Defaults to "normal".

value_function

The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original".

data_type

"tabular" or "image", controls the data type, defaults to "tabular".

generator

The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network.

generator_optimizer

The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup torch::optim_adam(g_net$parameters, lr = base_lr).

discriminator

The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network.

discriminator_optimizer

The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup torch::optim_adam(g_net$parameters, lr = base_lr * ttur_factor).

base_lr

The base learning rate for the optimizers. Default is 0.0001. Only used if no optimizer is explicitly passed to the trainer.

ttur_factor

A multiplier for the learning rate of the discriminator, to implement the two time scale update rule.

weight_clipper

The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training.

batch_size

The number of training samples selected into the mini batch for training. Defaults to 50.

epochs

The number of training epochs. Defaults to 150.

plot_progress

Monitor training progress with plots. Defaults to FALSE.

plot_interval

Number of training steps between plots. Input number of steps or "epoch". Defaults to "epoch".

eval_dropout

Should dropout be applied during the sampling of synthetic data? Defaults to FALSE.

synthetic_examples

Number of synthetic examples that should be generated. Defaults to 500. For image data e.g. 16 would be more reasonable.

plot_dimensions

If you monitor training progress with a plot which dimensions of the data do you want to look at? Defaults to c(1, 2), i.e. the first two columns of the tabular data.

device

Input on which device (e.g. "cpu" or "cuda") training should be done. Defaults to "cpu".

Value

gan_trainer trains the neural networks and returns an object of class trained_RGAN that contains the last generator, discriminator and the respective optimizers, as well as the settings.

Examples

## Not run: 
# Before running the first time the torch backend needs to be installed
torch::install_torch()
# Load data
data <- sample_toydata()
# Build new transformer
transformer <- data_transformer$new()
# Fit transformer to data
transformer$fit(data)
# Transform data and store as new object
transformed_data <-  transformer$transform(data)
# Train the default GAN
trained_gan <- gan_trainer(transformed_data)
# Sample synthetic data from the trained GAN
synthetic_data <- sample_synthetic_data(trained_gan, transformer)
# Plot the results
GAN_update_plot(data = data,
synth_data = synthetic_data,
main = "Real and Synthetic Data after Training")

## End(Not run)

[Package RGAN version 0.1.1 Index]