freeze_weights {keras} | R Documentation |
Freeze and unfreeze weights
Description
Freeze weights in a model or layer so that they are no longer trainable.
Usage
freeze_weights(object, from = NULL, to = NULL, which = NULL)
unfreeze_weights(object, from = NULL, to = NULL, which = NULL)
Arguments
object |
Keras model or layer object |
from |
Layer instance, layer name, or layer index within model |
to |
Layer instance, layer name, or layer index within model |
which |
layer names, integer positions, layers, logical vector (of
|
Note
The from
and to
layer arguments are both inclusive.
When applied to a model, the freeze or unfreeze is a global operation over all layers in the model (i.e. layers not within the specified range will be set to the opposite value, e.g. unfrozen for a call to freeze).
Models must be compiled again after weights are frozen or unfrozen.
Examples
## Not run:
conv_base <- application_vgg16(
weights = "imagenet",
include_top = FALSE,
input_shape = c(150, 150, 3)
)
# freeze it's weights
freeze_weights(conv_base)
conv_base
# create a composite model that includes the base + more layers
model <- keras_model_sequential() %>%
conv_base() %>%
layer_flatten() %>%
layer_dense(units = 256, activation = "relu") %>%
layer_dense(units = 1, activation = "sigmoid")
# compile
model %>% compile(
loss = "binary_crossentropy",
optimizer = optimizer_rmsprop(lr = 2e-5),
metrics = c("accuracy")
)
model
print(model, expand_nested = TRUE)
# unfreeze weights from "block5_conv1" on
unfreeze_weights(conv_base, from = "block5_conv1")
# compile again since we froze or unfroze weights
model %>% compile(
loss = "binary_crossentropy",
optimizer = optimizer_rmsprop(lr = 2e-5),
metrics = c("accuracy")
)
conv_base
print(model, expand_nested = TRUE)
# freeze only the last 5 layers
freeze_weights(conv_base, from = -5)
conv_base
# equivalently, also freeze only the last 5 layers
unfreeze_weights(conv_base, to = -6)
conv_base
# Freeze only layers of a certain type, e.g, BatchNorm layers
batch_norm_layer_class_name <- class(layer_batch_normalization())[1]
is_batch_norm_layer <- function(x) inherits(x, batch_norm_layer_class_name)
model <- application_efficientnet_b0()
freeze_weights(model, which = is_batch_norm_layer)
model
# equivalent to:
for(layer in model$layers) {
if(is_batch_norm_layer(layer))
layer$trainable <- FALSE
else
layer$trainable <- TRUE
}
## End(Not run)
[Package keras version 2.15.0 Index]