| TextEmbeddingClassifierNeuralNet {aifeducation} | R Documentation |
Text embedding classifier with a neural net
Description
Abstract class for neural nets with 'keras'/'tensorflow' and 'pytorch'.
Value
Objects of this class are used for assigning texts to classes/categories. For
the creation and training of a classifier an object of class EmbeddedText and a factor
are necessary. The object of class EmbeddedText contains the numerical text
representations (text embeddings) of the raw texts generated by an object of class
TextEmbeddingModel. The factor contains the classes/categories for every
text. Missing values (unlabeled cases) are supported. For predictions an object of class
EmbeddedText has to be used which was created with the same text embedding model as
for training.
Public fields
model('tensorflow_model()')
Field for storing the tensorflow model after loading.model_config('list()')
List for storing information about the configuration of the model. This information is used to predict new data.model_config$n_rec:Number of recurrent layers.model_config$n_hidden:Number of dense layers.model_config$target_levels:Levels of the target variable. Do not change this manually.model_config$input_variables:Order and name of the input variables. Do not change this manually.model_config$init_config:List storing all parameters passed to method new().
last_training('list()')
List for storing the history and the results of the last training. This information will be overwritten if a new training is started.last_training$learning_time:Duration of the training process.config$history:History of the last training.config$data:Object of class table storing the initial frequencies of the passed data.config$data_pb:lMatrix storing the number of additional cases (test and training) added during balanced pseudo-labeling. The rows refer to folds and final training. The columns refer to the steps during pseudo-labeling.config$data_bsc_test:Matrix storing the number of cases for each category used for testing during the phase of balanced synthetic units. Please note that the frequencies include original and synthetic cases. In case the number of original and synthetic cases exceeds the limit for the majority classes, the frequency represents the number of cases created by cluster analysis.config$date:Time when the last training finished.config$config:List storing which kind of estimation was requested during the last training.config$config$use_bsc:TRUEif balanced synthetic cases were requested.FALSEif not.config$config$use_baseline:TRUEif baseline estimation were requested.FALSEif not.config$config$use_bpl:TRUEif balanced, pseudo-labeling cases were requested.FALSEif not.
reliability('list()')
List for storing central reliability measures of the last training.reliability$test_metric:Array containing the reliability measures for the validation data for every fold, method, and step (in case of pseudo-labeling).reliability$test_metric_mean:Array containing the reliability measures for the validation data for every method and step (in case of pseudo-labeling). The values represent the mean values for every fold.reliability$raw_iota_objects:List containing all iota_object generated with the packageiotarelrfor every fold at the start and the end of the last training.reliability$raw_iota_objects$iota_objects_start:List of objects with classiotarelr_iota2containing the estimated iota reliability of the second generation for the baseline model for every fold. If the estimation of the baseline model is not requested, the list is set toNULL.reliability$raw_iota_objects$iota_objects_end:List of objects with classiotarelr_iota2containing the estimated iota reliability of the second generation for the final model for every fold. Depending of the requested training method these values refer to the baseline model, a trained model on the basis of balanced synthetic cases, balanced pseudo labeling or a combination of balanced synthetic cases with pseudo labeling.reliability$raw_iota_objects$iota_objects_start_free:List of objects with classiotarelr_iota2containing the estimated iota reliability of the second generation for the baseline model for every fold. If the estimation of the baseline model is not requested, the list is set toNULL.Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the assumption of weak superiority.reliability$raw_iota_objects$iota_objects_end_free:List of objects with classiotarelr_iota2containing the estimated iota reliability of the second generation for the final model for every fold. Depending of the requested training method, these values refer to the baseline model, a trained model on the basis of balanced synthetic cases, balanced pseudo-labeling or a combination of balanced synthetic cases and pseudo-labeling. Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the assumption of weak superiority.
reliability$iota_object_start:Object of classiotarelr_iota2as a mean of the individual objects for every fold. If the estimation of the baseline model is not requested, the list is set toNULL.reliability$iota_object_start_free:Object of classiotarelr_iota2as a mean of the individual objects for every fold. If the estimation of the baseline model is not requested, the list is set toNULL. Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the assumption of weak superiority.reliability$iota_object_end:Object of classiotarelr_iota2as a mean of the individual objects for every fold. Depending on the requested training method, this object refers to the baseline model, a trained model on the basis of balanced synthetic cases, balanced pseudo-labeling or a combination of balanced synthetic cases and pseudo-labeling.reliability$iota_object_end_free:Object of classiotarelr_iota2as a mean of the individual objects for every fold. Depending on the requested training method, this object refers to the baseline model, a trained model on the basis of balanced synthetic cases, balanced pseudo-labeling or a combination of balanced synthetic cases and pseudo-labeling. Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the assumption of weak superiority.reliability$standard_measures_end:Object of classlistcontaining the final measures for precision, recall, and f1 for every fold. Depending of the requested training method, these values refer to the baseline model, a trained model on the basis of balanced synthetic cases, balanced pseudo-labeling or a combination of balanced synthetic cases and pseudo-labeling.reliability$standard_measures_mean:matrixcontaining the mean measures for precision, recall, and f1 at the end of every fold.
Methods
Public methods
-
TextEmbeddingClassifierNeuralNet$set_documentation_license() -
TextEmbeddingClassifierNeuralNet$get_documentation_license()
Method new()
Creating a new instance of this class.
Usage
TextEmbeddingClassifierNeuralNet$new( ml_framework = aifeducation_config$get_framework(), name = NULL, label = NULL, text_embeddings = NULL, targets = NULL, hidden = c(128), rec = c(128), self_attention_heads = 0, intermediate_size = NULL, attention_type = "fourier", add_pos_embedding = TRUE, rec_dropout = 0.1, repeat_encoder = 1, dense_dropout = 0.4, recurrent_dropout = 0.4, encoder_dropout = 0.1, optimizer = "adam" )
Arguments
ml_frameworkstringFramework to use for training and inference.ml_framework="tensorflow"for 'tensorflow' andml_framework="pytorch"for 'pytorch'nameCharacterName of the new classifier. Please refer to common name conventions. Free text can be used with parameterlabel.labelCharacterLabel for the new classifier. Here you can use free text.text_embeddingsAn object of class
TextEmbeddingModel.targetsfactorcontaining the target values of the classifier.hiddenvectorcontaining the number of neurons for each dense layer. The length of the vector determines the number of dense layers. If you want no dense layer, set this parameter toNULL.recvectorcontaining the number of neurons for each recurrent layer. The length of the vector determines the number of dense layers. If you want no dense layer, set this parameter toNULL.self_attention_headsintegerdetermining the number of attention heads for a self-attention layer. Only relevant ifattention_type="multihead"intermediate_sizeintdetermining the size of the projection layer within a each transformer encoder.attention_typestringChoose the relevant attention type. Possible values are"fourier"andmultihead.add_pos_embeddingboolTRUEif positional embedding should be used.rec_dropoutdoubleranging between 0 and lower 1, determining the dropout between bidirectional gru layers.repeat_encoderintdetermining how many times the encoder should be added to the network.dense_dropoutdoubleranging between 0 and lower 1, determining the dropout between dense layers.recurrent_dropoutdoubleranging between 0 and lower 1, determining the recurrent dropout for each recurrent layer. Only relevant for keras models.encoder_dropoutdoubleranging between 0 and lower 1, determining the dropout for the dense projection within the encoder layers.optimizerObject of class
keras.optimizers.
Returns
Returns an object of class TextEmbeddingClassifierNeuralNet which is ready for training.
Method train()
Method for training a neural net.
Usage
TextEmbeddingClassifierNeuralNet$train(
data_embeddings,
data_targets,
data_n_test_samples = 5,
balance_class_weights = TRUE,
use_baseline = TRUE,
bsl_val_size = 0.25,
use_bsc = TRUE,
bsc_methods = c("dbsmote"),
bsc_max_k = 10,
bsc_val_size = 0.25,
bsc_add_all = FALSE,
use_bpl = TRUE,
bpl_max_steps = 3,
bpl_epochs_per_step = 1,
bpl_dynamic_inc = FALSE,
bpl_balance = FALSE,
bpl_max = 1,
bpl_anchor = 1,
bpl_min = 0,
bpl_weight_inc = 0.02,
bpl_weight_start = 0,
bpl_model_reset = FALSE,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15,
epochs = 40,
batch_size = 32,
dir_checkpoint,
trace = TRUE,
keras_trace = 2,
pytorch_trace = 2,
n_cores = 2
)Arguments
data_embeddingsObject of class
TextEmbeddingModel.data_targetsFactorcontaining the labels for cases stored indata_embeddings. Factor must be named and has to use the same names used indata_embeddings.data_n_test_samplesintdetermining the number of cross-fold samples.balance_class_weightsboolIfTRUEclass weights are generated based on the frequencies of the training data with the method Inverse Class Frequency'. IfFALSEeach class has the weight 1.use_baselineboolTRUEif the calculation of a baseline model is requested. This option is only relevant foruse_bsc=TRUEoruse_pbl=TRUE. If both areFALSE, a baseline model is calculated.bsl_val_sizedoublebetween 0 and 1, indicating the proportion of cases of each class which should be used for the validation sample during the estimation of the baseline model. The remaining cases are part of the training data.use_bscboolTRUEif the estimation should integrate balanced synthetic cases.FALSEif not.bsc_methodsvectorcontaining the methods for generating synthetic cases via 'smotefamily'. Multiple methods can be passed. Currentlybsc_methods=c("adas"),bsc_methods=c("smote")andbsc_methods=c("dbsmote")are possible.bsc_max_kintdetermining the maximal number of k which is used for creating synthetic units.bsc_val_sizedoublebetween 0 and 1, indicating the proportion of cases of each class which should be used for the validation sample during the estimation with synthetic cases.bsc_add_allboolIfFALSEonly synthetic cases necessary to fill the gab between the class and the major class are added to the data. IfTRUEall generated synthetic cases are added to the data.use_bplboolTRUEif the estimation should integrate balanced pseudo-labeling.FALSEif not.bpl_max_stepsintdetermining the maximum number of steps during pseudo-labeling.bpl_epochs_per_stepintNumber of training epochs within every step.bpl_dynamic_incboolIfTRUE, only a specific percentage of cases is included during each step. The percentage is determined bystep/bpl_max_steps. IfFALSE, all cases are used.bpl_balanceboolIfTRUE, the same number of cases for every category/class of the pseudo-labeled data are used with training. That is, the number of cases is determined by the minor class/category.bpl_maxdoublebetween 0 and 1, setting the maximal level of confidence for considering a case for pseudo-labeling.bpl_anchordoublebetween 0 and 1 indicating the reference point for sorting the new cases of every label. See notes for more details.bpl_mindoublebetween 0 and 1, setting the minimal level of confidence for considering a case for pseudo-labeling.bpl_weight_incdoublevalue how much the sample weights should be increased for the cases with pseudo-labels in every step.bpl_weight_startdobuleStarting value for the weights of the unlabeled cases.bpl_model_resetboolIfTRUE, model is re-initialized at every step.sustain_trackboolIfTRUEenergy consumption is tracked during training via the python library codecarbon.sustain_iso_codestringISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes.sustain_regionRegion within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html
sustain_intervalintegerInterval in seconds for measuring power usage.epochsintNumber of training epochs.batch_sizeintSize of batches.dir_checkpointstringPath to the directory where the checkpoint during training should be saved. If the directory does not exist, it is created.traceboolTRUE, if information about the estimation phase should be printed to the console.keras_traceintkeras_trace=0does not print any information about the training process from keras on the console.pytorch_traceintpytorch_trace=0does not print any information about the training process from pytorch on the console.pytorch_trace=1prints a progress bar.pytorch_trace=2prints one line of information for every epoch.n_coresintNumber of cores used for creating synthetic units.
Details
bsc_max_k:All values from 2 up to bsc_max_k are successively used. If the number of bsc_max_k is too high, the value is reduced to a number that allows the calculating of synthetic units.bpl_anchor:With the help of this value, the new cases are sorted. For this aim, the distance from the anchor is calculated and all cases are arranged into an ascending order.
Returns
Function does not return a value. It changes the object into a trained classifier.
Method predict()
Method for predicting new data with a trained neural net.
Usage
TextEmbeddingClassifierNeuralNet$predict(newdata, batch_size = 32, verbose = 1)
Arguments
newdataObject of class
TextEmbeddingModelordata.framefor which predictions should be made.batch_sizeintSize of batches.verboseintverbose=0does not cat any information about the training process from keras on the console.verbose=1prints a progress bar.verbose=2prints one line of information for every epoch.
Returns
Returns a data.frame containing the predictions and
the probabilities of the different labels for each case.
Method check_embedding_model()
Method for checking if the provided text embeddings are created with the same TextEmbeddingModel as the classifier.
Usage
TextEmbeddingClassifierNeuralNet$check_embedding_model(text_embeddings)
Arguments
text_embeddingsObject of class EmbeddedText.
Returns
TRUE if the underlying TextEmbeddingModel are the same.
FALSE if the models differ.
Method get_model_info()
Method for requesting the model information
Usage
TextEmbeddingClassifierNeuralNet$get_model_info()
Returns
list of all relevant model information
Method get_text_embedding_model()
Method for requesting the text embedding model information
Usage
TextEmbeddingClassifierNeuralNet$get_text_embedding_model()
Returns
list of all relevant model information on the text embedding model
underlying the classifier
Method set_publication_info()
Method for setting publication information of the classifier
Usage
TextEmbeddingClassifierNeuralNet$set_publication_info( authors, citation, url = NULL )
Arguments
authorsList of authors.
citationFree text citation.
urlURL of a corresponding homepage.
Returns
Function does not return a value. It is used for setting the private members for publication information.
Method get_publication_info()
Method for requesting the bibliographic information of the classifier.
Usage
TextEmbeddingClassifierNeuralNet$get_publication_info()
Returns
list with all saved bibliographic information.
Method set_software_license()
Method for setting the license of the classifier.
Usage
TextEmbeddingClassifierNeuralNet$set_software_license(license = "GPL-3")
Arguments
licensestringcontaining the abbreviation of the license or the license text.
Returns
Function does not return a value. It is used for setting the private member for the software license of the model.
Method get_software_license()
Method for getting the license of the classifier.
Usage
TextEmbeddingClassifierNeuralNet$get_software_license()
Arguments
licensestringcontaining the abbreviation of the license or the license text.
Returns
string representing the license for the software.
Method set_documentation_license()
Method for setting the license of the classifier's documentation.
Usage
TextEmbeddingClassifierNeuralNet$set_documentation_license( license = "CC BY-SA" )
Arguments
licensestringcontaining the abbreviation of the license or the license text.
Returns
Function does not return a value. It is used for setting the private member for the documentation license of the model.
Method get_documentation_license()
Method for getting the license of the classifier's documentation.
Usage
TextEmbeddingClassifierNeuralNet$get_documentation_license()
Arguments
licensestringcontaining the abbreviation of the license or the license text.
Returns
Returns the license as a string.
Method set_model_description()
Method for setting a description of the classifier.
Usage
TextEmbeddingClassifierNeuralNet$set_model_description( eng = NULL, native = NULL, abstract_eng = NULL, abstract_native = NULL, keywords_eng = NULL, keywords_native = NULL )
Arguments
engstringA text describing the training of the learner, its theoretical and empirical background, and the different output labels in English.nativestringA text describing the training of the learner, its theoretical and empirical background, and the different output labels in the native language of the classifier.abstract_engstringA text providing a summary of the description in English.abstract_nativestringA text providing a summary of the description in the native language of the classifier.keywords_engvectorof keyword in English.keywords_nativevectorof keyword in the native language of the classifier.
Returns
Function does not return a value. It is used for setting the private members for the description of the model.
Method get_model_description()
Method for requesting the model description.
Usage
TextEmbeddingClassifierNeuralNet$get_model_description()
Returns
list with the description of the classifier in English
and the native language.
Method save_model()
Method for saving a model to 'Keras v3 format', 'tensorflow' SavedModel format or h5 format.
Usage
TextEmbeddingClassifierNeuralNet$save_model(dir_path, save_format = "default")
Arguments
dir_pathstring()Path of the directory where the model should be saved.save_formatFormat for saving the model. For 'tensorflow'/'keras' models
"keras"for 'Keras v3 format',"tf"for SavedModel or"h5"for HDF5. For 'pytorch' models"safetensors"for 'safetensors' or"pt"for 'pytorch' via pickle. Use"default"for the standard format. This is keras for 'tensorflow'/'keras' models and safetensors for 'pytorch' models.
Returns
Function does not return a value. It saves the model to disk.
Method load_model()
Method for importing a model from 'Keras v3 format', 'tensorflow' SavedModel format or h5 format.
Usage
TextEmbeddingClassifierNeuralNet$load_model(dir_path, ml_framework = "auto")
Arguments
dir_pathstring()Path of the directory where the model is saved.ml_frameworkstringDetermines the machine learning framework for using the model. Possible areml_framework="pytorch"for 'pytorch',ml_framework="tensorflow"for 'tensorflow', andml_framework="auto".
Returns
Function does not return a value. It is used to load the weights of a model.
Method get_package_versions()
Method for requesting a summary of the R and python packages' versions used for creating the classifier.
Usage
TextEmbeddingClassifierNeuralNet$get_package_versions()
Returns
Returns a list containing the versions of the relevant
R and python packages.
Method get_sustainability_data()
Method for requesting a summary of tracked energy consumption during training and an estimate of the resulting CO2 equivalents in kg.
Usage
TextEmbeddingClassifierNeuralNet$get_sustainability_data()
Returns
Returns a list containing the tracked energy consumption,
CO2 equivalents in kg, information on the tracker used, and technical
information on the training infrastructure.
Method get_ml_framework()
Method for requesting the machine learning framework used for the classifier.
Usage
TextEmbeddingClassifierNeuralNet$get_ml_framework()
Returns
Returns a string describing the machine learning framework used
for the classifier
Method clone()
The objects of this class are cloneable with this method.
Usage
TextEmbeddingClassifierNeuralNet$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.