optimizer_sgd {keras3} | R Documentation |
Gradient descent (with momentum) optimizer.
Description
Update rule for parameter w
with gradient g
when momentum
is 0:
w <- w - learning_rate * g
Update rule when momentum
is larger than 0:
velocity <- momentum * velocity - learning_rate * g w <- w + velocity
When nesterov=TRUE
, this rule becomes:
velocity <- momentum * velocity - learning_rate * g w <- w + momentum * velocity - learning_rate * g
Usage
optimizer_sgd(
learning_rate = 0.01,
momentum = 0,
nesterov = FALSE,
weight_decay = NULL,
clipnorm = NULL,
clipvalue = NULL,
global_clipnorm = NULL,
use_ema = FALSE,
ema_momentum = 0.99,
ema_overwrite_frequency = NULL,
name = "SGD",
...,
loss_scale_factor = NULL,
gradient_accumulation_steps = NULL
)
Arguments
learning_rate |
A float, a
|
momentum |
float hyperparameter >= 0 that accelerates gradient descent in
the relevant direction and dampens oscillations. 0 is vanilla
gradient descent. Defaults to |
nesterov |
boolean. Whether to apply Nesterov momentum.
Defaults to |
weight_decay |
Float. If set, weight decay is applied. |
clipnorm |
Float. If set, the gradient of each weight is individually clipped so that its norm is no higher than this value. |
clipvalue |
Float. If set, the gradient of each weight is clipped to be no higher than this value. |
global_clipnorm |
Float. If set, the gradient of all weights is clipped so that their global norm is no higher than this value. |
use_ema |
Boolean, defaults to |
ema_momentum |
Float, defaults to 0.99. Only used if |
ema_overwrite_frequency |
Int or NULL, defaults to NULL. Only used if
|
name |
String. The name to use for momentum accumulator weights created by the optimizer. |
... |
For forward/backward compatability. |
loss_scale_factor |
Float or |
gradient_accumulation_steps |
Int or |
Value
an Optimizer
instance
See Also
Other optimizers:
optimizer_adadelta()
optimizer_adafactor()
optimizer_adagrad()
optimizer_adam()
optimizer_adam_w()
optimizer_adamax()
optimizer_ftrl()
optimizer_lion()
optimizer_loss_scale()
optimizer_nadam()
optimizer_rmsprop()