gan_trainer {RGAN} | R Documentation |
gan_trainer
Description
Provides a function to quickly train a GAN model.
Usage
gan_trainer(
data,
noise_dim = 2,
noise_distribution = "normal",
value_function = "original",
data_type = "tabular",
generator = NULL,
generator_optimizer = NULL,
discriminator = NULL,
discriminator_optimizer = NULL,
base_lr = 1e-04,
ttur_factor = 4,
weight_clipper = NULL,
batch_size = 50,
epochs = 150,
plot_progress = FALSE,
plot_interval = "epoch",
eval_dropout = FALSE,
synthetic_examples = 500,
plot_dimensions = c(1, 2),
device = "cpu"
)
Arguments
data |
Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset. |
noise_dim |
The dimensions of the GAN noise vector z. Defaults to 2. |
noise_distribution |
The noise distribution. Expects a function that samples from a distribution and returns a torch_tensor. For convenience "normal" and "uniform" will automatically set a function. Defaults to "normal". |
value_function |
The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original". |
data_type |
"tabular" or "image", controls the data type, defaults to "tabular". |
generator |
The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
generator_optimizer |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
discriminator |
The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
discriminator_optimizer |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
base_lr |
The base learning rate for the optimizers. Default is 0.0001. Only used if no optimizer is explicitly passed to the trainer. |
ttur_factor |
A multiplier for the learning rate of the discriminator, to implement the two time scale update rule. |
weight_clipper |
The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training. |
batch_size |
The number of training samples selected into the mini batch for training. Defaults to 50. |
epochs |
The number of training epochs. Defaults to 150. |
plot_progress |
Monitor training progress with plots. Defaults to FALSE. |
plot_interval |
Number of training steps between plots. Input number of steps or "epoch". Defaults to "epoch". |
eval_dropout |
Should dropout be applied during the sampling of synthetic data? Defaults to FALSE. |
synthetic_examples |
Number of synthetic examples that should be generated. Defaults to 500. For image data e.g. 16 would be more reasonable. |
plot_dimensions |
If you monitor training progress with a plot which dimensions of the data do you want to look at? Defaults to c(1, 2), i.e. the first two columns of the tabular data. |
device |
Input on which device (e.g. "cpu" or "cuda") training should be done. Defaults to "cpu". |
Value
gan_trainer trains the neural networks and returns an object of class trained_RGAN that contains the last generator, discriminator and the respective optimizers, as well as the settings.
Examples
## Not run:
# Before running the first time the torch backend needs to be installed
torch::install_torch()
# Load data
data <- sample_toydata()
# Build new transformer
transformer <- data_transformer$new()
# Fit transformer to data
transformer$fit(data)
# Transform data and store as new object
transformed_data <- transformer$transform(data)
# Train the default GAN
trained_gan <- gan_trainer(transformed_data)
# Sample synthetic data from the trained GAN
synthetic_data <- sample_synthetic_data(trained_gan, transformer)
# Plot the results
GAN_update_plot(data = data,
synth_data = synthetic_data,
main = "Real and Synthetic Data after Training")
## End(Not run)