This report explores the SwiftKey text corpus and builds a simple predictive text model using n-grams (1- to 4-grams). A backoff strategy is implemented to handle unseen word sequences. The model is pruned to reduce memory usage and improve runtime efficiency for potential Shiny deployment.
This section loads the SwiftKey dataset. If the files are not found locally, the dataset is downloaded and unzipped automatically.
find_swiftkey_file <- function(filename) {
candidates <- c(
getwd(),
file.path(getwd(), "final", "en_US"),
file.path(path.expand("~"), "Downloads"),
file.path(path.expand("~"), "Downloads", "capstone_data"),
file.path(path.expand("~"), "Downloads", "capstone_data", "final", "en_US")
)
for (d in candidates) {
p <- file.path(d, filename)
if (file.exists(p)) return(normalizePath(p))
}
for (d in candidates) {
if (dir.exists(d)) {
hits <- list.files(
d,
pattern = paste0("^", gsub("\\.", "\\\\.", filename), "$"),
recursive = TRUE,
full.names = TRUE
)
if (length(hits) > 0) return(normalizePath(hits[1]))
}
}
NA_character_
}
blogs_path <- find_swiftkey_file("en_US.blogs.txt")
news_path <- find_swiftkey_file("en_US.news.txt")
twitter_path <- find_swiftkey_file("en_US.twitter.txt")
if (any(is.na(c(blogs_path, news_path, twitter_path)))) {
zip_url <- "https://d396qusza40orc.cloudfront.net/dsscapstone/dataset/Coursera-SwiftKey.zip"
zip_file <- file.path(tempdir(), "Coursera-SwiftKey.zip")
unzip_dir <- file.path(tempdir(), "swiftkey")
download.file(zip_url, destfile = zip_file, mode = "wb")
unzip(zip_file, exdir = unzip_dir)
blogs_path <- list.files(unzip_dir, pattern = "^en_US.blogs\\.txt$", recursive = TRUE, full.names = TRUE)[1]
news_path <- list.files(unzip_dir, pattern = "^en_US.news\\.txt$", recursive = TRUE, full.names = TRUE)[1]
twitter_path <- list.files(unzip_dir, pattern = "^en_US.twitter\\.txt$", recursive = TRUE, full.names = TRUE)[1]
}
stopifnot(file.exists(blogs_path), file.exists(news_path), file.exists(twitter_path))
blogs <- readLines(blogs_path, encoding = "UTF-8", warn = FALSE)
news <- readLines(news_path, encoding = "UTF-8", warn = FALSE)
twitter <- readLines(twitter_path, encoding = "UTF-8", warn = FALSE)
text_all <- c(blogs, news, twitter)
# Sample for speed (adjust if you want more/less)
set.seed(123)
text <- sample(text_all, size = min(50000, length(text_all)))
stopifnot(is.character(text))
length(text)
## [1] 50000
Text is normalized by converting to lowercase, removing URLs and non-alphabetic characters, and collapsing extra whitespace.
clean_text <- function(x) {
x <- tolower(x)
x <- stri_replace_all_regex(x, "http\\S+|www\\S+", " ")
x <- stri_replace_all_regex(x, "[^a-z\\s']", " ")
x <- stri_replace_all_regex(x, "\\s+", " ")
x <- stri_trim_both(x)
x[nchar(x) > 0]
}
text_clean <- clean_text(text)
length(text_clean)
## [1] 49976
This section provides visible results including basic summaries and plots of the training data.
eda_stats <- data.frame(
Source = c("Blogs", "News", "Twitter"),
Lines = c(length(blogs), length(news), length(twitter)),
Words = c(
sum(stri_count_words(blogs)),
sum(stri_count_words(news)),
sum(stri_count_words(twitter))
),
Characters = c(
sum(nchar(blogs)),
sum(nchar(news)),
sum(nchar(twitter))
)
)
eda_stats
## Source Lines Words Characters
## 1 Blogs 899288 37546250 206824505
## 2 News 1010242 34762395 203223159
## 3 Twitter 2360148 30093372 162096031
set.seed(123)
sample_lines <- sample(text, size = min(5000, length(text)))
words_per_line <- stri_count_words(sample_lines)
summary(words_per_line)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 1.0 8.0 16.0 23.6 27.0 412.0
hist(
words_per_line,
breaks = 50,
main = "Histogram of Words per Line",
xlab = "Number of Words"
)
chars_per_line <- nchar(sample_lines)
summary(chars_per_line)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 3.0 42.0 87.0 131.8 151.0 2115.0
hist(
chars_per_line,
breaks = 50,
main = "Histogram of Characters per Line",
xlab = "Number of Characters"
)
tokens <- unlist(strsplit(text_clean, " "))
word_freq <- data.table(word = tokens)[, .N, by = word][order(-N)]
head(word_freq, 10)
## word N
## <char> <int>
## 1: the 55272
## 2: to 32083
## 3: and 28256
## 4: a 28072
## 5: of 23374
## 6: i 20015
## 7: in 19204
## 8: for 12854
## 9: is 12459
## 10: that 12035
barplot(
word_freq$N[1:10],
names.arg = word_freq$word[1:10],
las = 2,
main = "Top 10 Most Frequent Words",
ylab = "Frequency"
)
The corpus contains a large number of short text entries, especially from Twitter. Word and character distributions are right-skewed, with most lines being relatively short. These observations support using n-gram modeling with pruning to balance accuracy and efficiency.
make_ngrams <- function(lines, n) {
grams <- vector("list", length(lines))
k <- 1L
for (i in seq_along(lines)) {
w <- unlist(strsplit(lines[i], " ", fixed = TRUE))
if (length(w) < n) next
grams[[k]] <- vapply(
seq_len(length(w) - n + 1),
function(j) paste(w[j:(j+n-1)], collapse = " "),
character(1)
)
k <- k + 1L
}
grams <- unlist(grams, use.names = FALSE)
dt <- data.table(ngram = grams)[, .N, by = ngram]
setorder(dt, -N)
dt
}
uni <- make_ngrams(text_clean, 1)
bi <- make_ngrams(text_clean, 2)
tri <- make_ngrams(text_clean, 3)
quad <- make_ngrams(text_clean, 4)
list(unigrams=nrow(uni), bigrams=nrow(bi), trigrams=nrow(tri), quadgrams=nrow(quad))
## $unigrams
## [1] 55862
##
## $bigrams
## [1] 499501
##
## $trigrams
## [1] 894765
##
## $quadgrams
## [1] 995480
prune_ngram_table <- function(dt, n, min_count = 2, top_k = 3) {
parts <- tstrsplit(dt$ngram, " ", fixed = TRUE)
if (n == 1) {
out <- dt[N >= min_count]
out[, prob := N / sum(N)]
return(out[, .(next_word = ngram, prob)])
}
prefix <- do.call(paste, c(parts[1:(n-1)], list(sep = " ")))
nextw <- parts[[n]]
out <- data.table(prefix = prefix, next_word = nextw, N = dt$N)
out <- out[N >= min_count]
out[, prob := N / sum(N), by = prefix]
setorder(out, prefix, -prob)
out[, head(.SD, top_k), by = prefix]
}
uni_p <- prune_ngram_table(uni, 1, min_count = 5)
bi_p <- prune_ngram_table(bi, 2, min_count = 3)
tri_p <- prune_ngram_table(tri, 3, min_count = 2)
quad_p <- prune_ngram_table(quad,4, min_count = 2)
alpha <- 0.4
predict_next <- function(input, top_n = 3) {
w <- unlist(strsplit(clean_text(input), " "))
get <- function(dt, p, weight) {
hit <- dt[prefix == p]
if (nrow(hit) == 0) return(NULL)
hit[, .(next_word, score = prob * weight)]
}
candidates <- rbindlist(list(
if (length(w) >= 3) get(quad_p, paste(tail(w,3), collapse=" "), 1),
if (length(w) >= 2) get(tri_p, paste(tail(w,2), collapse=" "), alpha),
if (length(w) >= 1) get(bi_p, tail(w,1), alpha^2)
), fill = TRUE)
if (!is.null(candidates) && nrow(candidates) > 0) {
candidates <- candidates[, .(score = max(score)), by = next_word]
setorder(candidates, -score)
return(head(candidates, top_n))
}
head(uni_p[, .(next_word, score = prob)], top_n)
}
predict_next("i love")
## next_word score
## <char> <num>
## 1: you 0.11358025
## 2: the 0.04641975
## 3: it 0.02962963
predict_next("this is")
## next_word score
## <char> <num>
## 1: a 0.07927273
## 2: the 0.06981818
## 3: not 0.02981818
predict_next("in the middle of")
## next_word score
## <char> <num>
## 1: the 0.6904762
## 2: a 0.1666667
## 3: nowhere 0.0952381
system.time(predict_next("this is a simple test"))
## user system elapsed
## 0.002 0.000 0.003
set.seed(42)
test_lines <- sample(text_clean, size = floor(0.05 * length(text_clean)))
cases <- lapply(test_lines, function(x) {
w <- unlist(strsplit(x, " "))
if (length(w) < 4) return(NULL)
i <- sample(2:(length(w)-1), 1)
list(input = paste(w[1:i], collapse=" "), actual = w[i+1])
})
cases <- rbindlist(cases)
acc_at_k <- function(k) {
mean(sapply(seq_len(nrow(cases)), function(i) {
cases$actual[i] %in% predict_next(cases$input[i], top_n = k)$next_word
}))
}
accuracy_top1 <- acc_at_k(1)
accuracy_top2 <- acc_at_k(2)
accuracy_top3 <- acc_at_k(3)
data.frame(
Metric = c("Top-1 Accuracy", "Top-2 Accuracy", "Top-3 Accuracy"),
Value = c(accuracy_top1, accuracy_top2, accuracy_top3)
)
## Metric Value
## 1 Top-1 Accuracy 0.1650151
## 2 Top-2 Accuracy 0.2210254
## 3 Top-3 Accuracy 0.2537699
data.table(
model = c("unigram","bigram","trigram","quadgram"),
size_mb = c(
object.size(uni_p),
object.size(bi_p),
object.size(tri_p),
object.size(quad_p)
) / (1024^2)
)
## model size_mb
## <char> <num>
## 1: unigram 0.9995651
## 2: bigram 0.7919540
## 3: trigram 3.0848312
## 4: quadgram 1.5997772
This milestone report provided exploratory summaries (line/word/character counts, frequency tables, and basic plots) and implemented an n-gram predictive text model with backoff for unseen sequences. Pruning reduced memory usage and helped runtime while maintaining reasonable prediction accuracy, supporting future deployment in a Shiny application.