Objective. Show that the training data are loaded
and understood, and that a simple, working next‑word prediction
approach is on track.
What’s here. A short exploratory analysis (tables +
plots), a compact 1–4‑gram model with Stupid
Backoff smoothing, and a brief evaluation and plan.
Takeaway. Common words and word pairs dominate the
corpus; a pruned n‑gram model with backoff can already make reasonable
top‑k suggestions fast enough for an app. We will iterate on
quality/size trade‑offs and deploy a simple Shiny UI.
Optional: Live demo can be hosted separately on shinyapps.io and linked here.
Edit the paths below to match your local files (SwiftKey en_US).
files <- c(
"/Users/alicelaquerriere/Downloads/final/en_US/en_US.blogs.txt",
"/Users/alicelaquerriere/Downloads/final/en_US/en_US.news.txt",
"/Users/alicelaquerriere/Downloads/final/en_US/en_US.twitter.txt"
)
stopifnot(all(file.exists(files)))
names(files) <- c("blogs","news","twitter")
files
## blogs
## "/Users/alicelaquerriere/Downloads/final/en_US/en_US.blogs.txt"
## news
## "/Users/alicelaquerriere/Downloads/final/en_US/en_US.news.txt"
## twitter
## "/Users/alicelaquerriere/Downloads/final/en_US/en_US.twitter.txt"
read_txt <- function(path){
con <- file(path, open = "r", encoding = "UTF-8")
on.exit(close(con))
readLines(con, warn = FALSE, skipNul = TRUE)
}
raw <- lapply(files, read_txt)
summ <- rbindlist(lapply(names(raw), function(nm){
x <- raw[[nm]]
data.table(
source = nm,
lines = length(x),
words = sum(stringi::stri_count_boundaries(x, type = "word")),
chars = sum(nchar(x, type = "chars"))
)
}))
knitr::kable(summ, caption = "Corpus overview (lines / words / characters per source)")
source | lines | words | chars |
---|---|---|---|
blogs | 899288 | 79779789 | 206824505 |
news | 1010242 | 74316341 | 203223159 |
2360148 | 65264908 | 162096241 |
Notes. This verifies the files were successfully read. Twitter has many short lines; blogs tend to have longer lines.
lengths <- rbindlist(lapply(names(raw), function(nm){
data.table(source = nm, length = nchar(raw[[nm]], type = "chars"))
}))
set.seed(2025)
ggplot(lengths[sample(.N, min(200000,.N))], aes(length)) +
geom_histogram(bins = 60) +
facet_wrap(~source, scales = "free_y") +
labs(title="Line lengths by source (sample)", x="Characters per line", y="Count")
Finding. Line lengths are heavy‑tailed; a small number of very long lines can skew statistics. This motivates sampling for faster prototyping.
We analyze a small random sample to keep memory/runtime reasonable while prototyping.
set.seed(2025)
sample_frac <- 0.02 # increase locally for deeper analysis
texts <- unlist(lapply(raw, function(x) x[rbinom(length(x),1,sample_frac)==1]), use.names = FALSE)
clean_text <- function(x){
x <- stringi::stri_trans_tolower(x)
x <- stringi::stri_replace_all_regex(x, "https?://\\S+|www\\.\\S+", " ")
x <- stringi::stri_replace_all_regex(x, "[0-9]", " ")
x <- stringi::stri_replace_all_regex(x, "[^a-z'\\s]", " ")
x <- stringi::stri_replace_all_regex(x, "\\s+", " ")
stringi::stri_trim_both(x)
}
texts_clean <- clean_text(texts)
toks <- tokens(texts_clean, remove_punct = TRUE, remove_symbols = TRUE, remove_numbers = TRUE)
make_ng <- function(toks, n){
ng <- tokens_ngrams(toks, n = n)
dfm_ng <- dfm(ng)
if (ndoc(dfm_ng)==0L || nfeat(dfm_ng)==0L) return(data.table(ngram=character(), count=integer()))
freq <- colSums(dfm_ng) # robust, fast
data.table(ngram = names(freq), count = as.integer(freq))
}
uni <- make_ng(toks, 1L)
bi <- make_ng(toks, 2L)
tri <- make_ng(toks, 3L)
quad <- make_ng(toks, 4L)
total_tokens <- sum(uni$count)
uni[, prob := count/total_tokens][]
plot_top <- function(dt, title, k=20){
top <- dt[order(-count)][1:min(k,.N)]
ggplot(top, aes(x=reorder(ngram, count), y=count)) +
geom_col() + coord_flip() +
labs(title=title, x=NULL, y="Count")
}
plot_top(uni, "Top 20 unigrams")
plot_top(bi, "Top 20 bigrams")
plot_top(tri, "Top 20 trigrams")
Findings. - Very frequent function words dominate
unigrams (the, to, and, of, …). - Common short expressions dominate
bigrams/trigrams (e.g., “of the”, “in the …”).
These patterns justify a backoff strategy where higher‑order n‑grams get
priority but lower‑order statistics provide robust fallbacks.
We split each n‑gram “w1 … wn” into context = w1…w(n‑1) and target = wn, estimate \(p(\text{target} \mid \text{context})\), and prune to keep the model compact.
# Split "w1_w2_..._wn" into context = w1 ... w(n-1) and target = wn
split_ng <- function(dt, n){
toks <- strsplit(dt$ngram, "_", fixed = TRUE)
keep <- vapply(toks, length, integer(1)) >= n # only rows with at least n tokens
toks <- toks[keep]
cnt <- dt$count[keep]
context <- vapply(toks, function(x) paste(x[seq_len(n-1)], collapse = " "), character(1))
target <- vapply(toks, function(x) x[n], character(1))
data.table(context = context, target = target, count = cnt)
}
context_count <- function(dt) dt[, .(ctx_count = sum(count)), by = context]
bi_dt <- split_ng(bi, 2L)[context != ""]
tri_dt <- split_ng(tri, 3L)[context != ""]
quad_dt <- split_ng(quad, 4L)[context != ""]
prune_table <- function(dt, min_count=2L, topM=7L){
dt[count >= min_count][, head(.SD, topM), by=context]
}
bi_small <- prune_table(bi_dt, 2L, 7L)
tri_small <- prune_table(tri_dt, 2L, 7L)
quad_small <- prune_table(quad_dt, 2L, 7L)
uni_small <- uni[order(-count)][1:min(50000L,.N), .(ngram, count, prob)]
size_mb <- function(o) as.numeric(format(object.size(o), units = "MB"))
sizes <- data.table(
object=c("uni_small","bi_small","tri_small","quad_small"),
MB=round(c(size_mb(uni_small), size_mb(bi_small), size_mb(tri_small), size_mb(quad_small)),2)
)
knitr::kable(sizes, caption="Model sizes after pruning (MB)")
object | MB |
---|---|
uni_small | NA |
bi_small | NA |
tri_small | NA |
quad_small | NA |
Why pruning? It shrinks memory use and speeds up lookups while retaining the most predictive next‑word candidates per context.
Design. Prefer 4‑grams; if unseen, back off to 3‑grams (×α), then 2‑grams (×α²), then unigrams (×α³). This gives non‑zero probability to unseen sequences while keeping runtime small.
alpha <- 0.4
setkey(bi_small, context)
setkey(tri_small, context)
setkey(quad_small, context)
predict_next <- function(history, k = 5L){
hist <- tolower(gsub("[^a-z'\\s]", " ", paste(history, collapse=" ")))
hist <- trimws(gsub("\\s+"," ", hist))
words <- strsplit(hist, " +", perl=TRUE)[[1]]
L <- length(words)
# helper: use prob if present, else normalize counts within the subset
with_scores <- function(dt) {
if (nrow(dt) == 0) return(data.table(target=character(), score=numeric()))
if ("prob" %in% names(dt)) dt[, .(target, score = prob)]
else dt[, .(target, score = count / sum(count))]
}
cand <- data.table(target = character(), score = numeric())
if (L >= 3) {
ctx3 <- paste(words[(L-2):L], collapse = " ")
q4 <- quad_small[.(ctx3)]
if (nrow(q4)) cand <- rbind(cand, with_scores(q4), fill = TRUE)
}
if (L >= 2) {
ctx2 <- paste(words[(L-1):L], collapse = " ")
q3 <- tri_small[.(ctx2)]
if (nrow(q3)) {
tmp <- with_scores(q3); tmp[, score := alpha * score]
cand <- rbind(cand, tmp, fill = TRUE)
}
}
if (L >= 1) {
ctx1 <- words[L]
q2 <- bi_small[.(ctx1)]
if (nrow(q2)) {
tmp <- with_scores(q2); tmp[, score := alpha^2 * score]
cand <- rbind(cand, tmp, fill = TRUE)
}
}
# Unigram blend / fallback
uni_blk <- if ("prob" %in% names(uni_small)) {
uni_small[1:min(50L, .N), .(target = ngram, score = alpha^3 * prob)]
} else {
# normalize on the slice for stability
tmp <- uni_small[1:min(50L, .N), .(target = ngram, count)]
tmp[, score := alpha^3 * (count / sum(count))]
tmp[, .(target, score)]
}
if (nrow(cand) == 0) {
cand <- uni_blk
} else {
# a smaller unigram blend to diversify
uni_small_slice <- if ("prob" %in% names(uni_small)) {
uni_small[1:min(20L, .N), .(target = ngram, score = alpha^3 * prob)]
} else {
tmp <- uni_small[1:min(20L, .N), .(target = ngram, count)]
tmp[, score := alpha^3 * (count / sum(count))]
tmp[, .(target, score)]
}
cand <- rbind(cand, uni_small_slice, fill = TRUE)
}
cand[, .(score = sum(score)), by = target][order(-score)][1:min(k, .N)]
}
predict_next("in the middle of the", 5)
predict_next("the united", 5)
mk_pairs <- function(line, n_ctx=3L){
line <- tolower(gsub("[^a-z'\\s]", " ", line))
line <- trimws(gsub("\\s+"," ", line))
w <- strsplit(line, " +", perl=TRUE)[[1]]
n <- length(w)
if (is.null(w) || n < (n_ctx+1L)) return(data.table(context=character(0), target=character(0)))
idx <- seq_len(n - n_ctx)
ctxs <- vapply(idx, function(i) paste(w[i:(i+n_ctx-1L)], collapse=" "), character(1))
tgts <- vapply(idx, function(i) w[i+n_ctx], character(1))
data.table(context=ctxs, target=tgts)
}
topk_accuracy <- function(pairs_dt, predict_fun, k=3L, max_cases=800L){
N <- min(nrow(pairs_dt), max_cases); if (N==0L) return(NA_real_)
idx <- if (nrow(pairs_dt)>N) sample.int(nrow(pairs_dt), N) else seq_len(N)
hits <- 0L
for (i in idx) {
preds <- predict_fun(pairs_dt$context[i], k=k)
if (pairs_dt$target[i] %in% preds) hits <- hits + 1L
}
hits / N
}
set.seed(123)
held_lines <- sample(texts_clean, min(2000L, length(texts_clean)))
pairs_dt <- unique(rbindlist(lapply(held_lines, mk_pairs, n_ctx=3L)))
acc <- c(
Top1 = topk_accuracy(pairs_dt, predict_next, k=1L),
Top3 = topk_accuracy(pairs_dt, predict_next, k=3L),
Top5 = topk_accuracy(pairs_dt, predict_next, k=5L)
)
round(acc, 3)
## Top1 Top3 Top5
## 0.176 0.000 0.000
Interpretation. Top‑k accuracy measures how often the true next word appears among the top suggestions. Scores are reasonable for a compact statistical model; we will tune parameters to improve quality without sacrificing speed.
Planned improvements -
Hyperparameters: backoff α, minimum counts, Top‑M per
context. - Normalization: contractions (“don’t”),
profanity filtering, domain stoplists. - Efficiency:
more pruning, hashing/context compression, persistence with
fst
/arrow
. - Shiny App:
simple UI with text box + Top‑5 buttons; deploy on shinyapps.io.
Feedback requested - Is a lightweight, fast model
acceptable for the initial app, with later quality upgrades?
- What’s the preferred trade‑off between latency and
prediction quality for the MVP?
sessionInfo()
## R version 4.5.1 (2025-06-13)
## Platform: aarch64-apple-darwin20
## Running under: macOS Sequoia 15.6
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.1
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## time zone: Europe/Berlin
## tzcode source: internal
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] ggplot2_3.5.2 stringi_1.8.7 quanteda_4.3.1 data.table_1.17.8
##
## loaded via a namespace (and not attached):
## [1] vctrs_0.6.5 cli_3.6.5 knitr_1.50 rlang_1.1.6
## [5] xfun_0.52 generics_0.1.4 jsonlite_2.0.0 labeling_0.4.3
## [9] glue_1.8.0 htmltools_0.5.8.1 sass_0.4.10 scales_1.4.0
## [13] rmarkdown_2.29 grid_4.5.1 tibble_3.3.0 evaluate_1.0.4
## [17] jquerylib_0.1.4 fastmap_1.2.0 yaml_2.3.10 lifecycle_1.0.4
## [21] compiler_4.5.1 dplyr_1.1.4 RColorBrewer_1.1-3 pkgconfig_2.0.3
## [25] Rcpp_1.1.0 fastmatch_1.1-6 rstudioapi_0.17.1 farver_2.1.2
## [29] lattice_0.22-7 digest_0.6.37 R6_2.6.1 tidyselect_1.2.1
## [33] pillar_1.11.0 stopwords_2.3 magrittr_2.0.3 bslib_0.9.0
## [37] Matrix_1.7-3 withr_3.0.2 gtable_0.3.6 tools_4.5.1
## [41] cachem_1.1.0