eval_sbo_predictor {sbo}R Documentation

Evaluate Stupid Back-off next-word predictions

Description

Evaluate next-word predictions based on Stupid Back-off N-gram model on a test corpus.

Usage

eval_sbo_predictor(model, test, L = attr(model, "L"))

Arguments

model

a sbo_predictor object.

test

a character vector. Perform a single prediction on each entry of this vector (see details).

L

Maximum number of predictions for each input sentence (maximum allowed is attr(model, "L"))

Details

This function allows to obtain information on the quality of Stupid Back-off model predictions, such as next-word prediction accuracy, or the word-rank distribution of correct prediction, by direct test against a test set corpus. For a reasonable estimate of prediction accuracy, the different entries of the test vector should be uncorrelated documents (e.g. separate tweets, as in the twitter_test example dataset).

More in detail, eval_sbo_predictor performs the following operations:

  1. Sample a single sentence from each entry of the character vector test.

  2. Sample a single $N$-gram from each sentence obtained in the previous step.

  3. Predict next words from the $(N-1)$-gram prefix.

  4. Return all predictions, together with the true word completions.

Value

A tibble, containing the input $(N-1)$-grams, the true completions, the predicted completions and a column indicating whether one of the predictions were correct or not.

Author(s)

Valerio Gherardi

Examples


# Evaluating next-word predictions from a Stupid Back-off N-gram model
if (suppressMessages(require(dplyr) && require(ggplot2))) {
        p <- sbo_predictor(twitter_predtable)
        set.seed(840) # Set seed for reproducibility
        test <- sample(twitter_test, 500)
        eval <- eval_sbo_predictor(p, test)
        
        ## Compute three-word accuracies
        eval %>% summarise(accuracy = sum(correct)/n()) # Overall accuracy
        eval %>% # Accuracy for in-sentence predictions
                filter(true != "<EOS>") %>%
                summarise(accuracy = sum(correct) / n())
        
        ## Make histogram of word-rank distribution for correct predictions
        dict <- attr(twitter_predtable, "dict")
        eval %>%
                filter(correct, true != "<EOS>") %>%
                transmute(rank = match(true, table = dict)) %>%
                ggplot(aes(x = rank)) + geom_histogram(binwidth = 30)
}


[Package sbo version 0.5.0 Index]