gan_update_step {RGAN} | R Documentation |
gan_update_step
Description
Provides a function to send the output of a DataTransformer to a torch tensor, so that it can be accessed during GAN training.
Usage
gan_update_step(
data,
batch_size,
noise_dim,
sample_noise,
device = "cpu",
g_net,
d_net,
g_optim,
d_optim,
value_function,
weight_clipper
)
Arguments
data |
Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset. |
batch_size |
The number of training samples selected into the mini batch for training. Defaults to 50. |
noise_dim |
The dimensions of the GAN noise vector z. Defaults to 2. |
sample_noise |
A function to sample noise to a torch::tensor |
device |
Input on which device (e.g. "cpu" or "cuda") training should be done. Defaults to "cpu". |
g_net |
The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
d_net |
The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
g_optim |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
d_optim |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
value_function |
The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original". |
weight_clipper |
The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training. |
Value
A function