predict.textspace {ruimtehol} | R Documentation |
Predict using a Starspace model
Description
The prediction functionality allows you to retrieve the following types of elements from a Starspace model:
-
generic
: get general Starspace predictions in detail -
labels
: get similarity of your text to all the labels of the Starspace model -
embedding
: document embeddings of your text (shorthand forstarspace_embedding
) -
knn
: k-nearest neighbouring (most similar) elements of the model dictionary compared to your input text (shorthand forstarspace_knn
)
Usage
## S3 method for class 'textspace'
predict(
object,
newdata,
type = c("generic", "labels", "knn", "embedding"),
k = 5L,
sep = " ",
basedoc,
...
)
Arguments
object |
an object of class |
newdata |
a data frame with columns |
type |
character string: either 'generic', 'labels', 'embedding', 'knn'. Defaults to 'generic' |
k |
integer with the number of predictions to make. Defaults to 5. Only used in case |
sep |
character string used to split |
basedoc |
optional, either a character vector of possible elements to predict or
the path to a file in labelDoc format, containing basedocs which are set of possible things to predict, if different than
the ones from the training data. Only used in case |
... |
not used |
Value
The following is returned, depending on the argument type
:
In case type is set to
'generic'
: a list, one for each row or element innewdata
. Each list element is a list with elementsdoc_id: the identifier of the text
text: the character string with the text
prediction: data.frame with columns label, label_starspace and similarity indicating the predicted label and the similarity of the text to the label
terms: a list with elements basedoc_index and basedoc_terms indicating the position in basedoc and the terms which are part of the dictionary which are used to find the similarity
In case type is set to
'labels'
: a data.frame is returned namely:
The data.framenewdata
where several columns are added, one for each label in the Starspace model. These columns contain the similarities of the text to the label. Similarities are computed withembedding_similarity
indicating embedding similarities of the text compared to the labels using either cosine or dot product as was used during model training.In case type is set to
'embedding'
:
A matrix of document embeddings, one embedding for each text innewdata
as returned bystarspace_embedding
. The rownames of this matrix are set to the document identifiers ofnewdata
.In case type is set to
'knn'
: a list of data.frames, one for each row or element innewdata
Each of these data frames contains the columns doc_id, label, similarity and rank indicating the k-nearest neighbouring (most similar) elements of the model dictionary compared to your input text as returned bystarspace_knn
Examples
data(dekamer, package = "ruimtehol")
dekamer$text <- strsplit(dekamer$question, "\\W")
dekamer$text <- lapply(dekamer$text, FUN = function(x) x[x != ""])
dekamer$text <- sapply(dekamer$text,
FUN = function(x) paste(x, collapse = " "))
idx <- sample(nrow(dekamer), size = round(nrow(dekamer) * 0.9))
traindata <- dekamer[idx, ]
testdata <- dekamer[-idx, ]
set.seed(123456789)
model <- embed_tagspace(x = traindata$text,
y = traindata$question_theme_main,
early_stopping = 0.8,
dim = 10, minCount = 5)
scores <- predict(model, testdata)
scores <- predict(model, testdata, type = "labels")
str(scores)
emb <- predict(model, testdata[, c("doc_id", "text")], type = "embedding")
knn <- predict(model, testdata[1:5, c("doc_id", "text")], type = "knn", k=3)
## Not run:
library(udpipe)
data(dekamer, package = "ruimtehol")
dekamer <- subset(dekamer, question_theme_main == "DEFENSIEBELEID")
x <- udpipe(dekamer$question, "dutch", tagger = "none", parser = "none", trace = 100)
x <- x[, c("doc_id", "sentence_id", "sentence", "token")]
set.seed(123456789)
model <- embed_sentencespace(x, dim = 15, epoch = 5, minCount = 5)
scores <- predict(model, "Wat zijn de cijfers qua doorstroming van 2016?",
basedoc = unique(x$sentence), k = 3)
str(scores)
#' ## clean up for cran
file.remove(list.files(pattern = ".udpipe$"))
## End(Not run)