adaQN_free {stochQN}R Documentation

adaQN Free-Mode Optimizer

Description

Optimizes an empirical (perhaps non-convex) loss function over batches of sample data. Compared to function/class 'adaQN', this version lets the user do all the calculations from the outside, only interacting with the object by means of a function that returns a request type and is fed the required calculation through methods 'update_gradient' and 'update_function'.

Order in which requests are made:

========== loop ===========

* calc_grad

⁠ ⁠... (repeat calc_grad)

if max_incr > 0:

⁠ ⁠* calc_fun_val_batch

if 'use_grad_diff':

⁠ ⁠* calc_grad_big_batch (skipped if below max_incr)

===========================

After running this function, apply 'run_adaQN_free' to it to get the first requested piece of information.

Usage

adaQN_free(mem_size = 10, fisher_size = 100, bfgs_upd_freq = 20,
  max_incr = 1.01, min_curvature = 1e-04, scal_reg = 1e-04,
  rmsprop_weight = 0.9, y_reg = NULL, use_grad_diff = FALSE,
  check_nan = TRUE, nthreads = -1)

Arguments

mem_size

Number of correction pairs to store for approximation of Hessian-vector products.

fisher_size

Number of gradients to store for calculation of the empirical Fisher product with gradients. If passing 'NULL', will force 'use_grad_diff' to 'TRUE'.

bfgs_upd_freq

Number of iterations (batches) after which to generate a BFGS correction pair.

max_incr

Maximum ratio of function values in the validation set under the average values of 'x' during current epoch vs. previous epoch. If the ratio is above this threshold, the BFGS and Fisher memories will be reset, and 'x' values reverted to their previous average. Pass 'NULL' for no function-increase checking.

min_curvature

Minimum value of (s * y) / (s * s) in order to accept a correction pair. Pass 'NULL' for no minimum.

scal_reg

Regularization parameter to use in the denominator for AdaGrad and RMSProp scaling.

rmsprop_weight

If not 'NULL', will use RMSProp formula instead of AdaGrad for approximated inverse-Hessian initialization.

y_reg

Regularizer for 'y' vector (gets added y_reg * s). Pass 'NULL' for no regularization.

use_grad_diff

Whether to create the correction pairs using differences between gradients instead of Fisher matrix. These gradients are calculated on a larger batch than the regular ones (given by batch_size * bfgs_upd_freq). If 'TRUE', empirical Fisher matrix will not be used.

check_nan

Whether to check for variables becoming NaN after each iteration, and reverting the step if they do (will also reset BFGS and Fisher memory).

nthreads

Number of parallel threads to use. If set to -1, will determine the number of available threads and use all of them. Note however that not all the computations can be parallelized, and the BLAS backend might use a different number of threads.

Value

An 'adaQN_free' object, which can be used through functions 'update_gradient', 'update_fun', and 'run_adaQN_free'

References

See Also

update_gradient , update_fun , run_adaQN_free

Examples

### Example optimizing Rosenbrock 2D function
### Note that this example is not stochastic, as the
### function is not evaluated in expectation based on
### batches of data, but rather it has a given absolute
### form that never varies.
library(stochQN)


fr <- function(x) { ## Rosenbrock Banana function
	x1 <- x[1]
	x2 <- x[2]
	100 * (x2 - x1 * x1)^2 + (1 - x1)^2
}
grr <- function(x) { ## Gradient of 'fr'
	x1 <- x[1]
	x2 <- x[2]
	c(-400 * x1 * (x2 - x1 * x1) - 2 * (1 - x1),
	  200 * (x2 - x1 * x1))
}


### Initial values of x
x_opt = as.numeric(c(0, 2))
cat(sprintf("Initial values of x: [%.3f, %.3f]\n",
			x_opt[1], x_opt[2]))

### Will use constant step size throughout
### (not recommended)
step_size = 1e-2

### Initialize the optimizer
optimizer = adaQN_free()

### Keep track of the iteration number
curr_iter <- 0

### Run a loop for many iterations
### (Note that some iterations might require more
###  than 1 calculation request)
for (i in 1:200) {
	req <- run_adaQN_free(optimizer, x_opt, step_size)
	if (req$task == "calc_grad") {
	  update_gradient(optimizer, grr(req$requested_on))
	} else if (req$task == "calc_fun_val_batch") {
	  update_fun(optimizer, fr(req$requested_on))
	}

	### Track progress every 10 iterations
	if (req$info$iteration_number > curr_iter) {
		curr_iter <- req$info$iteration_number
	}
	if ((curr_iter %% 10) == 0) {
		cat(sprintf(
		  "Iteration %3d - Current function value: %.3f\n",
		  req$info$iteration_number, fr(x_opt)))
	}
}
cat(sprintf("Current values of x: [%.3f, %.3f]\n",
			x_opt[1], x_opt[2]))

[Package stochQN version 0.1.2-1 Index]