adaQN_free {stochQN} | R Documentation |
adaQN Free-Mode Optimizer
Description
Optimizes an empirical (perhaps non-convex) loss function over batches of sample data. Compared to function/class 'adaQN', this version lets the user do all the calculations from the outside, only interacting with the object by means of a function that returns a request type and is fed the required calculation through methods 'update_gradient' and 'update_function'.
Order in which requests are made:
========== loop ===========
* calc_grad
... (repeat calc_grad)
if max_incr > 0:
* calc_fun_val_batch
if 'use_grad_diff':
* calc_grad_big_batch (skipped if below max_incr)
===========================
After running this function, apply 'run_adaQN_free' to it to get the first requested piece of information.
Usage
adaQN_free(mem_size = 10, fisher_size = 100, bfgs_upd_freq = 20,
max_incr = 1.01, min_curvature = 1e-04, scal_reg = 1e-04,
rmsprop_weight = 0.9, y_reg = NULL, use_grad_diff = FALSE,
check_nan = TRUE, nthreads = -1)
Arguments
mem_size |
Number of correction pairs to store for approximation of Hessian-vector products. |
fisher_size |
Number of gradients to store for calculation of the empirical Fisher product with gradients. If passing 'NULL', will force 'use_grad_diff' to 'TRUE'. |
bfgs_upd_freq |
Number of iterations (batches) after which to generate a BFGS correction pair. |
max_incr |
Maximum ratio of function values in the validation set under the average values of 'x' during current epoch vs. previous epoch. If the ratio is above this threshold, the BFGS and Fisher memories will be reset, and 'x' values reverted to their previous average. Pass 'NULL' for no function-increase checking. |
min_curvature |
Minimum value of (s * y) / (s * s) in order to accept a correction pair. Pass 'NULL' for no minimum. |
scal_reg |
Regularization parameter to use in the denominator for AdaGrad and RMSProp scaling. |
rmsprop_weight |
If not 'NULL', will use RMSProp formula instead of AdaGrad for approximated inverse-Hessian initialization. |
y_reg |
Regularizer for 'y' vector (gets added y_reg * s). Pass 'NULL' for no regularization. |
use_grad_diff |
Whether to create the correction pairs using differences between gradients instead of Fisher matrix. These gradients are calculated on a larger batch than the regular ones (given by batch_size * bfgs_upd_freq). If 'TRUE', empirical Fisher matrix will not be used. |
check_nan |
Whether to check for variables becoming NaN after each iteration, and reverting the step if they do (will also reset BFGS and Fisher memory). |
nthreads |
Number of parallel threads to use. If set to -1, will determine the number of available threads and use all of them. Note however that not all the computations can be parallelized, and the BLAS backend might use a different number of threads. |
Value
An 'adaQN_free' object, which can be used through functions 'update_gradient', 'update_fun', and 'run_adaQN_free'
References
Keskar, N.S. and Berahas, A.S., 2016, September. "adaQN: An Adaptive Quasi-Newton Algorithm for Training RNNs." In Joint European Conference on Machine Learning and Knowledge Discovery in Databases (pp. 1-16). Springer, Cham.
Wright, S. and Nocedal, J., 1999. "Numerical optimization." (ch 7) Springer Science, 35(67-68), p.7.
See Also
update_gradient , update_fun , run_adaQN_free
Examples
### Example optimizing Rosenbrock 2D function
### Note that this example is not stochastic, as the
### function is not evaluated in expectation based on
### batches of data, but rather it has a given absolute
### form that never varies.
library(stochQN)
fr <- function(x) { ## Rosenbrock Banana function
x1 <- x[1]
x2 <- x[2]
100 * (x2 - x1 * x1)^2 + (1 - x1)^2
}
grr <- function(x) { ## Gradient of 'fr'
x1 <- x[1]
x2 <- x[2]
c(-400 * x1 * (x2 - x1 * x1) - 2 * (1 - x1),
200 * (x2 - x1 * x1))
}
### Initial values of x
x_opt = as.numeric(c(0, 2))
cat(sprintf("Initial values of x: [%.3f, %.3f]\n",
x_opt[1], x_opt[2]))
### Will use constant step size throughout
### (not recommended)
step_size = 1e-2
### Initialize the optimizer
optimizer = adaQN_free()
### Keep track of the iteration number
curr_iter <- 0
### Run a loop for many iterations
### (Note that some iterations might require more
### than 1 calculation request)
for (i in 1:200) {
req <- run_adaQN_free(optimizer, x_opt, step_size)
if (req$task == "calc_grad") {
update_gradient(optimizer, grr(req$requested_on))
} else if (req$task == "calc_fun_val_batch") {
update_fun(optimizer, fr(req$requested_on))
}
### Track progress every 10 iterations
if (req$info$iteration_number > curr_iter) {
curr_iter <- req$info$iteration_number
}
if ((curr_iter %% 10) == 0) {
cat(sprintf(
"Iteration %3d - Current function value: %.3f\n",
req$info$iteration_number, fr(x_opt)))
}
}
cat(sprintf("Current values of x: [%.3f, %.3f]\n",
x_opt[1], x_opt[2]))