grafzahl {grafzahl}R Documentation

Fine tune a pretrained Transformer model for texts

Description

Fine tune (or train) a pretrained Transformer model for your given training labelled data x and y. The prediction task can be classification (if regression is FALSE, default) or regression (if regression is TRUE).

Usage

grafzahl(
  x,
  y = NULL,
  model_name = "xlm-roberta-base",
  regression = FALSE,
  output_dir,
  cuda = detect_cuda(),
  num_train_epochs = 4,
  train_size = 0.8,
  args = NULL,
  cleanup = TRUE,
  model_type = NULL,
  manual_seed = floor(runif(1, min = 1, max = 721831)),
  verbose = TRUE
)

## Default S3 method:
grafzahl(
  x,
  y = NULL,
  model_name = "xlm-roberta-base",
  regression = FALSE,
  output_dir,
  cuda = detect_cuda(),
  num_train_epochs = 4,
  train_size = 0.8,
  args = NULL,
  cleanup = TRUE,
  model_type = NULL,
  manual_seed = floor(runif(1, min = 1, max = 721831)),
  verbose = TRUE
)

## S3 method for class 'corpus'
grafzahl(
  x,
  y = NULL,
  model_name = "xlm-roberta-base",
  regression = FALSE,
  output_dir,
  cuda = detect_cuda(),
  num_train_epochs = 4,
  train_size = 0.8,
  args = NULL,
  cleanup = TRUE,
  model_type = NULL,
  manual_seed = floor(runif(1, min = 1, max = 721831)),
  verbose = TRUE
)

textmodel_transformer(...)

## S3 method for class 'character'
grafzahl(
  x,
  y = NULL,
  model_name = "xlmroberta",
  regression = FALSE,
  output_dir,
  cuda = detect_cuda(),
  num_train_epochs = 4,
  train_size = 0.8,
  args = NULL,
  cleanup = TRUE,
  model_type = NULL,
  manual_seed = floor(runif(1, min = 1, max = 721831)),
  verbose = TRUE
)

Arguments

x

the corpus or character vector of texts on which the model will be trained. Depending on train_size, some texts will be used for cross-validation.

y

training labels. It can either be a single string indicating which docvars of the corpus is the training labels; a vector of training labels in either character or factor; or NULL if the corpus contains exactly one column in docvars and that column is the training labels. If x is a character vector, y must be a vector of the same length.

model_name

string indicates either 1) the model name on Hugging Face website; 2) the local path of the model

regression

logical, if TRUE, the task is regression, classification otherwise.

output_dir

string, location of the output model. If missing, the model will be stored in a temporary directory. Important: Please note that if this directory exists, it will be overwritten.

cuda

logical, whether to use CUDA, default to detect_cuda().

num_train_epochs

numeric, if train_size is not exactly 1.0, the maximum number of epochs to try in the "early stop" regime will be this number times 5 (i.e. 4 * 5 = 20 by default). If train_size is exactly 1.0, the number of epochs is exactly that.

train_size

numeric, proportion of data in x and y to be used actually for training. The rest will be used for cross validation.

args

list, additionally parameters to be used in the underlying simple transformers

cleanup

logical, if TRUE, the runs directory generated will be removed when the training is done

model_type

a string indicating model_type of the input model. If NULL, it will be inferred from model_name. Supported model types are available in supported_model_types.

manual_seed

numeric, random seed

verbose

logical, if TRUE, debug messages will be displayed

...

paramters pass to grafzahl()

Value

a grafzahl S3 object with the following items

call

original function call

input_data

input_data for the underlying python function

output_dir

location of the output model

model_type

model type

model_name

model name

regression

whether or not it is a regression model

levels

factor levels of y

manual_seed

random seed

meta

metadata about the current session

See Also

predict.grafzahl()

Examples

if (detect_conda() && interactive()) {
library(quanteda)
set.seed(20190721)
## Using the default cross validation method
model1 <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base")
predict(model1)

## Using LIME
input <- corpus(ecosent, text_field = "headline")
training_corpus <- corpus_subset(input, !gold)
model2 <- grafzahl(x = training_corpus,
                 y = "value",
                 model_name = "GroNLP/bert-base-dutch-cased")
test_corpus <- corpus_subset(input, gold)
predicted_sentiment <- predict(model2, test_corpus)
require(lime)
sentences <- c("Dijsselbloem pessimistisch over snelle stappen Grieken",
               "Aandelenbeurzen zetten koersopmars voort")
explainer <- lime(training_corpus, model2)
explanations <- explain(sentences, explainer, n_labels = 1,
                        n_features = 2)
plot_text_explanations(explanations)
}

[Package grafzahl version 0.0.11 Index]