FNN {gnn} | R Documentation |
Generative Moment Matching Network
Constructor for a generative feedforward neural network (FNN) model,
an object of S3
class "gnn_FNN"
FNN(dim = c(2, 2), activation = c(rep("relu", length(dim) - 2), "sigmoid"),
batch.norm = FALSE, dropout.rate = 0, loss.fun = "MMD", n.GPU = 0, ...)
dim |
activation |
loss.fun |
batch.norm |
dropout.rate |
n.GPU |
non-negative |
... |
additional arguments passed to |
The S3
class "gnn_FNN"
is a subclass of the
class "gnn_GNN"
which in turn is a subclass of
returns an object of S3
class "gnn_FNN"
with components
FNN model (a keras object inheriting from the R6 classes
, or araw
string indicating the type of model.dim
see above.
see above.
see above.
see above.
number of trainable, non-trainable and total number of parameters.
type of loss function (
number of training samples (
unless trained).batch.size
batch size (
unless trained).n.epoch
number of epochs (
unless trained).loss
containing the loss function values per epoch.time
object of S3 class
containing the training time (if trained).prior
containing a (sub-)sample of the prior (if trained).
Marius Hofert and Avinash Prasad
Li, Y., Swersky, K. and Zemel, R. (2015). Generative moment matching networks. Proceedings of Machine Learning Research, 37 (International Conference on Maching Learning), 1718–1727. See http://proceedings.mlr.press/v37/li15.pdf (2019-08-24)
Dziugaite, G. K., Roy, D. M. and Ghahramani, Z. (2015). Training generative neural networks via maximum mean discrepancy optimization. AUAI Press, 258–267. See http://www.auai.org/uai2015/proceedings/papers/230.pdf (2019-08-24)
Hofert, M., Prasad, A. and Zhu, M. (2020). Quasi-random sampling for multivariate distributions via generative neural networks. Journal of Computational and Graphical Statistics, doi:10.1080/10618600.2020.1868302.
Hofert, M., Prasad, A. and Zhu, M. (2020). Multivariate time-series modeling with generative neural networks. See https://arxiv.org/abs/2002.10645.
Hofert, M. Prasad, A. and Zhu, M. (2020). Applications of multivariate quasi-random sampling with neural networks. See https://arxiv.org/abs/2012.08036.
if(TensorFlow_available()) { # rather restrictive (due to R-Forge, winbuilder)
library(gnn) # for being standalone
## Training data
d <- 2 # bivariate case
P <- matrix(0.9, nrow = d, ncol = d); diag(P) <- 1 # correlation matrix
ntrn <- 60000 # training data sample size
X <- abs(rNorm(ntrn, scale = P)) # componentwise absolute values of N(0,P) sample
## Plot a subsample
m <- 2000 # subsample size for plots
opar <- par(pty = "s")
plot(X[1:m,], xlab = expression(X[1]), ylab = expression(X[2])) # plot |X|
U <- apply(X, 2, rank) / (ntrn + 1) # pseudo-observations of |X|
plot(U[1:m,], xlab = expression(U[1]), ylab = expression(U[2])) # visual check
## Model 1: A basic feedforward neural network (FNN) with MSE loss function
fnn <- FNN(c(d, 300, d), loss.fun = "MSE") # define the FNN
fnn <- fitGNN(fnn, data = U, n.epoch = 40) # train with batch optimization
plot(fnn, kind = "loss") # plot the loss after each epoch
## Model 2: A GMMN (FNN with MMD loss function)
gmmn <- FNN(c(d, 300, d)) # define the GMMN (initialized with random weights)
## For training we need to use a mini-batch optimization (batch size < nrow(U)).
## For a fair comparison (same number of gradient steps) to NN, we use 500
## samples (25% = 4 gradient steps/epoch) for 10 epochs for GMMN.
library(keras) # for callback_early_stopping()
## We monitor the loss function and stop earlier if the loss function
## over the last patience-many epochs has changed by less than min_delta
## in absolute value. Then we keep the weights that led to the smallest
## loss seen throughout training.
gmmn <- fitGNN(gmmn, data = U, batch.size = 500, n.epoch = 10,
callbacks = callback_early_stopping(monitor = "loss",
min_delta = 1e-3, patience = 3,
restore_best_weights = TRUE))
plot(gmmn, kind = "loss") # plot the loss after each epoch
## Note:
## - Obviously, in a real-world application, batch.size and n.epoch
## should be (much) larger (e.g., batch.size = 5000, n.epoch = 300).
## - Training is not reproducible (due to keras).
## Model 3: A FNN with CvM loss function
fnnCvM <- FNN(c(d, 300, d), loss.fun = "CvM")
fnnCvM <- fitGNN(fnnCvM, data = U, batch.size = 500, n.epoch = 10,
callbacks = callback_early_stopping(monitor = "loss",
min_delta = 1e-3, patience = 3,
restore_best_weights = TRUE))
plot(fnnCvM, kind = "loss") # plot the loss after each epoch
## Sample from the different models
V.fnn <- rGNN(fnn, size = m)
V.gmmn <- rGNN(gmmn, size = m)
V.fnnCvM <- rGNN(fnnCvM, size = m)
## Joint plot of training subsample with GMMN PRNs. Clearly, the MSE
## cannot be used to learn the distribution correctly.
layout(matrix(1:4, ncol = 2, byrow = TRUE))
plot(U[1:m,], xlab = expression(U[1]), ylab = expression(U[2]), cex = 0.2)
mtext("Training subsample", side = 4, line = 0.4, adj = 0)
plot(V.fnn, xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with MSE loss", side = 4, line = 0.4, adj = 0)
plot(V.gmmn, xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with MMD loss", side = 4, line = 0.4, adj = 0)
plot(V.fnnCvM, xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with CvM loss", side = 4, line = 0.4, adj = 0)
## Joint plot of training subsample with GMMN QRNs
library(qrng) # for sobol()
V.fnn. <- rGNN(fnn, size = m, method = "sobol", randomize = "Owen")
V.gmmn. <- rGNN(gmmn, size = m, method = "sobol", randomize = "Owen")
V.fnnCvM. <- rGNN(fnnCvM, size = m, method = "sobol", randomize = "Owen")
plot(U[1:m,], xlab = expression(U[1]), ylab = expression(U[2]), cex = 0.2)
mtext("Training subsample", side = 4, line = 0.4, adj = 0)
plot(V.fnn., xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with MSE loss", side = 4, line = 0.4, adj = 0)
plot(V.gmmn., xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with MMD loss", side = 4, line = 0.4, adj = 0)
plot(V.fnnCvM., xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with CvM loss", side = 4, line = 0.4, adj = 0)