# Load required libraries
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.2 ✔ tibble 3.3.0
## ✔ lubridate 1.9.4 ✔ tidyr 1.3.1
## ✔ purrr 1.0.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidytext)
library(tokenizers)
library(data.table)
##
## Attaching package: 'data.table'
##
## The following objects are masked from 'package:lubridate':
##
## hour, isoweek, mday, minute, month, quarter, second, wday, week,
## yday, year
##
## The following objects are masked from 'package:dplyr':
##
## between, first, last
##
## The following object is masked from 'package:purrr':
##
## transpose
# Load text data
twitter <- readLines("C:/Users/vashu/Documents/en_US.twitter.txt", skipNul = TRUE, encoding = "UTF-8")
blogs <- readLines("C:/Users/vashu/Documents/en_US.blogs.txt", skipNul = TRUE, encoding = "UTF-8")
news <- readLines("C:/Users/vashu/Documents/en_US.news.txt", skipNul = TRUE, encoding = "UTF-8")
all_text <- c(twitter, blogs, news)
# Sample data
set.seed(123)
sample_text <- sample(all_text, size = 0.01 * length(all_text))
text_df <- tibble(line = 1:length(sample_text), text = sample_text)
# Clean text
clean_text <- text_df %>%
mutate(text = str_to_lower(text)) %>%
mutate(text = str_replace_all(text, "[^a-z\\s]", " "))
# Create n-grams
unigrams <- clean_text %>%
unnest_tokens(word, text, token = "words") %>%
count(word, sort = TRUE)
bigrams <- clean_text %>%
unnest_tokens(bigram, text, token = "ngrams", n = 2) %>%
count(bigram, sort = TRUE)
trigrams <- clean_text %>%
unnest_tokens(trigram, text, token = "ngrams", n = 3) %>%
count(trigram, sort = TRUE)
# Separate bigrams and trigrams
bigram_split <- bigrams %>%
separate(bigram, into = c("word1", "word2"), sep = " ") %>%
filter(!is.na(word1) & !is.na(word2))
trigram_split <- trigrams %>%
separate(trigram, into = c("word1", "word2", "word3"), sep = " ") %>%
filter(!is.na(word1) & !is.na(word2) & !is.na(word3))
# Prediction function
predict_next_word <- function(input, uni_df, bi_df, tri_df) {
input <- tolower(input)
words <- unlist(strsplit(input, " "))
n <- length(words)
if (n >= 2) {
w1 <- words[n-1]
w2 <- words[n]
tri_match <- tri_df %>% filter(word1 == w1, word2 == w2)
if (nrow(tri_match) > 0) return(tri_match$word3[which.max(tri_match$n)])
}
if (n >= 1) {
w1 <- words[n]
bi_match <- bi_df %>% filter(word1 == w1)
if (nrow(bi_match) > 0) return(bi_match$word2[which.max(bi_match$n)])
}
return(uni_df$word[1])
}
# Test predictions
predict_next_word("i am", unigrams, bigram_split, trigram_split)
## [1] "not"
predict_next_word("how are", unigrams, bigram_split, trigram_split)
## [1] "you"
predict_next_word("thank you", unigrams, bigram_split, trigram_split)
## [1] "for"