grafzahl {grafzahl} | R Documentation |
Fine tune a pretrained Transformer model for texts
Description
Fine tune (or train) a pretrained Transformer model for your given training labelled data x
and y
. The prediction task can be classification (if regression
is FALSE
, default) or regression (if regression
is TRUE
).
Usage
grafzahl(
x,
y = NULL,
model_name = "xlm-roberta-base",
regression = FALSE,
output_dir,
cuda = detect_cuda(),
num_train_epochs = 4,
train_size = 0.8,
args = NULL,
cleanup = TRUE,
model_type = NULL,
manual_seed = floor(runif(1, min = 1, max = 721831)),
verbose = TRUE
)
## Default S3 method:
grafzahl(
x,
y = NULL,
model_name = "xlm-roberta-base",
regression = FALSE,
output_dir,
cuda = detect_cuda(),
num_train_epochs = 4,
train_size = 0.8,
args = NULL,
cleanup = TRUE,
model_type = NULL,
manual_seed = floor(runif(1, min = 1, max = 721831)),
verbose = TRUE
)
## S3 method for class 'corpus'
grafzahl(
x,
y = NULL,
model_name = "xlm-roberta-base",
regression = FALSE,
output_dir,
cuda = detect_cuda(),
num_train_epochs = 4,
train_size = 0.8,
args = NULL,
cleanup = TRUE,
model_type = NULL,
manual_seed = floor(runif(1, min = 1, max = 721831)),
verbose = TRUE
)
textmodel_transformer(...)
## S3 method for class 'character'
grafzahl(
x,
y = NULL,
model_name = "xlmroberta",
regression = FALSE,
output_dir,
cuda = detect_cuda(),
num_train_epochs = 4,
train_size = 0.8,
args = NULL,
cleanup = TRUE,
model_type = NULL,
manual_seed = floor(runif(1, min = 1, max = 721831)),
verbose = TRUE
)
Arguments
x |
the corpus or character vector of texts on which the model will be trained. Depending on |
y |
training labels. It can either be a single string indicating which docvars of the corpus is the training labels; a vector of training labels in either character or factor; or |
model_name |
string indicates either 1) the model name on Hugging Face website; 2) the local path of the model |
regression |
logical, if |
output_dir |
string, location of the output model. If missing, the model will be stored in a temporary directory. Important: Please note that if this directory exists, it will be overwritten. |
cuda |
logical, whether to use CUDA, default to |
num_train_epochs |
numeric, if |
train_size |
numeric, proportion of data in |
args |
list, additionally parameters to be used in the underlying simple transformers |
cleanup |
logical, if |
model_type |
a string indicating model_type of the input model. If |
manual_seed |
numeric, random seed |
verbose |
logical, if |
... |
paramters pass to |
Value
a grafzahl
S3 object with the following items
call |
original function call |
input_data |
input_data for the underlying python function |
output_dir |
location of the output model |
model_type |
model type |
model_name |
model name |
regression |
whether or not it is a regression model |
levels |
factor levels of y |
manual_seed |
random seed |
meta |
metadata about the current session |
See Also
Examples
if (detect_conda() && interactive()) {
library(quanteda)
set.seed(20190721)
## Using the default cross validation method
model1 <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base")
predict(model1)
## Using LIME
input <- corpus(ecosent, text_field = "headline")
training_corpus <- corpus_subset(input, !gold)
model2 <- grafzahl(x = training_corpus,
y = "value",
model_name = "GroNLP/bert-base-dutch-cased")
test_corpus <- corpus_subset(input, gold)
predicted_sentiment <- predict(model2, test_corpus)
require(lime)
sentences <- c("Dijsselbloem pessimistisch over snelle stappen Grieken",
"Aandelenbeurzen zetten koersopmars voort")
explainer <- lime(training_corpus, model2)
explanations <- explain(sentences, explainer, n_labels = 1,
n_features = 2)
plot_text_explanations(explanations)
}