callback_adaptive_lr {reservr} | R Documentation |
Keras Callback for adaptive learning rate with weight restoration
Description
Provides a keras callback similar to keras3::callback_reduce_lr_on_plateau()
but which also restores the weights
to the best seen so far whenever a learning rate reduction occurs, and with slightly more restrictive improvement
detection.
Usage
callback_adaptive_lr(
monitor = "val_loss",
factor = 0.1,
patience = 10L,
verbose = 0L,
mode = c("auto", "min", "max"),
delta_abs = 1e-04,
delta_rel = 0,
cooldown = 0L,
min_lr = 0,
restore_weights = TRUE
)
Arguments
monitor |
quantity to be monitored. |
factor |
factor by which the learning rate will be reduced. |
patience |
number of epochs with no significant improvement after which the learning rate will be reduced. |
verbose |
integer. Set to 1 to receive update messages. |
mode |
Optimisation mode. "auto" detects the mode from the name of |
delta_abs |
Minimum absolute metric improvement per epoch. The learning rate will be reduced if the average
improvement is less than |
delta_rel |
Minimum relative metric improvement per epoch. The learning rate will be reduced if the average
improvement is less than |
cooldown |
number of epochs to wait before resuming normal operation after learning rate has been reduced.
The minimum number of epochs between two learning rate reductions is |
min_lr |
lower bound for the learning rate. If a learning rate reduction would lower the learning rate below
|
restore_weights |
Bool. If TRUE, the best weights will be restored at each learning rate reduction. This is very useful if the metric oscillates. |
Details
Note that while keras3::callback_reduce_lr_on_plateau()
automatically logs the learning rate as a metric 'lr',
this is currently impossible from R.
Thus, if you want to also log the learning rate, you should add keras3::callback_reduce_lr_on_plateau()
with a high
min_lr
to effectively disable the callback but still monitor the learning rate.
Value
A KerasCallback
suitable for passing to keras3::fit()
.
Examples
dist <- dist_exponential()
group <- sample(c(0, 1), size = 100, replace = TRUE)
x <- dist$sample(100, with_params = list(rate = group + 1))
global_fit <- fit(dist, x)
if (interactive()) {
library(keras3)
l_in <- layer_input(shape = 1L)
mod <- tf_compile_model(
inputs = list(l_in),
intermediate_output = l_in,
dist = dist,
optimizer = optimizer_adam(),
censoring = FALSE,
truncation = FALSE
)
tf_initialise_model(mod, global_fit$params)
fit_history <- fit(
mod,
x = as_tensor(group, config_floatx()),
y = as_trunc_obs(x),
epochs = 20L,
callbacks = list(
callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4),
callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr
)
)
plot(fit_history)
predicted_means <- predict(mod, data = as_tensor(c(0, 1), config_floatx()))
}