predict.ETM {topicmodels.etm}R Documentation

Predict functionality for an ETM object.

Description

Predict to which ETM topic a text belongs or extract which words are emitted for each topic.

Usage

## S3 method for class 'ETM'
predict(
  object,
  newdata,
  type = c("topics", "terms"),
  batch_size = nrow(newdata),
  normalize = TRUE,
  top_n = 10,
  ...
)

Arguments

object

an object of class ETM

newdata

bag of words document term matrix in dgCMatrix format. Only used in case type = 'topics'.

type

a character string with either 'topics' or 'terms' indicating to either predict to which topic a document encoded as a set of bag of words belongs to or to extract the most emitted terms for each topic

batch_size

integer with the size of the batch in order to do chunkwise predictions in chunks of batch_size rows. Defaults to the whole dataset provided in newdata. Only used in case type = 'topics'.

normalize

logical indicating to normalize the bag of words data. Defaults to TRUE similar as the default when building the ETM model. Only used in case type = 'topics'.

top_n

integer with the number of most relevant words for each topic to extract. Only used in case type = 'terms'.

...

not used

Value

Returns for

See Also

ETM

Examples


library(torch)
library(topicmodels.etm)
path  <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt")
model <- torch_load(path)

# Get most emitted words for each topic
terminology  <- predict(model, type = "terms", top_n = 5)
terminology

# Get topics probabilities for each document
path   <- system.file(package = "topicmodels.etm", "example", "example_dtm.rds")
dtm    <- readRDS(path)
dtm    <- head(dtm, n = 5)
scores <- predict(model, newdata = dtm, type = "topics")
scores


[Package topicmodels.etm version 0.1.0 Index]