FNN {gnn} | R Documentation |
Generative Moment Matching Network
Description
Constructor for a generative feedforward neural network (FNN) model,
an object of S3
class "gnn_FNN"
.
Usage
FNN(dim = c(2, 2), activation = c(rep("relu", length(dim) - 2), "sigmoid"),
batch.norm = FALSE, dropout.rate = 0, loss.fun = "MMD", n.GPU = 0, ...)
Arguments
dim |
|
activation |
|
loss.fun |
|
batch.norm |
|
dropout.rate |
|
n.GPU |
non-negative |
... |
additional arguments passed to |
Details
The S3
class "gnn_FNN"
is a subclass of the
S3
class "gnn_GNN"
which in turn is a subclass of
"gnn_Model"
.
Value
FNN()
returns an object of S3
class "gnn_FNN"
with components
model
FNN model (a keras object inheriting from the R6 classes
"keras.engine.training.Model"
,"keras.engine.network.Network"
,"keras.engine.base_layer.Layer"
and"python.builtin.object"
, or araw
object).type
character
string indicating the type of model.dim
see above.
activation
see above.
batch.norm
see above.
dropout.rate
see above.
n.param
number of trainable, non-trainable and total number of parameters.
loss.type
type of loss function (
character
).n.train
number of training samples (
NA_integer_
unless trained).batch.size
batch size (
NA_integer_
unless trained).n.epoch
number of epochs (
NA_integer_
unless trained).loss
numeric(n.epoch)
containing the loss function values per epoch.time
object of S3 class
"proc_time"
containing the training time (if trained).prior
matrix
containing a (sub-)sample of the prior (if trained).
Author(s)
Marius Hofert and Avinash Prasad
References
Li, Y., Swersky, K. and Zemel, R. (2015). Generative moment matching networks. Proceedings of Machine Learning Research, 37 (International Conference on Maching Learning), 1718–1727. See http://proceedings.mlr.press/v37/li15.pdf (2019-08-24)
Dziugaite, G. K., Roy, D. M. and Ghahramani, Z. (2015). Training generative neural networks via maximum mean discrepancy optimization. AUAI Press, 258–267. See http://www.auai.org/uai2015/proceedings/papers/230.pdf (2019-08-24)
Hofert, M., Prasad, A. and Zhu, M. (2020). Quasi-random sampling for multivariate distributions via generative neural networks. Journal of Computational and Graphical Statistics, doi:10.1080/10618600.2020.1868302.
Hofert, M., Prasad, A. and Zhu, M. (2020). Multivariate time-series modeling with generative neural networks. See https://arxiv.org/abs/2002.10645.
Hofert, M. Prasad, A. and Zhu, M. (2020). Applications of multivariate quasi-random sampling with neural networks. See https://arxiv.org/abs/2012.08036.
Examples
if(TensorFlow_available()) { # rather restrictive (due to R-Forge, winbuilder)
library(gnn) # for being standalone
## Training data
d <- 2 # bivariate case
P <- matrix(0.9, nrow = d, ncol = d); diag(P) <- 1 # correlation matrix
ntrn <- 60000 # training data sample size
set.seed(271)
library(nvmix)
X <- abs(rNorm(ntrn, scale = P)) # componentwise absolute values of N(0,P) sample
## Plot a subsample
m <- 2000 # subsample size for plots
opar <- par(pty = "s")
plot(X[1:m,], xlab = expression(X[1]), ylab = expression(X[2])) # plot |X|
U <- apply(X, 2, rank) / (ntrn + 1) # pseudo-observations of |X|
plot(U[1:m,], xlab = expression(U[1]), ylab = expression(U[2])) # visual check
## Model 1: A basic feedforward neural network (FNN) with MSE loss function
fnn <- FNN(c(d, 300, d), loss.fun = "MSE") # define the FNN
fnn <- fitGNN(fnn, data = U, n.epoch = 40) # train with batch optimization
plot(fnn, kind = "loss") # plot the loss after each epoch
## Model 2: A GMMN (FNN with MMD loss function)
gmmn <- FNN(c(d, 300, d)) # define the GMMN (initialized with random weights)
## For training we need to use a mini-batch optimization (batch size < nrow(U)).
## For a fair comparison (same number of gradient steps) to NN, we use 500
## samples (25% = 4 gradient steps/epoch) for 10 epochs for GMMN.
library(keras) # for callback_early_stopping()
## We monitor the loss function and stop earlier if the loss function
## over the last patience-many epochs has changed by less than min_delta
## in absolute value. Then we keep the weights that led to the smallest
## loss seen throughout training.
gmmn <- fitGNN(gmmn, data = U, batch.size = 500, n.epoch = 10,
callbacks = callback_early_stopping(monitor = "loss",
min_delta = 1e-3, patience = 3,
restore_best_weights = TRUE))
plot(gmmn, kind = "loss") # plot the loss after each epoch
## Note:
## - Obviously, in a real-world application, batch.size and n.epoch
## should be (much) larger (e.g., batch.size = 5000, n.epoch = 300).
## - Training is not reproducible (due to keras).
## Model 3: A FNN with CvM loss function
fnnCvM <- FNN(c(d, 300, d), loss.fun = "CvM")
fnnCvM <- fitGNN(fnnCvM, data = U, batch.size = 500, n.epoch = 10,
callbacks = callback_early_stopping(monitor = "loss",
min_delta = 1e-3, patience = 3,
restore_best_weights = TRUE))
plot(fnnCvM, kind = "loss") # plot the loss after each epoch
## Sample from the different models
set.seed(271)
V.fnn <- rGNN(fnn, size = m)
set.seed(271)
V.gmmn <- rGNN(gmmn, size = m)
set.seed(271)
V.fnnCvM <- rGNN(fnnCvM, size = m)
## Joint plot of training subsample with GMMN PRNs. Clearly, the MSE
## cannot be used to learn the distribution correctly.
layout(matrix(1:4, ncol = 2, byrow = TRUE))
plot(U[1:m,], xlab = expression(U[1]), ylab = expression(U[2]), cex = 0.2)
mtext("Training subsample", side = 4, line = 0.4, adj = 0)
plot(V.fnn, xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with MSE loss", side = 4, line = 0.4, adj = 0)
plot(V.gmmn, xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with MMD loss", side = 4, line = 0.4, adj = 0)
plot(V.fnnCvM, xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with CvM loss", side = 4, line = 0.4, adj = 0)
## Joint plot of training subsample with GMMN QRNs
library(qrng) # for sobol()
V.fnn. <- rGNN(fnn, size = m, method = "sobol", randomize = "Owen")
V.gmmn. <- rGNN(gmmn, size = m, method = "sobol", randomize = "Owen")
V.fnnCvM. <- rGNN(fnnCvM, size = m, method = "sobol", randomize = "Owen")
plot(U[1:m,], xlab = expression(U[1]), ylab = expression(U[2]), cex = 0.2)
mtext("Training subsample", side = 4, line = 0.4, adj = 0)
plot(V.fnn., xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with MSE loss", side = 4, line = 0.4, adj = 0)
plot(V.gmmn., xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with MMD loss", side = 4, line = 0.4, adj = 0)
plot(V.fnnCvM., xlab = expression(V[1]), ylab = expression(V[2]), cex = 0.2)
mtext("Trained NN with CvM loss", side = 4, line = 0.4, adj = 0)
layout(1)
par(opar)
}