TEXT CLASSIFICATION WITH TIDY DATA PRINCIPLES
I am an enthusiastic proponent of using tidy data principles for dealing with text data . This kind of apporach offers a fludent and flexible option not just for exploratory data analysis, but also for machine learning for text, including both unsupervised machine learning and supervised machine learning. I haven’t written much about supervised machine learning for text, i.e predicitve modeling, using tidy data principles, so let’s walk through an example workflow for this a text classification task.
This post lays out a workflow similar to the appoach taken by Emil Hvifeldt in predicting authorship of the Federalist Papers, so be sure to check out that post to see more examples. Also, I’ve been giving some workshops lately that included material on this, such as for IBM Community Day: AI ands at the 2018 Deming Conference. I have slides and code available at those links. This material is also some of what we’ll cover in the short course I am teaching at the SDSS conference in 2019 so come on out to Bellevue if you are interested!
Let’s build a supervised machine learning model that learns the difference between text from Pride and Prejudice and text from The War of the Worlds. We can access the full texts of these works from Project Gutenberg via the gutenbergr package.
library(tidyverse)
## -- Attaching packages --------
## v ggplot2 3.3.2 v purrr 0.3.4
## v tibble 3.0.3 v dplyr 1.0.0
## v tidyr 1.1.0 v stringr 1.4.0
## v readr 1.3.1 v forcats 0.5.0
## -- Conflicts -----------------
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(gutenbergr)
library(ggthemes)
theme_set(theme_light())
titles <- c(
'The War of the Worlds',
'Pride and Prejudice'
)
books <- gutenberg_works(title %in% titles) %>%
gutenberg_download(meta_fields = 'title') %>%
mutate(document = row_number())
## Warning: `filter_()` is deprecated as of dplyr 0.7.0.
## Please use `filter()` instead.
## See vignette('programming') for more help
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
## Warning: `distinct_()` is deprecated as of dplyr 0.7.0.
## Please use `distinct()` instead.
## See vignette('programming') for more help
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
## Determining mirror for Project Gutenberg from http://www.gutenberg.org/robot/harvest
## Using mirror http://aleph.gutenberg.org
books
## # A tibble: 19,504 x 4
## gutenberg_id text title document
## <int> <chr> <chr> <int>
## 1 36 "The War of the Worlds" The War of th~ 1
## 2 36 "" The War of th~ 2
## 3 36 "by H. G. Wells [1898]" The War of th~ 3
## 4 36 "" The War of th~ 4
## 5 36 "" The War of th~ 5
## 6 36 " But who shall dwell in these worl~ The War of th~ 6
## 7 36 " inhabited? . . . Are we or the~ The War of th~ 7
## 8 36 " World? . . . And how are all t~ The War of th~ 8
## 9 36 " KEPLER (quoted in The Anatom~ The War of th~ 9
## 10 36 "" The War of th~ 10
## # ... with 19,494 more rows
We have the text data now, and let’s frame the kind of prediction problem we are going to work on. Imagine that we take each book and cut it up into lines, like strips of paper (confetti) with an individual line on each paper. Let’s train a model that can take an individual line and give us a probability that this book comes from Pride and Prejudice vs from The War of the Worlds as a first step, let’s transform out text data into a tidy format.
library(tidytext)
tidy_books <- books %>%
unnest_tokens(word, text) %>%
group_by(word) %>%
filter(n()>10) %>%
ungroup()
tidy_books
## # A tibble: 159,707 x 4
## gutenberg_id title document word
## <int> <chr> <int> <chr>
## 1 36 The War of the Worlds 1 the
## 2 36 The War of the Worlds 1 war
## 3 36 The War of the Worlds 1 of
## 4 36 The War of the Worlds 1 the
## 5 36 The War of the Worlds 3 by
## 6 36 The War of the Worlds 6 but
## 7 36 The War of the Worlds 6 who
## 8 36 The War of the Worlds 6 shall
## 9 36 The War of the Worlds 6 in
## 10 36 The War of the Worlds 6 these
## # ... with 159,697 more rows
We’ve also removed the rarest words in that step, keeping only words in our dataset that occur more than 10 times total over both books.
The tidy data strucuture is a great fit for performing exploratory data analysis, making lots of plots, and deeply understanding what is in the dataset we would like to use for modeling. Interest of space, let’s just show 1 example plot we could use for EDA, looking at the most frequent words in each book after removing stop words.
library(stopwords)
tidy_books %>%
count(title, word, sort = TRUE) %>%
anti_join(get_stopwords()) %>%
group_by(title) %>%
top_n(20) %>%
ungroup() %>%
ggplot(aes(reorder_within(word, n, title), n, fill =title
)) +
geom_col(alpha = 0.8, show.legend = FALSE) + scale_x_reordered() + coord_flip() + facet_wrap(~title, scales = 'free') + scale_y_continuous(expand = c(0,0)) + labs(
x = NULL, y = 'Word count',
title = 'Most frequent words after stop words',
subtitle = 'Words like sad occupu similar ranks but other words are quite different'
)
## Joining, by = "word"
## Selecting by n
We could perform other kinds of EDA like looking at tf-idf by book but we’ll stop here for now and move on to building a classification model
Let’s get this data ready for modeling. We want to split our data into training and testing sets, to use for building the model and evaluating the model. Here I use the rsample package to split the data, it works great with a tidy data workflow. Let’s go back to the books dataset (no the tidy_books dataset) because the lines of text are our individual observations
library(rsample)
books_split <- books %>%
select(document) %>%
initial_split()
train_data <- training(books_split)
test_data <- testing(books_split)
You can also use functions from the rsample package to generate rsampled datasets, but the specific modeling approach we’re going to use will do that for us so we only need a simple train/test split
Now we want to tranform out training data from a tidy data strucutre to a sparse matrix to use for our machine learning algorithm.
sparse_words <- tidy_books %>%
count(document, word) %>%
inner_join(train_data) %>%
cast_sparse(document, word, n)
## Joining, by = "document"
class(sparse_words)
## [1] "dgCMatrix"
## attr(,"package")
## [1] "Matrix"
dim(sparse_words)
## [1] 12024 1652
We have 12028 trainning observations and 1652 features at this point; text feature space handled in this way is very high dimensional, so we need to take that into account when considering our modeling approach.
One reason this overall approach is flexible and wonderful is that you could at this point cbind() other columns, suc has non-text numeric data onto this sparse matrix. Then you can use this combination of text and non-text data as your predictors in the machine learning algorithm, and the regularized regression algorithm we are going to use will find which are important for your problem space. I’ve experienced greate results with my real world prediction problems using this approach.
We also need to build a dataframe with a response variable to associate each of the rownames() of the sparse matrix with a title, to use as the quantity we will predict in the model.
word_rownames <- as.integer(rownames(sparse_words))
books_joined <- data_frame(document = word_rownames) %>%
left_join(books) %>%
select(document, title)
## Warning: `data_frame()` is deprecated as of tibble 1.1.0.
## Please use `tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
## Joining, by = "document"
Now it’s time to train out classification model! Let’s use the glmnet package to fit a logisitc regression with LASSO regularization. It’s a great fit for text classification because the variable selection that LASSO regularization performs can tell you which worlds are important for your prediction problem. The glmnet package also supports parallel processing with very little hassle, so we can train on multiple coeres with cross-validation on the training set using cv.glmet().
library(glmnet)
## Loading required package: Matrix
##
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
## Loaded glmnet 4.0-2
library(doMC)
## Loading required package: foreach
##
## Attaching package: 'foreach'
## The following objects are masked from 'package:purrr':
##
## accumulate, when
## Loading required package: iterators
## Loading required package: parallel
registerDoMC(cores = 8)
is_jane <- books_joined$title == "Pride and Prejudice"
model <- cv.glmnet(sparse_words, is_jane,
family = "binomial",
parallel = TRUE, keep = TRUE
)
We did it! If you are used to looking at the default plot methods for glmnet’s output here is what we;’re dealing with
plot(model)
plot(model$glmnet.fit)
Those default plots are helpful, but we want to dig more deeply into our model and understand it better. FOr starters, what predictors are driving the model? Let’s use broom to check out the coefficients of the model, for the largest value of lambda within 1 standard error of the minimum
library(broom)
coefs <- model$glmnet.fit %>%
tidy() %>%
filter(lambda == model$lambda.1se)
Which coeficents are the largest in size, in each direction?
coefs %>% group_by(estimate >0) %>%
top_n(10, abs(estimate)) %>%
ungroup() %>%
ggplot(aes(fct_reorder(term, estimate), estimate, fill = estimate >0)) + geom_col(alpha =0.8, show.legend = FALSE) + coord_flip() + labs(
x = NULL,
title = 'Coefficients that increase/decrease probability the most',
subtitle = 'A document mentioning Martians is unlikely to be written by Jane Austen'
)
Make sense, if you ask me!
We want to evaluate how well this model is doing using the test data that we held out and did not use for training the model. There are a couple steps to thism but we can deeply understand the performance using the model output and tidy data principles. Let’s create a dataframe that tells us ,for each document in the test set, the probability of being written by Jane Austen
intercept <- coefs %>%
filter(term =='(Intercept)') %>% pull(estimate)
classifications <- tidy_books %>%
inner_join(test_data) %>%
inner_join(coefs, by = c('word' = 'term')) %>%
group_by(document) %>%
summarize(score = sum(estimate)) %>%
mutate(probability = plogis(intercept + score))
## Joining, by = "document"
## `summarise()` ungrouping output (override with `.groups` argument)
classifications
## # A tibble: 3,999 x 3
## document score probability
## <int> <dbl> <dbl>
## 1 1 -2.22 0.133
## 2 6 1.98 0.911
## 3 19 0.284 0.652
## 4 25 -0.990 0.344
## 5 26 0.832 0.764
## 6 30 -4.24 0.0199
## 7 34 -4.27 0.0194
## 8 36 -0.641 0.426
## 9 42 -2.10 0.147
## 10 48 2.85 0.961
## # ... with 3,989 more rows
Now let’s use the yardsticks package to calculate some model performance metrics. For example, what does the ROC curve look like?
library(yardstick)
## For binary classification, the first factor level is assumed to be the event.
## Use the argument `event_level = "second"` to alter this as needed.
##
## Attaching package: 'yardstick'
## The following object is masked from 'package:readr':
##
## spec
comment_classes <- classifications %>%
left_join( books %>% select(title, document), by ='document') %>% mutate(title = as.factor(title))
comment_classes %>% roc_curve(title, probability) %>%
ggplot(aes(x = 1-specificity, y = sensitivity)) +
geom_line(
color ='midnightblue', size =1.5
) + geom_abline(lty =2, alpha = 0.5, color = 'gray50', size = 1.2) + labs(
title = 'ROC curve for text classification using regularized regression',
subtitle = 'Predicting whether text was written by Jane Austen or H.G Welss'
)
Looks pretty nice. What is the AUC on the test data?
comment_classes %>%
roc_auc(title, probability)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 roc_auc binary 0.972
Not shabbly.
What about a confusion matrix? Let’s use probability of 0.5 as our cutoff point, for example.
comment_classes %>% mutate(
prediction = case_when(
probability > 0.5 ~ 'Pride and Prejudice',
TRUE ~ 'The War of the Worlds'
),
prediction = as.factor(prediction)
) %>% conf_mat(title, prediction)
## Truth
## Prediction Pride and Prejudice The War of the Worlds
## Pride and Prejudice 2514 189
## The War of the Worlds 151 1145
More text from ‘The War of the Worlds’ was misclassified with this particular cutoff point.
Let’s talk about these misclassifications. In the real world, it’s usually worth my white to understand a bit about both false negatives and false positives for my models. Which documents here were incorrectly predicted to be written by Jane Austen, at the extreme probability end?
comment_classes %>%
filter(
probability > .9,
title == 'The War of the Worlds'
) %>%
sample_n(10) %>%
inner_join(books %>%
select(document, text)) %>% select(probability, text)
## Joining, by = "document"
## # A tibble: 10 x 2
## probability text
## <dbl> <chr>
## 1 0.943 "meteorites are rounded more or less completely. It was, howeve~
## 2 0.988 "\"After all, it may not be so much we may have to learn before-~
## 3 0.969 "did not know what to make of her. One shell, and they would ha~
## 4 0.988 "I cannot but regret, now that I am concluding my story, how lit~
## 5 0.993 "She explained that they had as much as thirty pounds in gold,"
## 6 0.962 "She was steaming at such a pace that in a minute she seemed hal~
## 7 0.915 "purpose, and taking possession of the conquered country. They ~
## 8 0.924 "\"You'll see, sir. They carry a kind of box, sir, that shoots ~
## 9 0.925 "of rage in our final tragedy, easy enough to blame; for they kn~
## 10 0.930 "they did not wish to destroy the country but only to crush and ~
Some of these are quite short, and some of these I would have difficulty classifying as a human reader quite familiar with these texts
Which documents here were incorrectly predicted to not be written by Jane Austen?
comment_classes %>%
filter(
probability < .3, title == 'Pride and Prejudice'
) %>% sample_n(10) %>% inner_join(books %>% select(document, text)) %>% select(probability, text)
## Joining, by = "document"
## # A tibble: 10 x 2
## probability text
## <dbl> <chr>
## 1 0.150 in the room a near relation of my patroness. I happened to overh~
## 2 0.204 and telling again what had already been written; and when it clo~
## 3 0.0534 Newcastle, a place quite northward, it seems, and there they are~
## 4 0.00401 hot-pressed paper, well covered with a lady's fair, flowing hand~
## 5 0.162 all families within the reach of my influence; and on these grou~
## 6 0.221 was as sincere as her brother's in sending it. Four sides of pap~
## 7 0.233 then said, of my conduct, my manners, my expressions during the ~
## 8 0.148 shrubbery. They both set off, and the conjectures of the remaini~
## 9 0.163 to a window to enjoy its prospect. The hill, crowned with wood, ~
## 10 0.285 in the regiment since the preceding Wednesday; several of the of~
These are the texts that are from Pride and Prejudice but the model did not correctly identify as such.
This workflow demonstrates how tidy data principles can be used not just for data cleaning and munging, but for sophisticated machine learning as well. I used my own tidytext package, and alaso a couple of packages from the tidymodels metapackage which provides lots of valuable functions and infrastructure for this kind of work. One thing I want to note is that my data was not in a tidy data structure the whole time during this process, and that is what I usually find myself doing in real world situations. I use tidy tools to clean and prepare data, then transform to a data strucutre like a sparse matric for modeling, then tidy() the output of the machine learning algorithm so I can visualize it and understand it in other ways as well. We talk about this workflow in our book, and it’s one that serves me well in the real world.