gan_update_step {RGAN}R Documentation

gan_update_step

Description

Provides a function to send the output of a DataTransformer to a torch tensor, so that it can be accessed during GAN training.

Usage

gan_update_step(
  data,
  batch_size,
  noise_dim,
  sample_noise,
  device = "cpu",
  g_net,
  d_net,
  g_optim,
  d_optim,
  value_function,
  weight_clipper
)

Arguments

data

Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset.

batch_size

The number of training samples selected into the mini batch for training. Defaults to 50.

noise_dim

The dimensions of the GAN noise vector z. Defaults to 2.

sample_noise

A function to sample noise to a torch::tensor

device

Input on which device (e.g. "cpu" or "cuda") training should be done. Defaults to "cpu".

g_net

The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network.

d_net

The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network.

g_optim

The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup torch::optim_adam(g_net$parameters, lr = base_lr).

d_optim

The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup torch::optim_adam(g_net$parameters, lr = base_lr * ttur_factor).

value_function

The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original".

weight_clipper

The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training.

Value

A function


[Package RGAN version 0.1.1 Index]