get_dl_model {rTLsDeep} | R Documentation |
Selecting deep learning modeling approaches
Description
This function selects and returns the deep learning approach to be used with the fit_dl_model function for post-hurricane individual tree-level damage classification.
Usage
get_dl_model(
model_type = "vgg",
img_width = 256,
img_height = 256,
lr_rate = 1e-04,
tensorflow_dir = NA,
channels,
class_list
)
Arguments
model_type |
A character string describing the deep learning model to be used. Available models: "vgg", "resnet", "inception", "densenet", "efficientnet", "simple". |
img_width |
A numeric value describing the width of the image used for training. Default: 256. |
img_height |
A numeric value describing the height of the image used for training. Default: 256. |
lr_rate |
A numeric value indicating the learning rate. Default: 0.0001. |
tensorflow_dir |
A character string indicating the directory for the tensorflow python environment. Guide to install the environment here: https://doi.org/10.5281/zenodo.3929709. Default = NA. |
channels |
A numeric value for the number of channels/bands of the input images. |
class_list |
A character string or numeric value describing the post-hurricane individual tree level damage classes, e.g.: c("1","2","3","4","5","6"). |
Value
Returns a list containing the model object with the required parameters and model_type used.
Examples
# Set directory to tensorflow (python environment)
# This is required if running deep learning local computer with GPU
# Guide to install here: https://doi.org/10.5281/zenodo.3929709
tensorflow_dir = NA
# define model type
model_type = "simple"
#model_type = "vgg"
#model_type = "inception"
#model_type = "resnet"
#model_type = "densenet"
#model_type = "efficientnet"
train_image_files_path = system.file('extdata', 'train', package='rTLsDeep')
test_image_files_path = system.file('extdata', 'validation', package='rTLsDeep')
img_width <- 256
img_height <- 256
class_list_train = unique(list.files(train_image_files_path))
class_list_test = unique(list.files(test_image_files_path))
lr_rate = 0.0001
target_size <- c(img_width, img_height)
channels = 4
# get model
if (reticulate::py_module_available('tensorflow') == FALSE)
{
tensorflow::install_tensorflow()
}
model = get_dl_model(model_type=model_type,
img_width=img_width,
img_height=img_height,
channels=channels,
lr_rate = lr_rate,
tensorflow_dir = tensorflow_dir,
class_list = class_list_train)