get_dl_model {rTLsDeep}R Documentation

Selecting deep learning modeling approaches

Description

This function selects and returns the deep learning approach to be used with the fit_dl_model function for post-hurricane individual tree-level damage classification.

Usage

get_dl_model(
  model_type = "vgg",
  img_width = 256,
  img_height = 256,
  lr_rate = 1e-04,
  tensorflow_dir = NA,
  channels,
  class_list
)

Arguments

model_type

A character string describing the deep learning model to be used. Available models: "vgg", "resnet", "inception", "densenet", "efficientnet", "simple".

img_width

A numeric value describing the width of the image used for training. Default: 256.

img_height

A numeric value describing the height of the image used for training. Default: 256.

lr_rate

A numeric value indicating the learning rate. Default: 0.0001.

tensorflow_dir

A character string indicating the directory for the tensorflow python environment. Guide to install the environment here: https://doi.org/10.5281/zenodo.3929709. Default = NA.

channels

A numeric value for the number of channels/bands of the input images.

class_list

A character string or numeric value describing the post-hurricane individual tree level damage classes, e.g.: c("1","2","3","4","5","6").

Value

Returns a list containing the model object with the required parameters and model_type used.

Examples


# Set directory to tensorflow (python environment)
# This is required if running deep learning local computer with GPU
# Guide to install here: https://doi.org/10.5281/zenodo.3929709
tensorflow_dir = NA

# define model type
model_type = "simple"
#model_type = "vgg"
#model_type = "inception"
#model_type = "resnet"
#model_type = "densenet"
#model_type = "efficientnet"

train_image_files_path = system.file('extdata', 'train', package='rTLsDeep')
test_image_files_path = system.file('extdata', 'validation', package='rTLsDeep')
img_width <- 256
img_height <- 256
class_list_train = unique(list.files(train_image_files_path))
class_list_test = unique(list.files(test_image_files_path))
lr_rate = 0.0001
target_size <- c(img_width, img_height)
channels = 4

# get model
if (reticulate::py_module_available('tensorflow') == FALSE)
{
 tensorflow::install_tensorflow()
}
model = get_dl_model(model_type=model_type,
                    img_width=img_width,
                    img_height=img_height,
                    channels=channels,
                    lr_rate = lr_rate,
                    tensorflow_dir = tensorflow_dir,
                    class_list = class_list_train)



[Package rTLsDeep version 0.0.5 Index]