This is code that will encompany an article that will appear in a special edition of a HR Data sets. The article is about explaining black-box machine learning models. In that article I’m showcasingpractical examples: Explaining text classification models with xgboost and lime
Below, you will find the code for the third part: Text classification with lime.
Here I am using HR dataset: HR feedback reviews. The data contains a text review of different items of organization feedback, as well as some additional information, like last evaluation, division, etc.
In this example, I will use the review feedback in order to classify whether or not the item was left. I am also combining review title and text.
emp_reviews <- read_csv("~/Data_set_HR.csv") %>%
mutate(Left = as.factor(ifelse(Attrition == "Yes", 1, 0)))
## Parsed with column specification:
## cols(
## .default = col_double(),
## Attrition = col_character(),
## BusinessTravel = col_character(),
## Department = col_character(),
## Education = col_character(),
## EducationField = col_character(),
## Gender = col_character(),
## JobRole = col_character(),
## MaritalStatus = col_character(),
## OverTime = col_character(),
## Feedback = col_character()
## )
## See spec(...) for full column specifications.
glimpse(emp_reviews)
## Observations: 500
## Variables: 35
## $ Age <dbl> 49, 33, 27, 32, 59, 30, 38, 36, 35, 2...
## $ Attrition <chr> "No", "No", "No", "No", "No", "No", "...
## $ BusinessTravel <chr> "Travel_Frequently", "Travel_Frequent...
## $ DailyRate <dbl> 279, 1392, 591, 1005, 1324, 1358, 216...
## $ Department <chr> "Research & Development", "Research &...
## $ DistanceFromHome <dbl> 8, 3, 2, 2, 3, 24, 23, 27, 16, 15, 26...
## $ Education <chr> "Others", "Master", "Others", "Colleg...
## $ EducationField <chr> "Life Sciences", "Life Sciences", "Me...
## $ EmployeeCount <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...
## $ EnvironmentSatisfaction <dbl> 3, 4, 1, 4, 3, 4, 4, 3, 1, 4, 1, 2, 2...
## $ Gender <chr> "Male", "Female", "Male", "Male", "Fe...
## $ HourlyRate <dbl> 61, 56, 40, 79, 81, 67, 44, 94, 84, 4...
## $ JobInvolvement <dbl> 2, 3, 3, 3, 4, 3, 2, 3, 4, 2, 3, 3, 4...
## $ JobLevel <dbl> 2, 1, 1, 1, 1, 1, 3, 2, 1, 2, 1, 1, 3...
## $ JobRole <chr> "Research Scientist", "Research Scien...
## $ JobSatisfaction <dbl> 2, 3, 2, 4, 1, 3, 3, 3, 2, 3, 3, 4, 1...
## $ MaritalStatus <chr> "Married", "Married", "Married", "Sin...
## $ MonthlyIncome <dbl> 5130, 2909, 3468, 3068, 2670, 2693, 9...
## $ MonthlyRate <dbl> 24907, 23159, 16632, 11864, 9964, 133...
## $ NumCompaniesWorked <dbl> 1, 1, 9, 0, 4, 1, 0, 6, 0, 0, 1, 0, 1...
## $ OverTime <chr> "No", "Yes", "No", "No", "Yes", "No",...
## $ PercentSalaryHike <dbl> 23, 11, 12, 13, 20, 22, 21, 13, 13, 1...
## $ PerformanceRating <dbl> 4, 3, 3, 3, 4, 4, 4, 3, 3, 3, 3, 3, 3...
## $ RelationshipSatisfaction <dbl> 4, 3, 4, 3, 1, 2, 2, 2, 3, 4, 4, 3, 3...
## $ StandardHours <dbl> 80, 80, 80, 80, 80, 80, 80, 80, 80, 8...
## $ StockOptionLevel <dbl> 1, 0, 1, 0, 3, 1, 0, 2, 1, 0, 1, 1, 1...
## $ TotalWorkingYears <dbl> 10, 8, 6, 8, 12, 1, 10, 17, 6, 10, 5,...
## $ TrainingTimesLastYear <dbl> 3, 3, 3, 2, 3, 2, 2, 3, 5, 3, 1, 2, 1...
## $ WorkLifeBalance <dbl> 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, 2, 3, 3...
## $ YearsAtCompany <dbl> 10, 8, 2, 7, 1, 1, 9, 7, 5, 9, 5, 2, ...
## $ YearsInCurrentRole <dbl> 7, 7, 2, 7, 0, 0, 7, 7, 4, 5, 2, 2, 9...
## $ YearsSinceLastPromotion <dbl> 1, 3, 2, 3, 0, 0, 1, 7, 0, 0, 4, 1, 8...
## $ YearsWithCurrManager <dbl> 7, 0, 2, 6, 0, 0, 8, 7, 3, 8, 3, 2, 8...
## $ Feedback <chr> "People are willing to share knowledg...
## $ Left <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
Whether an item was liked or not will thus be my response variable or label for classification.
emp_reviews %>%
ggplot(aes(x = Attrition, fill = Attrition)) +
geom_bar(alpha = 0.8) +
scale_fill_tableau(palette = "Classic Purple-Gray 6") +
guides(fill = FALSE)
Let’s split the data into train and test sets:
The first text model I’m looking at has been built similarly to the example model in the help for lime::interactive_text_explanations().
First, we need to prepare the data for modeling: we will need to convert the text to a document term matrix (dtm). There are different ways to do this. One is be with the text2vec package.
“Because of R’s copy-on-modify semantics, it is not easy to iteratively grow a DTM. Thus constructing a DTM, even for a small collections of documents, can be a serious bottleneck for analysts and researchers. It involves reading the whole collection of text documents into RAM and processing it as single vector, which can easily increase memory use by a factor of 2 to 4. The text2vec package solves this problem by providing a better way of constructing a document-term matrix.” https://cran.r-project.org/web/packages/text2vec/vignettes/text-vectorization.html
Alternatives to text2vec would be tm + SnowballC or you could work with the tidytext package.
The itoken() function creates vocabularies (here stemmed words), from which we can create the dtm with the create_dtm() function.
All preprocessing steps, starting from the raw text, need to be wrapped in a function that can then be pasted into the lime::lime() function; this is only necessary if you want to use your model with lime.
Now, this preprocessing function can be applied to both training and test data.
get_matrix <- function(text) {
it <- itoken(text, progressbar = FALSE)
create_dtm(it, vectorizer = hash_vectorizer())
}
dtm_train <- get_matrix(emp_reviews_train$Feedback)
str(dtm_train)
## Formal class 'dgCMatrix' [package "Matrix"] with 6 slots
## ..@ i : int [1:13080] 124 97 27 45 28 48 178 214 227 274 ...
## ..@ p : int [1:262145] 0 0 0 0 0 0 0 0 0 0 ...
## ..@ Dim : int [1:2] 400 262144
## ..@ Dimnames:List of 2
## .. ..$ : chr [1:400] "1" "2" "3" "4" ...
## .. ..$ : NULL
## ..@ x : num [1:13080] 1 1 1 1 1 1 1 1 1 1 ...
## ..@ factors : list()
dtm_test <- get_matrix(emp_reviews_test$Feedback)
str(dtm_test)
## Formal class 'dgCMatrix' [package "Matrix"] with 6 slots
## ..@ i : int [1:3322] 18 82 86 43 55 59 92 6 36 19 ...
## ..@ p : int [1:262145] 0 0 0 0 0 0 0 0 0 0 ...
## ..@ Dim : int [1:2] 100 262144
## ..@ Dimnames:List of 2
## .. ..$ : chr [1:100] "1" "2" "3" "4" ...
## .. ..$ : NULL
## ..@ x : num [1:3322] 1 1 1 1 1 1 1 1 1 1 ...
## ..@ factors : list()
And we use it to train a model with the xgboost package (just as in the example of the lime package).
xgb_model <- xgb.train(list(max_depth = 7,
eta = 0.1,
objective = "binary:logistic",
eval_metric = "error", nthread = 1),
xgb.DMatrix(dtm_train,
label = emp_reviews_train$Left == "1"),
nrounds = 50)
Let’s try it on the test data and see how it performs:
pred <- predict(xgb_model, dtm_test)
confusionMatrix(emp_reviews_test$Left,
as.factor(round(pred, digits = 0)))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 54 6
## 1 19 21
##
## Accuracy : 0.75
## 95% CI : (0.6534, 0.8312)
## No Information Rate : 0.73
## P-Value [Acc > NIR] : 0.3737
##
## Kappa : 0.4493
##
## Mcnemar's Test P-Value : 0.0164
##
## Sensitivity : 0.7397
## Specificity : 0.7778
## Pos Pred Value : 0.9000
## Neg Pred Value : 0.5250
## Prevalence : 0.7300
## Detection Rate : 0.5400
## Detection Prevalence : 0.6000
## Balanced Accuracy : 0.7588
##
## 'Positive' Class : 0
##
Okay, not a perfect score but good enough for me - right now, I’m more interested in the explanations of the model’s predictions. For this, we need to run the lime() function and give it
explainer <- lime(emp_reviews_train$Feedback,
xgb_model,
preprocess = get_matrix)
With this, we could right away call the interactive explainer Shiny app, where we can type any text we want into the field on the left and see the explanation on the right: words that are underlined green support the classification, red words contradict them.
What happens in the background in the app, we can do explicitly by calling the explain() function and give it
We can plot them either with the plot_text_explanations() function, which gives an output like in the Shiny app or we use the regular plot_features() function.
explanations <- lime::explain(emp_reviews_test$Feedback[15:18], explainer, n_labels = 1, n_features = 5)
plot_text_explanations(explanations)
plot_features(explanations)
As we can see, our explanations contain a lot of stop-words that don’t really make much sense as features in our model. So…
##let’s try a more complex example
Okay, our model above works but there are still common words and stop words in our model that LIME picks up on. Ideally, we would want to remove them before modeling and keep only relevant words. This we can accomplish by using additional steps and options in our preprocessing function.
Important to know is that whatever preprocessing we do with our text corpus, train and test data has to have the same features (i.e. words)! If we were to incorporate all the steps shown below into one function and call it separately on train and test data, we would end up with different words in our dtm and the predict() function won’t work any more. In the simple example above, it works because we have been using the hash_vectorizer().
Nevertheless, the lime::explain() function expects a preprocessing function that takes a character vector as input.
How do we go about this? First, we will need to create the vocabulary just from the training data. To reduce the number of words to only the most relevant I am performing the following steps:
-stem all words -remove step-words -prune vocabulary -transform into vector space
stem_tokenizer <- function(x) {
lapply(word_tokenizer(x),
SnowballC::wordStem,
language = "en")
}
stop_words = tm::stopwords(kind = "en")
# create prunded vocabulary
vocab_train <- itoken(emp_reviews_train$Feedback,
preprocess_function = tolower,
tokenizer = stem_tokenizer,
progressbar = FALSE)
v <- create_vocabulary(vocab_train,
stopwords = stop_words)
pruned_vocab <- prune_vocabulary(v,
doc_proportion_max = 0.99,
doc_proportion_min = 0.01)
vectorizer_train <- vocab_vectorizer(pruned_vocab)
This vector space can now be added to the preprocessing function, which we can then apply to both train and test data. Here, I am also transforming the word counts to tfidf values.
# preprocessing function
create_dtm_mat <- function(text, vectorizer = vectorizer_train) {
vocab <- itoken(text,
preprocess_function = tolower,
tokenizer = stem_tokenizer,
progressbar = FALSE)
dtm <- create_dtm(vocab,
vectorizer = vectorizer)
tfidf = TfIdf$new()
fit_transform(dtm, tfidf)
}
dtm_train2 <- create_dtm_mat(emp_reviews_train$Feedback)
str(dtm_train2)
## Formal class 'dgCMatrix' [package "Matrix"] with 6 slots
## ..@ i : int [1:6320] 29 80 334 335 6 7 83 182 96 288 ...
## ..@ p : int [1:516] 0 4 8 12 16 20 24 28 32 36 ...
## ..@ Dim : int [1:2] 400 515
## ..@ Dimnames:List of 2
## .. ..$ : chr [1:400] "1" "2" "3" "4" ...
## .. ..$ : chr [1:515] "abov" "Everi" "achiev" "option" ...
## ..@ x : num [1:6320] 0.209 0.258 0.157 0.292 0.122 ...
## ..@ factors : list()
dtm_test2 <- create_dtm_mat(emp_reviews_test$Feedback)
str(dtm_test2)
## Formal class 'dgCMatrix' [package "Matrix"] with 6 slots
## ..@ i : int [1:1484] 21 36 28 66 91 73 74 25 32 91 ...
## ..@ p : int [1:516] 0 1 1 1 2 3 3 5 5 5 ...
## ..@ Dim : int [1:2] 100 515
## ..@ Dimnames:List of 2
## .. ..$ : chr [1:100] "1" "2" "3" "4" ...
## .. ..$ : chr [1:515] "abov" "Everi" "achiev" "option" ...
## ..@ x : num [1:1484] 0.326 0.356 0.489 0.152 0.351 ...
## ..@ factors : list()
And we will train another gradient boosting model:
xgb_model2 <- xgb.train(params = list(max_depth = 10,
eta = 0.2,
objective = "binary:logistic",
eval_metric = "error", nthread = 1),
data = xgb.DMatrix(dtm_train2,
label = emp_reviews_train$Left == "1"),
nrounds = 500)
pred2 <- predict(xgb_model2, dtm_test2)
confusionMatrix(emp_reviews_test$Left,
as.factor(round(pred2, digits = 0)))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 47 13
## 1 14 26
##
## Accuracy : 0.73
## 95% CI : (0.632, 0.8139)
## No Information Rate : 0.61
## P-Value [Acc > NIR] : 0.008104
##
## Kappa : 0.4351
##
## Mcnemar's Test P-Value : 1.000000
##
## Sensitivity : 0.7705
## Specificity : 0.6667
## Pos Pred Value : 0.7833
## Neg Pred Value : 0.6500
## Prevalence : 0.6100
## Detection Rate : 0.4700
## Detection Prevalence : 0.6000
## Balanced Accuracy : 0.7186
##
## 'Positive' Class : 0
##
Unfortunately, this didn’t really improve the classification accuracy but let’s look at the explanations again:
explainer2 <- lime(emp_reviews_train$Feedback,
xgb_model2,
preprocess = create_dtm_mat)
explanations2 <- lime::explain(emp_reviews_test$Feedback[15:18], explainer2, n_labels = 1, n_features = 4)
plot_text_explanations(explanations2)
plot_features(explanations2)
The words that get picked up now make much more sense! So, even though making my model more complex didn’t improve “the numbers”, this second model is likely to be much better able to generalize to new reviews because it seems to pick up on words that make intuitive sense.
That’s why I’m sold on the benefits of adding explainer functions to most machine learning workflows - and why I love the lime package in R!
sessionInfo()
## R version 3.5.2 (2018-12-20)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 17763)
##
## Matrix products: default
##
## locale:
## [1] LC_COLLATE=English_Indonesia.1252 LC_CTYPE=English_Indonesia.1252
## [3] LC_MONETARY=English_Indonesia.1252 LC_NUMERIC=C
## [5] LC_TIME=English_Indonesia.1252
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] tm_0.7-6 NLP_0.2-0 qdap_2.3.2
## [4] RColorBrewer_1.1-2 qdapTools_1.3.3 qdapRegex_0.7.2
## [7] qdapDictionaries_1.0.7 lime_0.5.0 xgboost_0.90.0.2
## [10] caret_6.0-84 lattice_0.20-38 text2vec_0.5.1
## [13] ggthemes_4.2.0 forcats_0.4.0 stringr_1.4.0
## [16] dplyr_0.8.3 purrr_0.3.2 readr_1.3.1
## [19] tidyr_0.8.3 tibble_2.1.3 ggplot2_3.2.1
## [22] tidyverse_1.2.1
##
## loaded via a namespace (and not attached):
## [1] openNLPdata_1.5.3-4 colorspace_1.4-1 class_7.3-14
## [4] futile.logger_1.4.3 rstudioapi_0.10 SnowballC_0.6.0
## [7] prodlim_2018.04.18 fansi_0.4.0 lubridate_1.7.4
## [10] xml2_1.2.2 codetools_0.2-15 splines_3.5.2
## [13] knitr_1.24 shinythemes_1.1.2 zeallot_0.1.0
## [16] mlapi_0.1.0 jsonlite_1.6 venneuler_1.1-0
## [19] rJava_0.9-10 broom_0.5.2 shiny_1.3.2
## [22] compiler_3.5.2 httr_1.4.1 backports_1.1.4
## [25] assertthat_0.2.1 Matrix_1.2-14 lazyeval_0.2.2
## [28] cli_1.1.0 later_0.8.0 formatR_1.7
## [31] htmltools_0.3.6 tools_3.5.2 igraph_1.2.4.1
## [34] gtable_0.3.0 glue_1.3.1 reshape2_1.4.3
## [37] Rcpp_1.0.2 slam_0.1-45 cellranger_1.1.0
## [40] vctrs_0.2.0 gdata_2.18.0 nlme_3.1-137
## [43] iterators_1.0.12 timeDate_3043.102 gender_0.5.2
## [46] gower_0.2.1 xfun_0.8 xlsxjars_0.6.1
## [49] rvest_0.3.4 mime_0.7 gtools_3.8.1
## [52] xlsx_0.6.1 XML_3.98-1.20 MASS_7.3-51.4
## [55] scales_1.0.0 ipred_0.9-9 hms_0.5.0
## [58] promises_1.0.1 parallel_3.5.2 lambda.r_1.2.3
## [61] yaml_2.2.0 gridExtra_2.3 rpart_4.1-13
## [64] stringi_1.4.3 plotrix_3.7-6 foreach_1.4.7
## [67] e1071_1.7-2 openNLP_0.2-6 lava_1.6.6
## [70] chron_2.3-53 rlang_0.4.0 pkgconfig_2.0.3
## [73] bitops_1.0-6 evaluate_0.14 recipes_0.1.6
## [76] htmlwidgets_1.3 labeling_0.3 tidyselect_0.2.5
## [79] plyr_1.8.4 magrittr_1.5 R6_2.4.0
## [82] generics_0.0.2 pillar_1.4.2 haven_2.1.1
## [85] withr_2.1.2 survival_2.43-3 RCurl_1.95-4.12
## [88] nnet_7.3-12 modelr_0.1.5 crayon_1.3.4
## [91] futile.options_1.0.1 wordcloud_2.6 utf8_1.1.4
## [94] rmarkdown_1.14 grid_3.5.2 readxl_1.3.1
## [97] data.table_1.12.2 reports_0.1.4 ModelMetrics_1.2.2
## [100] digest_0.6.21 xtable_1.8-4 httpuv_1.5.1
## [103] RcppParallel_4.4.3 stats4_3.5.2 munsell_0.5.0
## [106] glmnet_2.0-18