| TextEmbeddingModel {aifeducation} | R Documentation |
Text embedding model
Description
This R6 class stores a text embedding model which can be used to tokenize, encode, decode, and embed raw texts. The object provides a unique interface for different text processing methods.
Value
Objects of class TextEmbeddingModel transform raw texts into numerical
representations which can be used for downstream tasks. For this aim objects of this class
allow to tokenize raw texts, to encode tokens to sequences of integers, and to decode sequences
of integers back to tokens.
Public fields
last_training('list()')
List for storing the history and the results of the last training. This information will be overwritten if a new training is started.
Methods
Public methods
Method new()
Method for creating a new text embedding model
Usage
TextEmbeddingModel$new( model_name = NULL, model_label = NULL, model_version = NULL, model_language = NULL, method = NULL, ml_framework = aifeducation_config$get_framework()$TextEmbeddingFramework, max_length = 0, chunks = 1, overlap = 0, emb_layer_min = "middle", emb_layer_max = "2_3_layer", emb_pool_type = "average", model_dir, bow_basic_text_rep, bow_n_dim = 10, bow_n_cluster = 100, bow_max_iter = 500, bow_max_iter_cluster = 500, bow_cr_criterion = 1e-08, bow_learning_rate = 1e-08, trace = FALSE )
Arguments
model_namestringcontaining the name of the new model.model_labelstringcontaining the label/title of the new model.model_versionstringversion of the model.model_languagestringcontaining the language which the model represents (e.g., English).methodstringdetermining the kind of embedding model. Currently the following models are supported:method="bert"for Bidirectional Encoder Representations from Transformers (BERT),method="roberta"for A Robustly Optimized BERT Pretraining Approach (RoBERTa),method="longformer"for Long-Document Transformer,method="funnel"for Funnel-Transformer,method="deberta_v2"for Decoding-enhanced BERT with Disentangled Attention (DeBERTa V2),method="glove"for GlobalVector Clusters, andmethod="lda"for topic modeling. See details for more information.ml_frameworkstringFramework to use for the model.ml_framework="tensorflow"for 'tensorflow' andml_framework="pytorch"for 'pytorch'. Only relevant for transformer models.max_lengthintdetermining the maximum length of token sequences used in transformer models. Not relevant for the other methods.chunksintMaximum number of chunks. Only relevant for transformer models.overlapintdetermining the number of tokens which should be added at the beginning of the next chunk. Only relevant for BERT models.emb_layer_minintorstringdetermining the first layer to be included in the creation of embeddings. An integer correspondents to the layer number. The first layer has the number 1. Instead of an integer the following strings are possible:"start"for the first layer,"middle"for the middle layer,"2_3_layer"for the layer two-third layer, and"last"for the last layer.emb_layer_maxintorstringdetermining the last layer to be included in the creation of embeddings. An integer correspondents to the layer number. The first layer has the number 1. Instead of an integer the following strings are possible:"start"for the first layer,"middle"for the middle layer,"2_3_layer"for the layer two-third layer, and"last"for the last layer.emb_pool_typestringdetermining the method for pooling the token embeddings within each layer. If"cls"only the embedding of the CLS token is used. If"average"the token embedding of all tokens are averaged (excluding padding tokens).model_dirstringpath to the directory where the BERT model is stored.bow_basic_text_repobject of class
basic_text_repcreated via the function bow_pp_create_basic_text_rep. Only relevant formethod="glove_cluster"andmethod="lda".bow_n_dimintNumber of dimensions of the GlobalVector or number of topics for LDA.bow_n_clusterintNumber of clusters created on the basis of GlobalVectors. Parameter is not relevant formethod="lda"andmethod="bert"bow_max_iterintMaximum number of iterations for fitting GlobalVectors and Topic Models.bow_max_iter_clusterintMaximum number of iterations for fitting cluster ifmethod="glove".bow_cr_criteriondoubleconvergence criterion for GlobalVectors.bow_learning_ratedoubleinitial learning rate for GlobalVectors.traceboolTRUEprints information about the progress.FALSEdoes not.
Details
method: In the case of
method="bert",method="roberta", andmethod="longformer", a pretrained transformer model must be supplied viamodel_dir. Formethod="glove"andmethod="lda"a new model will be created based on the data provided viabow_basic_text_rep. The original algorithm for GlobalVectors provides only word embeddings, not text embeddings. To achieve text embeddings the words are clustered based on their word embeddings with kmeans.
Returns
Returns an object of class TextEmbeddingModel.
Method load_model()
Method for loading a transformers model into R.
Usage
TextEmbeddingModel$load_model(model_dir, ml_framework = "auto")
Arguments
model_dirstringcontaining the path to the relevant model directory.ml_frameworkstringDetermines the machine learning framework for using the model. Possible areml_framework="pytorch"for 'pytorch',ml_framework="tensorflow"for 'tensorflow', andml_framework="auto".
Returns
Function does not return a value. It is used for loading a saved transformer model into the R interface.
Method save_model()
Method for saving a transformer model on disk.Relevant only for transformer models.
Usage
TextEmbeddingModel$save_model(model_dir, save_format = "default")
Arguments
model_dirstringcontaining the path to the relevant model directory.save_formatFormat for saving the model. For 'tensorflow'/'keras' models
"h5"for HDF5. For 'pytorch' models"safetensors"for 'safetensors' or"pt"for 'pytorch' via pickle. Use"default"for the standard format. This is h5 for 'tensorflow'/'keras' models and safetensors for 'pytorch' models.
Returns
Function does not return a value. It is used for saving a transformer model to disk.
Method encode()
Method for encoding words of raw texts into integers.
Usage
TextEmbeddingModel$encode( raw_text, token_encodings_only = FALSE, to_int = TRUE, trace = FALSE )
Arguments
raw_textvectorcontaining the raw texts.token_encodings_onlyboolIfTRUE, only the token encodings are returned. IfFALSE, the complete encoding is returned which is important for BERT models.to_intboolIfTRUEthe integer ids of the tokens are returned. IfFALSEthe tokens are returned. Argument only applies for transformer models and iftoken_encodings_only==TRUE.traceboolIfTRUE, information of the progress is printed.FALSEif not requested.
Returns
list containing the integer sequences of the raw texts with
special tokens.
Method decode()
Method for decoding a sequence of integers into tokens
Usage
TextEmbeddingModel$decode(int_seqence, to_token = FALSE)
Arguments
int_seqencelistcontaining the integer sequences which should be transformed to tokens or plain text.to_tokenboolIfFALSEa plain text is returned. ifTRUEa sequence of tokens is returned. Argument only relevant if the model is based on a transformer.
Returns
list of token sequences
Method get_special_tokens()
Method for receiving the special tokens of the model
Usage
TextEmbeddingModel$get_special_tokens()
Returns
Returns a matrix containing the special tokens in the rows
and their type, token, and id in the columns.
Method embed()
Method for creating text embeddings from raw texts
In the case of using a GPU and running out of memory reduce the batch size or restart R and switch to use cpu only via set_config_cpu_only.
Usage
TextEmbeddingModel$embed( raw_text = NULL, doc_id = NULL, batch_size = 8, trace = FALSE )
Arguments
raw_textvectorcontaining the raw texts.doc_idvectorcontaining the corresponding IDs for every text.batch_sizeintdetermining the maximal size of every batch.traceboolTRUE, if information about the progression should be printed on console.
Returns
Method returns a R6 object of class EmbeddedText. This object
contains the embeddings as a data.frame and information about the
model creating the embeddings.
Method fill_mask()
Method for calculating tokens behind mask tokens.
Usage
TextEmbeddingModel$fill_mask(text, n_solutions = 5)
Arguments
textstringText containing mask tokens.n_solutionsintNumber estimated tokens for every mask.
Returns
Returns a list containing a data.frame for every
mask. The data.frame contains the solutions in the rows and reports
the score, token id, and token string in the columns.
Method set_publication_info()
Method for setting the bibliographic information of the model.
Usage
TextEmbeddingModel$set_publication_info(type, authors, citation, url = NULL)
Arguments
typestringType of information which should be changed/added.type="developer", andtype="modifier"are possible.authorsList of people.
citationstringCitation in free text.urlstringCorresponding URL if applicable.
Returns
Function does not return a value. It is used to set the private members for publication information of the model.
Method get_publication_info()
Method for getting the bibliographic information of the model.
Usage
TextEmbeddingModel$get_publication_info()
Returns
list of bibliographic information.
Method set_software_license()
Method for setting the license of the model
Usage
TextEmbeddingModel$set_software_license(license = "GPL-3")
Arguments
licensestringcontaining the abbreviation of the license or the license text.
Returns
Function does not return a value. It is used for setting the private member for the software license of the model.
Method get_software_license()
Method for requesting the license of the model
Usage
TextEmbeddingModel$get_software_license()
Returns
string License of the model
Method set_documentation_license()
Method for setting the license of models' documentation.
Usage
TextEmbeddingModel$set_documentation_license(license = "CC BY-SA")
Arguments
licensestringcontaining the abbreviation of the license or the license text.
Returns
Function does not return a value. It is used to set the private member for the documentation license of the model.
Method get_documentation_license()
Method for getting the license of the models' documentation.
Usage
TextEmbeddingModel$get_documentation_license()
Arguments
licensestringcontaining the abbreviation of the license or the license text.
Method set_model_description()
Method for setting a description of the model
Usage
TextEmbeddingModel$set_model_description( eng = NULL, native = NULL, abstract_eng = NULL, abstract_native = NULL, keywords_eng = NULL, keywords_native = NULL )
Arguments
engstringA text describing the training of the classifier, its theoretical and empirical background, and the different output labels in English.nativestringA text describing the training of the classifier, its theoretical and empirical background, and the different output labels in the native language of the model.abstract_engstringA text providing a summary of the description in English.abstract_nativestringA text providing a summary of the description in the native language of the classifier.keywords_engvectorof keywords in English.keywords_nativevectorof keywords in the native language of the classifier.
Returns
Function does not return a value. It is used to set the private members for the description of the model.
Method get_model_description()
Method for requesting the model description.
Usage
TextEmbeddingModel$get_model_description()
Returns
list with the description of the model in English
and the native language.
Method get_model_info()
Method for requesting the model information
Usage
TextEmbeddingModel$get_model_info()
Returns
list of all relevant model information
Method get_package_versions()
Method for requesting a summary of the R and python packages' versions used for creating the classifier.
Usage
TextEmbeddingModel$get_package_versions()
Returns
Returns a list containing the versions of the relevant
R and python packages.
Method get_basic_components()
Method for requesting the part of interface's configuration that is necessary for all models.
Usage
TextEmbeddingModel$get_basic_components()
Returns
Returns a list.
Method get_bow_components()
Method for requesting the part of interface's configuration that is necessary bag-of-words models.
Usage
TextEmbeddingModel$get_bow_components()
Returns
Returns a list.
Method get_transformer_components()
Method for requesting the part of interface's configuration that is necessary for transformer models.
Usage
TextEmbeddingModel$get_transformer_components()
Returns
Returns a list.
Method get_sustainability_data()
Method for requesting a log of tracked energy consumption during training and an estimate of the resulting CO2 equivalents in kg.
Usage
TextEmbeddingModel$get_sustainability_data()
Returns
Returns a matrix containing the tracked energy consumption,
CO2 equivalents in kg, information on the tracker used, and technical
information on the training infrastructure for every training run.
Method get_ml_framework()
Method for requesting the machine learning framework used for the classifier.
Usage
TextEmbeddingModel$get_ml_framework()
Returns
Returns a string describing the machine learning framework used
for the classifier
Method clone()
The objects of this class are cloneable with this method.
Usage
TextEmbeddingModel$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
See Also
Other Text Embedding:
EmbeddedText,
combine_embeddings()