This document evaluates a backoff-based n-gram prediction model applied to three different text sources: Twitter, Blogs, and News.
We load and sample each source, explore term frequency, train a basic model, and explore enhancements for accuracy and runtime efficiency.
read_sample <- function(path, n = 5000, seed = 123) {
set.seed(seed)
lines <- readLines(path, encoding = "UTF-8", skipNul = TRUE)
sample(lines, size = n)
}
twitter_lines <- read_sample("data/en_US.twitter.txt")
blogs_lines <- read_sample("data/en_US.blogs.txt")
news_lines <- read_sample("data/en_US.news.txt")
twitter_df <- tibble(text = twitter_lines)
tokens <- twitter_df %>%
unnest_tokens(word, text) %>%
filter(str_detect(word, "^[a-z']+$"))
unigram_model <- tokens %>%
count(word, sort = TRUE)
unigrams <- twitter_df %>%
unnest_tokens(word, text) %>%
filter(str_detect(word, "^[a-z']+$")) %>%
count(word, sort = TRUE)
bigram_model <- twitter_df %>%
unnest_tokens(bigram, text, token = "ngrams", n = 2) %>%
count(bigram, sort = TRUE) %>%
separate(bigram, into = c("w1", "w2"), sep = " ") %>%
group_by(w1) %>%
slice_max(n, n = 5) %>%
ungroup()
trigram_model <- twitter_df %>%
unnest_tokens(trigram, text, token = "ngrams", n = 3) %>%
count(trigram, sort = TRUE) %>%
separate(trigram, into = c("w1", "w2", "w3"), sep = " ") %>%
group_by(w1, w2) %>%
slice_max(n, n = 3) %>%
ungroup()
predict_next_word <- function(input_text) {
words <- tolower(str_split(input_text, " ")[[1]])
len <- length(words)
if (len >= 2) {
pred <- trigram_model %>%
filter(w1 == words[len - 1], w2 == words[len]) %>%
arrange(desc(n)) %>%
pull(w3)
if (length(pred) > 0) return(pred[1])
}
if (len >= 1) {
pred <- bigram_model %>%
filter(w1 == words[len]) %>%
arrange(desc(n)) %>%
pull(w2)
if (length(pred) > 0) return(pred[1])
}
return(unigram_model$word[1])
}
test_phrases <- c(
"I love",
"You are the",
"This is so",
"I believe",
"When I was",
"The most important",
"The president said",
"According to the",
"In the event"
)
results <- map_chr(test_phrases, predict_next_word)
data.frame(Phrase = test_phrases, Prediction = results)
## Phrase Prediction
## 1 I love you
## 2 You are the best
## 3 This is so excited
## 4 I believe in
## 5 When I was a
## 6 The most important great
## 7 The president said i
## 8 According to the movies
## 9 In the event i
tic("Prediction")
predict_next_word("I can't wait")
## [1] "to"
toc()
## Prediction: 0.005 sec elapsed
total_words <- sum(unigrams$n)
coverage_50 <- unigrams %>%
mutate(cum_sum = cumsum(n)) %>%
filter(cum_sum < total_words * 0.5) %>%
nrow()
coverage_90 <- unigrams %>%
mutate(cum_sum = cumsum(n)) %>%
filter(cum_sum < total_words * 0.9) %>%
nrow()
data.frame(coverage_50, coverage_90)
## coverage_50 coverage_90
## 1 120 3674
This predictive model will be deployed in a Shiny app with an input box and a single-word prediction output.
Improvements include:
Source-specific tuning (blogs/news) Faster runtime via data.table or DuckDB Backup predictions with word vectors
save(unigram_model, bigram_model, trigram_model, file = "models.RData")