Почему это важно?

Почему так важно уделять внимание не только метрикам качества, но и интерпретации полученных результатов? С одной стороны, мы хотим получать максимально точные предсказания, но с другой — было бы неплохо понимать, почему получен именно такой результат и никакой другой.

Один из примеров — эксперимент, проведенный учеными из Вашингтонского университета. Классификатор, различающей фотографии хаски и волков, достигал 90% доли правильных ответов. Как алгоритму удалось получить такой хороший результат? Всё оказалось довольно просто: модель принимала решения на основе фона картинки, а не характеристик животных. На заднем плане фотографий с волками в большинстве случаев присутствовал снег, тогда как у хаски — нет. Так, вместо одного классификатора, получился алгоритм, детектирующий снег. Это было бы сложно заметить без интерпретации модели (линк на статью).

Поскольку данная модель была обучена для эксперимента, ошибочные предсказания не привели бы к появлению серьезных проблем. Но что, если мы строим модель для её дальнейшего использования в реальной жизни? В некоторых сферах цена ошибки может быть очень велика. Иными словами, необходимость интерпретации появляется тогда, когда присутствуют определенные риски, например, финансовые или социальные, что может быть весьма актуально в нашем случае.

Далее мы рассмотрим один из алгоритмов для локальной интерпретации результатов — lime.

данные

library(tidyverse)
library(GGally)
library(caret)

В качестве данных используем небольшой датасет о кликах на рекламные объявления.

ads = read_csv("advertising.csv")
colnames(ads) = c('Daily_Time_Spent_on_Site', 'Age', 'Area_Income', 'Daily_Internet_Usage', 'Ad_Topic_Line', 'City', 'Male', 'Country', 'Timestamp', 'Clicked_on_Ad')
ads$City = as.factor(ads$City)
ads$Country = as.factor(ads$Country)
ads$Clicked_on_Ad = ifelse(ads$Clicked_on_Ad == 1, "Yes", "No")
ads$Clicked_on_Ad = as.factor(ads$Clicked_on_Ad)
kableExtra::kable(head(ads), format = "markdown")
Daily_Time_Spent_on_Site Age Area_Income Daily_Internet_Usage Ad_Topic_Line City Male Country Timestamp Clicked_on_Ad
68.95 35 61833.90 256.09 Cloned 5thgeneration orchestration Wrightburgh 0 Tunisia 2016-03-27 00:53:11 No
80.23 31 68441.85 193.77 Monitored national standardization West Jodi 1 Nauru 2016-04-04 01:39:02 No
69.47 26 59785.94 236.50 Organic bottom-line service-desk Davidton 0 San Marino 2016-03-13 20:35:42 No
74.15 29 54806.18 245.89 Triple-buffered reciprocal time-frame West Terrifurt 1 Italy 2016-01-10 02:31:19 No
68.37 35 73889.99 225.58 Robust logistical utilization South Manuel 0 Iceland 2016-06-03 03:36:18 No
59.99 23 59761.56 226.74 Sharable client-driven software Jamieberg 1 Norway 2016-05-19 14:30:17 No

Можно посмотреть на взаимосвязь части предикторов с помощью ggpairs. Фиолетовые точки – те, кто кликнул на рекламу, оранжевые – нет (обычно в жизни всё сложнее).

ggpairs(ads, progress = FALSE, columns = c('Daily_Time_Spent_on_Site', 'Age', 'Area_Income', 'Daily_Internet_Usage'), legend = c(1,1),
        aes(color=Clicked_on_Ad, alpha=0.5)) + 
  theme_modern(base_size = 10) + theme(legend.position = "bottom", axis.text = element_text(size=8), legend.text = element_text(size = 8), legend.title= element_text(size = 8)) 

строим модель

Теперь разделяем данные на обучающую (80%) и тестовую (20%) выборки…

# здесь в качестве предикторов используем только часть переменных, чтобы модель могла сделать больше ошибок для дальнейшей интерпретации
X = ads %>% dplyr::select(Daily_Time_Spent_on_Site, Age, Area_Income, Male, Clicked_on_Ad)
# делим
set.seed(17)
split = createDataPartition(X$Clicked_on_Ad, p=0.8, list = FALSE)
X_train = X[split, ]
X_valid = X[-split, ]

… и обучаем логистическую регрессию.

set.seed(17)
lr = train(Clicked_on_Ad ~ .,  # указываем таргет и предикторы (. - все)
           data = X_train,  # данные
           method = "glm", family = "binomial")  # метод 

Делаем предсказание для тестовой части и строим confusion matrix. Получилось весьма хорошо. В следующей секции посмотрим, почему же модель могла ошибиться.

# предсказываем
pred_lr_valid = predict(lr, newdata = X_valid, type="prob")
# если вероятность 1 класса больше порога 0.5, присваиваем нужный класс
pred_lr_valid_05 = as.factor(ifelse(pred_lr_valid$Yes > 0.5, "Yes", "No"))

# матрица: передаем предсказания и референс (правильные ответы)
confusionMatrix(data = pred_lr_valid_05, reference = X_valid$Clicked_on_Ad, positive = "Yes", mode="prec_recall")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction No Yes
##        No  99  13
##        Yes  1  87
##                                           
##                Accuracy : 0.93            
##                  95% CI : (0.8853, 0.9612)
##     No Information Rate : 0.5             
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.86            
##                                           
##  Mcnemar's Test P-Value : 0.003283        
##                                           
##               Precision : 0.9886          
##                  Recall : 0.8700          
##                      F1 : 0.9255          
##              Prevalence : 0.5000          
##          Detection Rate : 0.4350          
##    Detection Prevalence : 0.4400          
##       Balanced Accuracy : 0.9300          
##                                           
##        'Positive' Class : Yes             
## 

еще пара шагов, почти у цели

Создаем датасет с предсказаниями, и помечаем их либо как правильные, либо нет.

Note: вообще это не обязательно, со звездочкой, но позволяет рассмотреть отдельно правильные и неправильные предсказания

pred_df = data.frame(
  sample_id = 1:nrow(X_valid), # создаем колонку с номером наблюдения
  predict(lr, X_valid, type = "prob"), # добавляем предсказания (две колонки - вероятность Yes & No)
  actual = X_valid$Clicked_on_Ad)  # и правильные ответы
# с помощью пары манипуляций получаем итоговую принадлежность к классу по мнению модели
pred_df$prediction = colnames(pred_df)[2:3][apply(pred_df[, 2:3], 1, which.max)]
# и помечаем как правильное, если предикт совпадает с правильным ответом
pred_df$correct = ifelse(pred_df$actual == pred_df$prediction, "correct", "wrong")
head(pred_df)
##   sample_id           No        Yes actual prediction correct
## 1         1 0.0040881260 0.99591187    Yes        Yes correct
## 2         2 0.0003524952 0.99964750    Yes        Yes correct
## 3         3 0.0002552578 0.99974474    Yes        Yes correct
## 4         4 0.6811129016 0.31888710    Yes         No   wrong
## 5         5 0.9927374500 0.00726255     No         No correct
## 6         6 0.0043769305 0.99562307    Yes        Yes correct

Разделяем данные на две части в зависимости от правильности результата:

pred_cor = filter(pred_df, correct == "correct")
pred_wrong = filter(pred_df, correct == "wrong")

Случайным образом выбираем 3 наблюдения из каждой группы

set.seed(17)
test_data_cor = X_valid %>% mutate(sample_id = 1:nrow(X_valid)) %>%
  dplyr::filter(sample_id %in% pred_cor$sample_id) %>%
  sample_n(size = 3) %>%  
  remove_rownames() %>%
  tibble::column_to_rownames(var = "sample_id") %>% dplyr::select(-Clicked_on_Ad)

test_data_wrong = X_valid %>% mutate(sample_id = 1:nrow(X_valid)) %>%
  dplyr::filter(sample_id %in% pred_wrong$sample_id) %>%
  sample_n(size = 3) %>%  
  remove_rownames() %>%
  tibble::column_to_rownames(var = "sample_id") %>% dplyr::select(-Clicked_on_Ad)

lime

Загружаем пакет:

library(lime)

Создаем “объясняющий” объект explainer_caret с помощью функции lime(), которая принимает на вход

explainer_caret = lime(x=X_train, model=lr, bin_continuous = TRUE)

“Объясняем” предсказания с помощью функции explain(). В качестве основных параметров передаем

Также можно указать feature_select – способ отбора указанного числа признаков и ряд других аргументов, подробнее тут и тут.

explanation_cor = explain(test_data_cor, explainer_caret, n_labels = 1, n_features=4,
                          feature_select = "lasso_path",  n_permutations = 5000)
explanation_wrong = explain(test_data_wrong, explainer_caret, n_labels = 1, n_features=4,
                          feature_select = "lasso_path",  n_permutations = 5000)

Сначала посмотрим на правильное предсказание. Для 190 наблюдения значение предиктора Daily_Time_Spent_on_Site (ежедневное время пребывания на сайте) больше 78.8, и связано с принадлежностью к классу тех, кто не кликнет на рекламу. Остальные переменные наоборот “поддержали” отнесение данного случая к положительному (клик) классу.

plot_features(explanation_cor, ncol = 2) + labs(title = "LIME Feature Importance, correctly predicted")

В качестве примера, когда модель сделала неверное предсказание, возьмем кейс 129. Видим, что значение переменной Daily_Time_Spent_on_Site было больше 78.8, и пользователь был мужчиной (1). Так, несмотря на поддержку лейбла “кликнет” переменными Area_Income и Age, модель предсказала неверный класс – “не кликнет”.

plot_features(explanation_wrong, ncol = 2) + labs(title = "LIME Feature Importance", subtitle = "Incorrectly predicted cases")

картинки

library(magick)
explanation = .load_image_example()  # загрузим пример чтобы не обучать тут классификаторы картинок

lime также можно использовать при классификации текстов и картинок!

В первом ряду цветом выделены те части изображения, которые “поддерживают” решение модели, а во втором – наоборот, противоречат. Например, в первом примере модель определила, что на картинке есть клубника, и это действительно так (подсвечено синим). Но против данного решения выступили части помидора и яблочка.
Во втором случае модель нашла на картинке свечку, вероятно из-за похожести части яблока на её фитиль. Против такого решения снова выступила часть подмидора с:

plot_image_explanation(explanation, show_negative = TRUE)