predict.tidylda {tidylda} | R Documentation |
Get predictions from a Latent Dirichlet Allocation model
Description
Obtains predictions of topics for new documents from a fitted LDA model
Usage
## S3 method for class 'tidylda'
predict(
object,
new_data,
type = c("prob", "class", "distribution"),
method = c("gibbs", "dot"),
iterations = NULL,
burnin = -1,
no_common_tokens = c("default", "zero", "uniform"),
times = 100,
threads = 1,
verbose = TRUE,
...
)
Arguments
object |
a fitted object of class |
new_data |
a DTM or TCM of class |
type |
one of "prob", "class", or "distribution". Defaults to "prob". |
method |
one of either "gibbs" or "dot". If "gibbs" Gibbs sampling is used
and |
iterations |
If |
burnin |
If |
no_common_tokens |
behavior when encountering documents that have no tokens
in common with the model. Options are " |
times |
Integer, number of samples to draw if |
threads |
Number of parallel threads, defaults to 1. Note: currently ignored; only single-threaded prediction is implemented. |
verbose |
Logical. Do you want to print a progress bar out to the console?
Only active if |
... |
Additional arguments, currently unused |
Details
If predict.tidylda
encounters documents that have no tokens in common
with the model in object
it will engage in one of three behaviors based
on the setting of no_common_tokens
.
default
(the default) sets all topics to 0 for offending documents. This
enables continued computations downstream in a way that NA
would not.
However, if no_common_tokens == "default"
, then predict.tidylda
will emit a warning for every such document it encounters.
zero
has the same behavior as default
but it emits a message
instead of a warning.
uniform
sets all topics to 1/k for every topic for offending documents.
it does not emit a warning or message.
Value
type
gives different outputs depending on whether the user selects
"prob", "class", or "distribution". If "prob", the default, returns a
a "theta" matrix with one row per document and one column per topic. If
"class", returns a vector with the topic index of the most likely topic in
each document. If "distribution", returns a tibble with one row per
parameter per sample. Number of samples is set by the times
argument.
Examples
# load some data
data(nih_sample_dtm)
# fit a model
set.seed(12345)
m <- tidylda(
data = nih_sample_dtm[1:20, ], k = 5,
iterations = 200, burnin = 175
)
str(m)
# predict on held-out documents using gibbs sampling "fold in"
p1 <- predict(m, nih_sample_dtm[21:100, ],
method = "gibbs",
iterations = 200, burnin = 175
)
# predict on held-out documents using the dot product
p2 <- predict(m, nih_sample_dtm[21:100, ], method = "dot")
# compare the methods
barplot(rbind(p1[1, ], p2[1, ]), beside = TRUE, col = c("red", "blue"))
# predict classes on held out documents
p3 <- predict(m, nih_sample_dtm[21:100, ],
method = "gibbs",
type = "class",
iterations = 100, burnin = 75
)
# predict distribution on held out documents
p4 <- predict(m, nih_sample_dtm[21:100, ],
method = "gibbs",
type = "distribution",
iterations = 100, burnin = 75,
times = 10
)