This project focuses on building a predictive text model using n-gram language modeling techniques. The goal is to predict the next word given the previous one, two, or three words, leveraging unigram, bigram, and trigram frequency tables derived from a corpus of crude oil news articles. The process includes:
This approach demonstrates how simple probabilistic models can be applied to natural language prediction tasks efficiently, even with limited data.
Text mining is the process of extracting meaningful information from unstructured text data. It involves cleaning, transforming, and analyzing textual content to uncover patterns and insights. Language modeling is the task of assigning probabilities to sequences of words. An n-gram model approximates this by considering only the last (n-1) words of context.
For a sequence of words \(W = (w_1, w_2, .....w_k)\), the full probability is : \[ P(W) = P(w_1) \dot P(w_2 | w_1) \dot P(w_3 | w_1 | w_2) ..... P(w_k | w_1 .... | w_{k-1}) \]
using the Markov assumption, an n-gram mdel simplifies this to : \[ P(w_k | w_1 .... | w_{k-1}) \approx P(w_k | w_{k-n+1} .... | w_{k-1}) \] For:
If a trigram is not found:
\[ P_{backoff} (w_i | w_{i-2}, w_{i-1}) = \lambda_3 P(w_i | w_{i-2}, w_{i-1}) + \lambda_2 P(w_i | w_{i-1}) + \lambda_1 P(w_i) \] where \(\lambda_1, \lambda_2, \lambda_3\) are weights ( e.g: 1.0, 0.4, 0.1 for stupid backoff)
The goal of EDA is to understand the structure of the text corpus, identify frequent terms, and analyze patterns using unigrams, bigrams, and trigrams.
We use the entire corpus due to its small size.
# built-in `crude` dataset from the `tm` package
library(tm)
data("crude")
corpus <- crude
The following text preprocessing steps were applied:
# Text cleaning with tm package
corpus <- VCorpus(VectorSource(corpus))
corpus <- tm_map(corpus, content_transformer(tolower))
corpus <- tm_map(corpus, removePunctuation)
corpus <- tm_map(corpus, removeNumbers)
corpus <- tm_map(corpus, removeWords, stopwords("english"))
corpus <- tm_map(corpus, stripWhitespace)
# Basic summary of corpus
summary(corpus)
## Length Class Mode
## 1 2 PlainTextDocument list
## 2 2 PlainTextDocument list
## 3 2 PlainTextDocument list
## 4 2 PlainTextDocument list
## 5 2 PlainTextDocument list
## 6 2 PlainTextDocument list
## 7 2 PlainTextDocument list
## 8 2 PlainTextDocument list
## 9 2 PlainTextDocument list
## 10 2 PlainTextDocument list
## 11 2 PlainTextDocument list
## 12 2 PlainTextDocument list
## 13 2 PlainTextDocument list
## 14 2 PlainTextDocument list
## 15 2 PlainTextDocument list
## 16 2 PlainTextDocument list
## 17 2 PlainTextDocument list
## 18 2 PlainTextDocument list
## 19 2 PlainTextDocument list
## 20 2 PlainTextDocument list
# Calculate statistics
doc_lengths <- sapply(corpus, function(x) strsplit(as.character(x), "\\s+") %>% lengths())
num_docs <- length(corpus)
avg_length <- mean(doc_lengths)
min_length <- min(doc_lengths)
max_length <- max(doc_lengths)
# Vocabulary size
dtm <- DocumentTermMatrix(corpus)
vocab_size <- length(dtm$dimnames$Terms)
# Create summary table
summary_table <- data.frame(
Metric = c("Number of Documents", "Average Document Length", "Minimum Document Length", "Maximum Document Length", "Vocabulary Size"),
Value = c(num_docs, round(avg_length, 2), min_length, max_length, vocab_size)
)
# Display as table
knitr::kable(summary_table, caption = "Dataset Summary Statistics")
| Metric | Value |
|---|---|
| Number of Documents | 20.00 |
| Average Document Length | 122.15 |
| Minimum Document Length | 35.00 |
| Maximum Document Length | 265.00 |
| Vocabulary Size | 951.00 |
dtm <- DocumentTermMatrix(corpus)
m <- as.matrix(dtm)
# word_freq <- sort(rowSums(m), decreasing = TRUE)
word_freq <- sort(colSums(m), decreasing = TRUE)
word_freq_df <- data.frame(word = names(word_freq), freq = word_freq)
head(word_freq_df, 5)
## word freq
## oil oil 85
## said said 73
## prices prices 48
## opec opec 42
## mln mln 31
We compute term frequencies from the document-term matrix. Let \(f_{ij}\) be the frequency of term \(j\) in document \(i\). The total frequency of term \(j\) is \(f_j = \sum_{i=1}^{n} f_{ij}\).
Unigrams are individual words. Analyzing their frequency helps us understand which words are most common in the corpus.
Mathematically, if \(w_i\) is a word, its frequency \(f(w_i)\) is computed as: \[ f(w_i) = \sum_{d \in C} {count}(w_i, d) \]
library(tidytext)
library(dplyr)
library(tibble)
# Convert corpus to a clean tibble with character column
text_df <- tibble(text = as.character(sapply(corpus, as.character)))
unigrams <- text_df %>% unnest_tokens(word, text, token="words")
unigram_freq <- unigrams %>% count(word, sort=TRUE)
#head(unigram_freq, 10)
Key Observations :
Bigrams are pairs of consecutive words. They help reveal contextual relationships and common phrases.
Mathematically: \[ f(w_i, w_{i+1}) = \sum_{d \in C} {count}((w_i, w_{i+1}), d) \]
\(B = \{(w_i, w_{i+1})\}\) is the set of bigrams.
bigrams <- text_df %>% unnest_tokens(bigram, text, token = "ngrams", n = 2)
bigram_freq <- bigrams %>% count(bigram, sort = TRUE) %>% filter(!is.na(bigram))
#head(bigram_freq, 20)
Key Observations:
Trigrams are sequences of three consecutive words. They capture more complex linguistic patterns.
Mathematically: \[ f(w_i, w_{i+1}, w_{i+2}) = \sum_{d \in C} ext{count}((w_i, w_{i+1}, w_{i+2}), d) \]
\(T = \{(w_i, w_{i+1}, w_{i+2})\}\) is the set of trigrams
trigrams <- text_df %>% unnest_tokens(trigram, text, token = "ngrams", n = 3)
trigram_freq <- trigrams %>% count(trigram, sort = TRUE)
#head(trigram_freq, 20)
Key Observations :
library(wordcloud)
library(RColorBrewer)
wordcloud(words = word_freq_df$word, freq = word_freq_df$freq, min.freq = 2,
max.words = 100, random.order = FALSE, colors = brewer.pal(8, "Dark2"))
Word clouds visualize term frequency. Larger words indicate higher frequency.
This step takes your bigram and trigram frequency tables (created during EDA) and splits the combined word strings into separate columns.For example, “crude oil” becomes \(w_1 = "crude", w_2 = "oil"\). This makes it easier to filter and match based on previous words when predicting the next word.
# Assuming you already have these from your EDA:
# unigram_freq, bigram_freq, trigram_freq
# Split bigrams and trigrams into separate columns
library(tidytext)
library(dplyr)
library(tidyr)
library(tibble)
text_df <- tibble(text = as.character(sapply(corpus, as.character)))
# Unigrams
unigrams <- text_df %>% unnest_tokens(word, text, token = "words")
unigram_freq <- unigrams %>% count(word, sort = TRUE)
# Bigrams
bigrams <- text_df %>% unnest_tokens(bigram, text, token = "ngrams", n = 2)
bigram_freq <- bigrams %>% count(bigram, sort = TRUE) %>% filter(!is.na(bigram))
# Trigrams
trigrams <- text_df %>% unnest_tokens(trigram, text, token = "ngrams", n = 3)
trigram_freq <- trigrams %>% count(trigram, sort = TRUE)
# Split bigrams and trigrams
bigram_freq <- bigram_freq %>% separate(bigram, into = c("w1", "w2"), sep = " ")
trigram_freq <- trigram_freq %>% separate(trigram, into = c("w1", "w2", "w3"), sep = " ")
print(unigram_freq)
## # A tibble: 957 × 2
## word n
## <chr> <int>
## 1 oil 85
## 2 said 73
## 3 prices 48
## 4 opec 42
## 5 mln 31
## 6 last 24
## 7 bpd 23
## 8 dlrs 23
## 9 crude 21
## 10 market 20
## # ℹ 947 more rows
print(bigram_freq)
## # A tibble: 1,944 × 3
## w1 w2 n
## <chr> <chr> <int>
## 1 oil prices 18
## 2 mln bpd 14
## 3 crude oil 13
## 4 dlrs barrel 10
## 5 sources said 9
## 6 mln barrels 8
## 7 oil minister 7
## 8 world oil 7
## 9 billion riyals 6
## 10 last month 6
## # ℹ 1,934 more rows
print(trigram_freq)
## # A tibble: 2,158 × 4
## w1 w2 w3 n
## <chr> <chr> <chr> <int>
## 1 world oil prices 6
## 2 sheikh abdulaziz said 4
## 3 ali alkhalifa alsabah 3
## 4 arabian oil minister 3
## 5 barrels per day 3
## 6 emergency opec meeting 3
## 7 hold futures position 3
## 8 industry sources said 3
## 9 kuwaits oil minister 3
## 10 minister hisham nazer 3
## # ℹ 2,148 more rows
This converts input text to lowercase and splits into words.
Backoff logic:
top_n controls how many predictions to return (e.g., top 3).
predict_next_word <- function(input_text, top_n = 3) {
input_text <- tolower(input_text)
words <- strsplit(input_text, "\\s+")[[1]]
len <- length(words)
# Try trigram prediction
if (len >= 2) {
w1 <- words[len - 1]
w2 <- words[len]
trigram_matches <- trigram_freq %>%
filter(w1 == !!w1, w2 == !!w2) %>%
arrange(desc(n)) %>%
head(top_n)
if (nrow(trigram_matches) > 0) return(trigram_matches$w3)
}
# Backoff to bigram
if (len >= 1) {
w2 <- words[len]
bigram_matches <- bigram_freq %>%
filter(w1 == !!w2) %>%
arrange(desc(n)) %>%
head(top_n)
if (nrow(bigram_matches) > 0) return(bigram_matches$w2)
}
# Backoff to unigram
unigram_matches <- unigram_freq %>%
arrange(desc(n)) %>%
head(top_n)
return(unigram_matches$word)
}
# All words in sequence = test data
test_data <- unigrams$word
Tests the function by predicting the next word after “crude oil”. Returns top 3 predictions based on trigram, bigram, or unigram frequencies.
# Model testing
predict_next_word("crude oil", top_n = 3)
## [1] "one" "canadian" "company"
Loops through a test dataset. For each position, uses previous two words as context and predicts the next word. Compares predictions with actual word:
evaluate_accuracy <- function(test_data, model_func, top_n = NULL) {
correct_top1 <- 0
correct_top2 <- 0
correct_top3 <- 0
total <- 0
for (i in 1:(length(test_data) - 2)) {
context <- paste(test_data[i], test_data[i + 1])
actual <- test_data[i + 2]
predictions <- if (!is.null(top_n)) {
model_func(context, top_n)
} else {
model_func(context)
}
if (length(predictions) > 0) {
if (actual == predictions[1]) correct_top1 <- correct_top1 + 1
if (length(predictions) >= 2 && actual == predictions[2]) correct_top2 <- correct_top2 + 1
if (actual %in% predictions) correct_top3 <- correct_top3 + 1
}
total <- total + 1
}
list(
top1_accuracy = correct_top1 / total,
top2_accuracy = correct_top2 / total,
top3_accuracy = correct_top3 / total
)
}
Measures how fast the prediction function runs. Runs the prediction 100 times and reports timing statistics.
library(microbenchmark)
timing_result <- microbenchmark(predict_next_word("crude oil", top_n = 3), times = 100)
print(summary(timing_result))
## expr min lq mean median
## 1 predict_next_word("crude oil", top_n = 3) 1.731801 1.837552 2.044218 1.925301
## uq max neval
## 1 2.120551 4.486401 100
perplexity measures how well the model predicts a sequence of words. Lower perplexity = better model
$$ Perplexity = 2^{- _{i=1} ^N P(w_i | context)}
$$
1. Compute Probabilities for n-grams
we need functions to calculate probabilities for unigram, bigram, trigram
# Total counts for normalization
total_unigrams <- sum(unigram_freq$n)
# Probability functions
get_unigram_prob <- function(word) {
count <- unigram_freq$n[unigram_freq$word == word]
if (length(count) == 0) return(1e-6) # Smoothing for unseen words
return(count / total_unigrams)
}
get_bigram_prob <- function(w1, w2) {
count_bigram <- bigram_freq$n[bigram_freq$w1 == w1 & bigram_freq$w2 == w2]
count_w1 <- unigram_freq$n[unigram_freq$word == w1]
if (length(count_bigram) == 0 || length(count_w1) == 0) return(1e-6)
return(count_bigram / count_w1)
}
get_trigram_prob <- function(w1, w2, w3) {
count_trigram <- trigram_freq$n[trigram_freq$w1 == w1 & trigram_freq$w2 == w2 & trigram_freq$w3 == w3]
count_bigram <- bigram_freq$n[bigram_freq$w1 == w1 & bigram_freq$w2 == w2]
if (length(count_trigram) == 0 || length(count_bigram) == 0) return(1e-6)
return(count_trigram / count_bigram)
}
2. Backoff Probbility function
Combine trigram, bigram, and unigram probabilities using Stupid Backoff weights:
ngram_probs <- function(w1, w2, w3) {
lambda3 <- 1.0
lambda2 <- 0.4
lambda1 <- 0.1
p3 <- get_trigram_prob(w1, w2, w3)
if (p3 > 1e-6) return(p3)
p2 <- get_bigram_prob(w2, w3)
if (p2 > 1e-6) return(lambda2 * p2)
p1 <- get_unigram_prob(w3)
return(lambda1 * p1)
}
3. Calculate Perplexity
calculate_perplexity <- function(test_tokens) {
log_prob_sum <- 0
N <- length(test_tokens) - 2
for (i in 1:N) {
w1 <- test_tokens[i]
w2 <- test_tokens[i + 1]
w3 <- test_tokens[i + 2]
prob <- ngram_probs(w1, w2, w3)
log_prob_sum <- log_prob_sum + log2(prob)
}
perplexity <- 2^(-log_prob_sum / N)
return(perplexity)
}
4. Run Evaluation
# Prepare test data
test_data <- unigrams$word
# Accuracy
accuracy_results <- evaluate_accuracy(test_data, predict_next_word, top_n = 3)
print(accuracy_results)
## $top1_accuracy
## [1] 0.8887521
##
## $top2_accuracy
## [1] 0.06280788
##
## $top3_accuracy
## [1] 0.9671593
# Timing
library(microbenchmark)
timing_result <- microbenchmark(predict_next_word("crude oil", top_n = 3), times = 100)
print(summary(timing_result))
## expr min lq mean median
## 1 predict_next_word("crude oil", top_n = 3) 1.712901 1.775351 1.950392 1.824301
## uq max neval
## 1 1.992851 4.152301 100
# Perplexity
perplexity_value <- calculate_perplexity(test_data)
print(perplexity_value)
## [1] 1.31944
Parameters:
Model Size:
# Compare unigram vs bigram vs trigram
predict_unigram <- function(input_text, top_n = 3) {
unigram_freq %>%
arrange(desc(n)) %>%
head(top_n) %>%
pull(word)
}
predict_bigram <- function(input_text, top_n = 3) {
words <- strsplit(tolower(input_text), "\\s+")[[1]]
w <- tail(words, 1)
bigram_matches <- bigram_freq %>%
filter(w1 == !!w) %>%
arrange(desc(n)) %>%
head(top_n)
if (nrow(bigram_matches) > 0) return(bigram_matches$w2)
return(unigram_freq %>% arrange(desc(n)) %>% head(top_n) %>% pull(word)) # fallback
}
predict_trigram <- function(input_text, top_n = 3) {
words <- strsplit(tolower(input_text), "\\s+")[[1]]
len <- length(words)
if (len >= 2) {
w1 <- words[len - 1]
w2 <- words[len]
trigram_matches <- trigram_freq %>%
filter(w1 == !!w1, w2 == !!w2) %>%
arrange(desc(n)) %>%
head(top_n)
if (nrow(trigram_matches) > 0) return(trigram_matches$w3)
}
return(predict_bigram(input_text, top_n)) # fallback
}
# Evaluate each
accuracy_uni <- evaluate_accuracy(test_data, predict_unigram, top_n = 3)
accuracy_bi <- evaluate_accuracy(test_data, predict_bigram, top_n = 3)
accuracy_tri <- evaluate_accuracy(test_data, predict_next_word, top_n = 3)
print(list(Unigram = accuracy_uni, Bigram = accuracy_bi, Trigram = accuracy_tri))
## $Unigram
## $Unigram$top1_accuracy
## [1] 0.03489327
##
## $Unigram$top2_accuracy
## [1] 0.02996716
##
## $Unigram$top3_accuracy
## [1] 0.08456486
##
##
## $Bigram
## $Bigram$top1_accuracy
## [1] 0.5250411
##
## $Bigram$top2_accuracy
## [1] 0.1703612
##
## $Bigram$top3_accuracy
## [1] 0.7795567
##
##
## $Trigram
## $Trigram$top1_accuracy
## [1] 0.8887521
##
## $Trigram$top2_accuracy
## [1] 0.06280788
##
## $Trigram$top3_accuracy
## [1] 0.9671593
Trigram lookup is slower because:
Example timing:
library(microbenchmark)
timing_uni <- microbenchmark(predict_unigram("crude oil"), times = 100)
timing_bi <- microbenchmark(predict_bigram("crude oil"), times = 100)
timing_tri <- microbenchmark(predict_next_word("crude oil"), times = 100)
print(list(Unigram = summary(timing_uni), Bigram = summary(timing_bi), Trigram = summary(timing_tri)))
## $Unigram
## expr min lq mean median uq
## 1 predict_unigram("crude oil") 1.245101 1.301201 1.551658 1.376651 1.60265
## max neval
## 1 3.295101 100
##
## $Bigram
## expr min lq mean median uq
## 1 predict_bigram("crude oil") 1.731802 1.830302 2.274264 1.959401 2.26815
## max neval
## 1 9.437601 100
##
## $Trigram
## expr min lq mean median uq
## 1 predict_next_word("crude oil") 1.725501 1.806651 2.140748 1.920901 2.244951
## max neval
## 1 6.953701 100
Generally, lower perplexity = better accuracy, but correlation is not perfect. Example:
perplexity_uni <- calculate_perplexity(test_data) # using unigram probabilities
perplexity_tri <- calculate_perplexity(test_data) # using trigram backoff probabilities
print(list(Unigram_Perplexity = perplexity_uni, Trigram_Perplexity = perplexity_tri))
## $Unigram_Perplexity
## [1] 1.31944
##
## $Trigram_Perplexity
## [1] 1.31944
Prune low-frequency n-grams (e.g., keep only n-grams with count ≥ 2). This reduces memory and speeds up lookup with minimal accuracy loss.
bigram_freq <- bigram_freq %>% filter(n >= 2)
trigram_freq <- trigram_freq %>% filter(n >= 2)
After pruning accuracy drops slightly (e.g., from 40% to 38%) but model size reduces significantly.
To improve the model, we can explore:
Here are a few options beyond traditional n-grams:
To improve generalization, consider using:
Accuracy Metrices
\[Perplexity = 2^{-\frac {1}{N} \sum_{i=1} ^N log_2 P(w_i | w_{i-2}, w_{i-1})}\]
Efficiency Metrices
Measured using microbenchmark, as you’ve done:
Represent each document (or sentence) as a TF-IDF vector. For a given input phrase, find the most similar sentence using cosine similarity. Predict the next word from the most similar sentence.
# Load libraries
library(tm)
library(tidytext)
library(dplyr)
library(tibble)
library(text2vec)
library(stringr)
# Load and clean data
data("crude")
corpus <- VCorpus(VectorSource(crude))
corpus <- tm_map(corpus, content_transformer(tolower))
corpus <- tm_map(corpus, removePunctuation)
corpus <- tm_map(corpus, removeNumbers)
corpus <- tm_map(corpus, removeWords, stopwords("english"))
corpus <- tm_map(corpus, stripWhitespace)
# Convert to tibble
text_df <- tibble(text = sapply(corpus, as.character))
# Tokenize into sentences
sentences <- unlist(strsplit(text_df$text, split = "\\."))
sentences <- sentences[sentences != ""]
# Create TF-IDF matrix
it <- itoken(sentences, progress_bar = FALSE)
vectorizer <- vocab_vectorizer(create_vocabulary(it))
dtm <- create_dtm(it, vectorizer)
tfidf <- TfIdf$new()
tfidf_matrix <- tfidf$fit_transform(dtm)
# Cosine similarity function
cosine_sim <- function(x, y) {
sum(x * y) / (sqrt(sum(x^2)) * sqrt(sum(y^2)))
}
# Prediction function
predict_next_word_tfidf <- function(input_text) {
input_text <- tolower(input_text)
input_vec <- tfidf$transform(create_dtm(itoken(input_text), vectorizer))
sims <- apply(tfidf_matrix, 1, function(row) cosine_sim(row, input_vec))
best_match <- sentences[which.max(sims)]
words <- unlist(strsplit(best_match, "\\s+"))
idx <- which(words %in% tail(strsplit(input_text, "\\s+")[[1]], 1))
if (length(idx) > 0 && idx[1] < length(words)) {
return(words[idx[1] + 1])
} else {
return(words[1])
}
}
# Example
predict_next_word_tfidf("crude oil")
## 413
## "one"
library(microbenchmark)
library(dplyr)
# Top-N Accuracy Evaluation
evaluate_accuracy_tfidf <- function(test_data, model_func, top_n = 3) {
correct_top1 <- 0
correct_top2 <- 0
correct_top3 <- 0
total <- 0
for (i in 1:(length(test_data) - 2)) {
context <- paste(test_data[i], test_data[i + 1])
actual <- test_data[i + 2]
prediction <- model_func(context)
predictions <- c(prediction) # For TF-IDF, we return one word (extend if needed)
# Check Top-1
if (actual == predictions[1]) correct_top1 <- correct_top1 + 1
# Check Top-2
if (length(predictions) >= 2 && actual %in% predictions[1:2]) correct_top2 <- correct_top2 + 1
# Check Top-3
if (actual %in% predictions[1:min(top_n, length(predictions))]) correct_top3 <- correct_top3 + 1
total <- total + 1
}
list(
top1_accuracy = correct_top1 / total,
top2_accuracy = correct_top2 / total,
top3_accuracy = correct_top3 / total
)
}
# Efficiency: Inference Time
timing_tfidf <- microbenchmark(
predict_next_word_tfidf("crude oil"),
times = 100
)
# Perplexity (approximation using unigram probabilities)
calculate_perplexity_tfidf <- function(test_tokens) {
probs <- rep(1 / length(unique(test_tokens)), length(test_tokens)) # uniform approx
perplexity <- 2^(-mean(log2(probs)))
return(perplexity)
}
accuracy_tfidf <- evaluate_accuracy_tfidf(unigrams$word, predict_next_word_tfidf)
perplexity_tfidf <- calculate_perplexity_tfidf(unigrams$word)
print(accuracy_tfidf)
## $top1_accuracy
## [1] 0.5931856
##
## $top2_accuracy
## [1] 0
##
## $top3_accuracy
## [1] 0.5931856
print(summary(timing_tfidf))
## expr min lq mean median
## 1 predict_next_word_tfidf("crude oil") 2.844901 3.266552 4.093369 3.617251
## uq max neval
## 1 4.148651 14.0685 100
print(perplexity_tfidf)
## [1] 957
# Combine results
accuracy_tfidf <- evaluate_accuracy_tfidf(unigrams$word, predict_next_word_tfidf)
timing_tfidf <- microbenchmark(predict_next_word_tfidf("crude oil"), times = 100)
perplexity_tfidf <- calculate_perplexity_tfidf(unigrams$word)
print(accuracy_tfidf)
## $top1_accuracy
## [1] 0.5931856
##
## $top2_accuracy
## [1] 0
##
## $top3_accuracy
## [1] 0.5931856
print(timing_tfidf)
## Unit: milliseconds
## expr min lq mean median
## predict_next_word_tfidf("crude oil") 2.779101 3.134351 3.699965 3.395151
## uq max neval
## 3.796101 11.5463 100
print(perplexity_tfidf)
## [1] 957
accuracy_ngram <- evaluate_accuracy(test_data, predict_next_word, top_n = 3)
timing_ngram <- microbenchmark(predict_next_word("crude oil", top_n = 3), times = 100)
perplexity_ngram <- calculate_perplexity(test_data)
print(accuracy_ngram)
## $top1_accuracy
## [1] 0.2701149
##
## $top2_accuracy
## [1] 0.03899836
##
## $top3_accuracy
## [1] 0.3210181
print(timing_ngram)
## Unit: milliseconds
## expr min lq mean median
## predict_next_word("crude oil", top_n = 3) 1.700101 1.750601 1.912535 1.799202
## uq max neval
## 1.919501 3.537701 100
print(perplexity_ngram)
## [1] 582.2335
results <- data.frame(
Model = c("TF-IDF", "N-gram"),
Top1_Accuracy = c(accuracy_tfidf$top1_accuracy, accuracy_ngram$top1_accuracy),
Top2_Accuracy = c(accuracy_tfidf$top2_accuracy, accuracy_ngram$top2_accuracy),
Top3_Accuracy = c(accuracy_tfidf$top3_accuracy, accuracy_ngram$top3_accuracy),
Perplexity = c(perplexity_tfidf, perplexity_ngram),
Mean_Inference_Time_ms = c(mean(as.numeric(timing_tfidf$time)) / 1e6,
mean(as.numeric(timing_ngram$time)) / 1e6)
)
library(knitr)
kable(results, caption = "TF-IDF vs N-gram Model Evaluation")
| Model | Top1_Accuracy | Top2_Accuracy | Top3_Accuracy | Perplexity | Mean_Inference_Time_ms |
|---|---|---|---|---|---|
| TF-IDF | 0.5931856 | 0.0000000 | 0.5931856 | 957.0000 | 3.699965 |
| N-gram | 0.2701149 | 0.0389984 | 0.3210181 | 582.2335 | 1.912535 |
# Visualization
library(ggplot2)
ggplot(results, aes(x = Model, y = Top1_Accuracy, fill = Model)) +
geom_bar(stat = "identity") +
ggtitle("Top-1 Accuracy Comparison")
ggplot(results, aes(x = Model, y = Mean_Inference_Time_ms, fill = Model)) +
geom_bar(stat = "identity") +
ggtitle("Inference Time (ms)")
Wrap your prediction functions inside profvis:
library(profvis)
# Profile TF-IDF prediction
profvis({
for (i in 1:100) {
predict_next_word_tfidf("crude oil")
}
})
# Profile N-gram prediction
profvis({
for (i in 1:100) {
predict_next_word("crude oil")
}
})
Missed N-grams: Rare trigrams or bigrams that appear infrequently in training.
Reasons:
Fixes:
Uncertainty can be estimated using:
Probability Distribution: If the predicted word has a low probability compared to others, uncertainty is high. Entropy: \[ H= -\sum_i P(w_i) log_2 P(w_i) \] Higher entropy → more uncertainty.
# Calculate entropy for N-gram predictions
calculate_entropy_ngram <- function(context, candidates) {
# Compute probabilities for candidate words
probs <- sapply(candidates, function(w3) {
ngram_probs(strsplit(context, " ")[[1]][1],
strsplit(context, " ")[[1]][2], w3)
})
probs <- probs / sum(probs) # Normalize to sum = 1
entropy <- -sum(probs * log2(probs + 1e-12)) # Avoid log(0)
return(entropy)
}
# Example usage for N-gram
context <- "crude oil"
candidates <- unigram_freq$word[1:10] # Top 10 frequent words
entropy_ngram <- calculate_entropy_ngram(context, candidates)
print(paste("N-gram Uncertainty (Entropy):", round(entropy_ngram, 4)))
## [1] "N-gram Uncertainty (Entropy): 1.373"
# Calculate entropy for TF-IDF predictions
calculate_entropy_tfidf <- function(input_text, candidates) {
input_vec <- tfidf$transform(create_dtm(itoken(input_text), vectorizer))
sims <- sapply(candidates, function(word) {
phrase <- paste(input_text, word)
vec <- tfidf$transform(create_dtm(itoken(phrase), vectorizer))
cosine_sim(as.numeric(input_vec), as.numeric(vec))
})
probs <- sims / sum(sims) # Normalize similarities to probabilities
entropy <- -sum(probs * log2(probs + 1e-12))
return(entropy)
}
# Example usage for TF-IDF
entropy_tfidf <- calculate_entropy_tfidf(context, candidates)
print(paste("TF-IDF Uncertainty (Entropy):", round(entropy_tfidf, 4)))
## [1] "TF-IDF Uncertainty (Entropy): 3.3124"
uncertainty_results <- data.frame(
Model = c("N-gram", "TF-IDF"),
Entropy = c(entropy_ngram, entropy_tfidf)
)
library(knitr)
kable(uncertainty_results, caption = "Prediction Uncertainty (Entropy)")
| Model | Entropy |
|---|---|
| N-gram | 1.373047 |
| TF-IDF | 3.312429 |
library(ggplot2)
ggplot(uncertainty_results, aes(x = Model, y = Entropy, fill = Model)) +
geom_bar(stat = "identity", width = 0.6) +
ggtitle("Prediction Uncertainty (Entropy)") +
ylab("Entropy") +
theme_minimal()
This section demonstrates an interactive Shiny app that predicts the next word using an N-gram language model. The app also includes visualizations similar to the SwiftKey dashboard.
=========================== Load Libraries =========================== library(shiny) library(ggplot2) library(dplyr) library(tidytext) library(tibble) library(tm) library(wordcloud) library(RColorBrewer) library(flexdashboard) # For gauge
=========================== Load and Prepare Data =========================== data(“crude”) corpus <- VCorpus(VectorSource(crude)) corpus <- tm_map(corpus, content_transformer(tolower)) corpus <- tm_map(corpus, removePunctuation) corpus <- tm_map(corpus, removeNumbers) corpus <- tm_map(corpus, removeWords, stopwords(“english”)) corpus <- tm_map(corpus, stripWhitespace)
text_df <- tibble(text = sapply(corpus, as.character))
Unigrams, Bigrams, Trigrams
unigrams <- text_df %>% unnest_tokens(word, text, token = “words”) %>% count(word, sort = TRUE) bigrams <- text_df %>% unnest_tokens(bigram, text, token = “ngrams”, n = 2) %>% count(bigram, sort = TRUE) trigrams <- text_df %>% unnest_tokens(trigram, text, token = “ngrams”, n = 3) %>% count(trigram, sort = TRUE)
bigram_freq <- bigrams %>% tidyr::separate(bigram, into = c(“w1”, “w2”), sep = ” “) trigram_freq <- trigrams %>% tidyr::separate(trigram, into = c(”w1”, “w2”, “w3”), sep = ” “)
=========================== Prediction Function (Backoff) =========================== predict_next_word <- function(input_text, top_n = 3) { input_text <- tolower(input_text) words <- strsplit(input_text, “\s+”)[[1]] len <- length(words)
# Try trigram if (len >= 2) { w1 <- words[len - 1] w2 <- words[len] trigram_matches <- trigram_freq %>% filter(w1 == !!w1, w2 == !!w2) %>% arrange(desc(n)) %>% head(top_n) if (nrow(trigram_matches) > 0) return(trigram_matches$w3) }
# Backoff to bigram if (len >= 1) { w2 <- words[len] bigram_matches <- bigram_freq %>% filter(w1 == !!w2) %>% arrange(desc(n)) %>% head(top_n) if (nrow(bigram_matches) > 0) return(bigram_matches$w2) }
# Backoff to unigram unigram_matches <- unigrams %>% arrange(desc(n)) %>% head(top_n) return(unigram_matches$word) }
=========================== Entropy Calculation =========================== calculate_entropy <- function(predictions) { probs <- rep(1 / length(predictions), length(predictions)) # uniform since no probabilities entropy <- -sum(probs * log2(probs)) return(entropy) }
=========================== Shiny UI =========================== ui <- fluidPage( titlePanel(“Next Word Predictor (N-gram Model)”),
sidebarLayout( sidebarPanel( textInput(“input_text”, “Enter a phrase:”, value = “crude oil”), actionButton(“predict_btn”, “Predict Next Word”), hr(), h4(“Documentation”), p(“This app predicts the next word using an N-gram language model with backoff strategy.”), p(“Visualizations include word frequency, bigrams, trigrams, word cloud, and uncertainty gauge.”) ),
mainPanel(
h3("Predicted Next Words:"),
verbatimTextOutput("prediction"),
plotOutput("barplot"),
hr(),
#h4("Prediction Uncertainty (Entropy):"),
#gaugeOutput("entropy_gauge"), # Gauge visualization
h4("Prediction Uncertainty (Entropy):"),
verbatimTextOutput("entropy_text"),
hr(),
tabsetPanel(
tabPanel("Word Frequency", plotOutput("word_freq_plot")),
tabPanel("Top Bigrams", plotOutput("bigram_plot")),
tabPanel("Top Trigrams", plotOutput("trigram_plot")),
tabPanel("Word Cloud", plotOutput("wordcloud_plot"))
)
)
) )
=========================== Shiny Server =========================== server <- function(input, output) { observeEvent(input\(predict_btn, { context <- input\)input_text predictions <- predict_next_word(context, top_n = 3)
output$prediction <- renderText({
paste("Top Predictions:", paste(predictions, collapse = ", "))
})
# Bar plot for predictions
freq_data <- data.frame(word = predictions, freq = seq(length(predictions), 1))
output$barplot <- renderPlot({
ggplot(freq_data, aes(x = reorder(word, -freq), y = freq, fill = word)) +
geom_bar(stat = "identity") +
ggtitle("Top Predicted Words") +
xlab("Word") + ylab("Rank") +
theme_minimal()
})
# Entropy Gauge
entropy_val <- calculate_entropy(predictions)
output$entropy_text <- renderText({
paste("Entropy of predictions:", round(entropy_val, 3))
})
#output$entropy_gauge <- renderGauge({
# gauge(entropy_val, min = 0, max = 2, symbol = "",
# sectors = gaugeSectors(success = c(0, 0.8), warning = c(0.8, 1.5), danger = c(1.5, 2)),
# label = "Entropy")
# }) })
# Word Frequency Plot output$word_freq_plot <- renderPlot({ ggplot(unigrams[1:20, ], aes(x = reorder(word, n), y = n)) + geom_bar(stat = “identity”, fill = “steelblue”) + coord_flip() + ggtitle(“Top 20 Most Frequent Words”) })
# Bigram Plot output$bigram_plot <- renderPlot({ ggplot(bigram_freq[1:20, ], aes(x = reorder(paste(w1, w2), n), y = n)) + geom_bar(stat = “identity”, fill = “darkgreen”) + coord_flip() + ggtitle(“Top 20 Bigrams”) })
# Trigram Plot output$trigram_plot <- renderPlot({ ggplot(trigram_freq[1:20, ], aes(x = reorder(paste(w1, w2, w3), n), y = n)) + geom_bar(stat = “identity”, fill = “purple”) + coord_flip() + ggtitle(“Top 20 Trigrams”) })
# Word Cloud output\(wordcloud_plot <- renderPlot({ wordcloud(words = unigrams\)word, freq = unigrams$n, min.freq = 2, max.words = 100, random.order = FALSE, colors = brewer.pal(8, “Dark2”)) }) }
=========================== Run App =========================== shinyApp(ui = ui, server = server)
The objective of this project was to design and evaluate a Next Word Prediction System using N-gram models and TF-IDF approaches, and to showcase the results through an interactive Shiny application with rich visualizations.
Data Preprocessing & Feature Engineering
Model Development
Evaluation
Computed Top-1, Top-2, and Top-3 accuracy, perplexity, and inference time.
Observed that:
Added uncertainty estimation (entropy) to measure prediction confidence.
Visualization & Dashboard
Developed a Shiny app that:
#inspect(dtm[1:5, 1:5])
sessionInfo()
## R version 4.5.1 (2025-06-13 ucrt)
## Platform: x86_64-w64-mingw32/x64
## Running under: Windows 11 x64 (build 22631)
##
## Matrix products: default
## LAPACK version 3.12.1
##
## locale:
## [1] LC_COLLATE=English_United States.utf8
## [2] LC_CTYPE=English_United States.utf8
## [3] LC_MONETARY=English_United States.utf8
## [4] LC_NUMERIC=C
## [5] LC_TIME=English_United States.utf8
##
## time zone: Asia/Calcutta
## tzcode source: internal
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] flexdashboard_0.6.2 shiny_1.11.1 profvis_0.4.0
## [4] knitr_1.50 text2vec_0.6.4 microbenchmark_1.5.0
## [7] textdata_0.4.5 stringr_1.5.2 tidyr_1.3.1
## [10] tidytext_0.4.3 tibble_3.3.0 dplyr_1.1.4
## [13] ggplot2_4.0.0 wordcloud_2.6 RColorBrewer_1.1-3
## [16] SnowballC_0.7.1 tm_0.7-16 NLP_0.3-2
##
## loaded via a namespace (and not attached):
## [1] gtable_0.3.6 xfun_0.53 bslib_0.9.0
## [4] htmlwidgets_1.6.4 lattice_0.22-7 tzdb_0.5.0
## [7] vctrs_0.6.5 tools_4.5.1 generics_0.1.4
## [10] parallel_4.5.1 janeaustenr_1.0.0 pkgconfig_2.0.3
## [13] tokenizers_0.3.0 Matrix_1.7-3 data.table_1.17.8
## [16] S7_0.2.0 lifecycle_1.0.4 compiler_4.5.1
## [19] farver_2.1.2 RhpcBLASctl_0.23-42 httpuv_1.6.16
## [22] htmltools_0.5.8.1 sass_0.4.10 yaml_2.3.10
## [25] later_1.4.4 pillar_1.11.1 crayon_1.5.3
## [28] jquerylib_0.1.4 cachem_1.1.0 mime_0.13
## [31] rsparse_0.5.3 tidyselect_1.2.1 digest_0.6.37
## [34] stringi_1.8.7 slam_0.1-55 purrr_1.1.0
## [37] labeling_0.4.3 fastmap_1.2.0 grid_4.5.1
## [40] cli_3.6.5 magrittr_2.0.4 utf8_1.2.6
## [43] readr_2.1.5 withr_3.0.2 promises_1.3.3
## [46] scales_1.4.0 float_0.3-3 rmarkdown_2.30
## [49] mlapi_0.1.1 hms_1.1.3 evaluate_1.0.5
## [52] rlang_1.1.6 Rcpp_1.1.0 xtable_1.8-4
## [55] glue_1.8.0 xml2_1.4.0 rstudioapi_0.17.1
## [58] jsonlite_2.0.0 lgr_0.5.0 R6_2.6.1
## [61] fs_1.6.6