limeПочему так важно уделять внимание не только метрикам качества, но и интерпретации полученных результатов? С одной стороны, мы хотим получать максимально точные предсказания, но с другой — было бы неплохо понимать, почему получен именно такой результат и никакой другой.
Один из примеров — эксперимент, проведенный учеными из Вашингтонского университета. Классификатор, различающей фотографии хаски и волков, достигал 90% доли правильных ответов. Как алгоритму удалось получить такой хороший результат? Всё оказалось довольно просто: модель принимала решения на основе фона картинки, а не характеристик животных. На заднем плане фотографий с волками в большинстве случаев присутствовал снег, тогда как у хаски — нет. Так, вместо одного классификатора, получился алгоритм, детектирующий снег. Это было бы сложно заметить без интерпретации модели (линк на статью).
Поскольку данная модель была обучена для эксперимента, ошибочные предсказания не привели бы к появлению серьезных проблем. Но что, если мы строим модель для её дальнейшего использования в реальной жизни? В некоторых сферах цена ошибки может быть очень велика. Иными словами, необходимость интерпретации появляется тогда, когда присутствуют определенные риски, например, финансовые или социальные, что может быть весьма актуально в нашем случае.
Далее мы рассмотрим один из алгоритмов для локальной интерпретации результатов — lime.
library(tidyverse)
library(GGally)
library(caret)
В качестве данных используем небольшой датасет о кликах на рекламные объявления.
ads = read_csv("advertising.csv")
colnames(ads) = c('Daily_Time_Spent_on_Site', 'Age', 'Area_Income', 'Daily_Internet_Usage', 'Ad_Topic_Line', 'City', 'Male', 'Country', 'Timestamp', 'Clicked_on_Ad')
ads$City = as.factor(ads$City)
ads$Country = as.factor(ads$Country)
ads$Clicked_on_Ad = ifelse(ads$Clicked_on_Ad == 1, "Yes", "No")
ads$Clicked_on_Ad = as.factor(ads$Clicked_on_Ad)
kableExtra::kable(head(ads), format = "markdown")
| Daily_Time_Spent_on_Site | Age | Area_Income | Daily_Internet_Usage | Ad_Topic_Line | City | Male | Country | Timestamp | Clicked_on_Ad |
|---|---|---|---|---|---|---|---|---|---|
| 68.95 | 35 | 61833.90 | 256.09 | Cloned 5thgeneration orchestration | Wrightburgh | 0 | Tunisia | 2016-03-27 00:53:11 | No |
| 80.23 | 31 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 2016-04-04 01:39:02 | No |
| 69.47 | 26 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 2016-03-13 20:35:42 | No |
| 74.15 | 29 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 2016-01-10 02:31:19 | No |
| 68.37 | 35 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 2016-06-03 03:36:18 | No |
| 59.99 | 23 | 59761.56 | 226.74 | Sharable client-driven software | Jamieberg | 1 | Norway | 2016-05-19 14:30:17 | No |
Можно посмотреть на взаимосвязь части предикторов с помощью ggpairs. Фиолетовые точки – те, кто кликнул на рекламу, оранжевые – нет (обычно в жизни всё сложнее).
ggpairs(ads, progress = FALSE, columns = c('Daily_Time_Spent_on_Site', 'Age', 'Area_Income', 'Daily_Internet_Usage'), legend = c(1,1),
aes(color=Clicked_on_Ad, alpha=0.5)) +
theme_modern(base_size = 10) + theme(legend.position = "bottom", axis.text = element_text(size=8), legend.text = element_text(size = 8), legend.title= element_text(size = 8))
Теперь разделяем данные на обучающую (80%) и тестовую (20%) выборки…
# здесь в качестве предикторов используем только часть переменных, чтобы модель могла сделать больше ошибок для дальнейшей интерпретации
X = ads %>% dplyr::select(Daily_Time_Spent_on_Site, Age, Area_Income, Male, Clicked_on_Ad)
# делим
set.seed(17)
split = createDataPartition(X$Clicked_on_Ad, p=0.8, list = FALSE)
X_train = X[split, ]
X_valid = X[-split, ]
… и обучаем логистическую регрессию.
set.seed(17)
lr = train(Clicked_on_Ad ~ ., # указываем таргет и предикторы (. - все)
data = X_train, # данные
method = "glm", family = "binomial") # метод
Делаем предсказание для тестовой части и строим confusion matrix. Получилось весьма хорошо. В следующей секции посмотрим, почему же модель могла ошибиться.
# предсказываем
pred_lr_valid = predict(lr, newdata = X_valid, type="prob")
# если вероятность 1 класса больше порога 0.5, присваиваем нужный класс
pred_lr_valid_05 = as.factor(ifelse(pred_lr_valid$Yes > 0.5, "Yes", "No"))
# матрица: передаем предсказания и референс (правильные ответы)
confusionMatrix(data = pred_lr_valid_05, reference = X_valid$Clicked_on_Ad, positive = "Yes", mode="prec_recall")
## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 99 13
## Yes 1 87
##
## Accuracy : 0.93
## 95% CI : (0.8853, 0.9612)
## No Information Rate : 0.5
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.86
##
## Mcnemar's Test P-Value : 0.003283
##
## Precision : 0.9886
## Recall : 0.8700
## F1 : 0.9255
## Prevalence : 0.5000
## Detection Rate : 0.4350
## Detection Prevalence : 0.4400
## Balanced Accuracy : 0.9300
##
## 'Positive' Class : Yes
##
Создаем датасет с предсказаниями, и помечаем их либо как правильные, либо нет.
Note: вообще это не обязательно, со звездочкой, но позволяет рассмотреть отдельно правильные и неправильные предсказания
pred_df = data.frame(
sample_id = 1:nrow(X_valid), # создаем колонку с номером наблюдения
predict(lr, X_valid, type = "prob"), # добавляем предсказания (две колонки - вероятность Yes & No)
actual = X_valid$Clicked_on_Ad) # и правильные ответы
# с помощью пары манипуляций получаем итоговую принадлежность к классу по мнению модели
pred_df$prediction = colnames(pred_df)[2:3][apply(pred_df[, 2:3], 1, which.max)]
# и помечаем как правильное, если предикт совпадает с правильным ответом
pred_df$correct = ifelse(pred_df$actual == pred_df$prediction, "correct", "wrong")
head(pred_df)
## sample_id No Yes actual prediction correct
## 1 1 0.0040881260 0.99591187 Yes Yes correct
## 2 2 0.0003524952 0.99964750 Yes Yes correct
## 3 3 0.0002552578 0.99974474 Yes Yes correct
## 4 4 0.6811129016 0.31888710 Yes No wrong
## 5 5 0.9927374500 0.00726255 No No correct
## 6 6 0.0043769305 0.99562307 Yes Yes correct
Разделяем данные на две части в зависимости от правильности результата:
pred_cor = filter(pred_df, correct == "correct")
pred_wrong = filter(pred_df, correct == "wrong")
Случайным образом выбираем 3 наблюдения из каждой группы
set.seed(17)
test_data_cor = X_valid %>% mutate(sample_id = 1:nrow(X_valid)) %>%
dplyr::filter(sample_id %in% pred_cor$sample_id) %>%
sample_n(size = 3) %>%
remove_rownames() %>%
tibble::column_to_rownames(var = "sample_id") %>% dplyr::select(-Clicked_on_Ad)
test_data_wrong = X_valid %>% mutate(sample_id = 1:nrow(X_valid)) %>%
dplyr::filter(sample_id %in% pred_wrong$sample_id) %>%
sample_n(size = 3) %>%
remove_rownames() %>%
tibble::column_to_rownames(var = "sample_id") %>% dplyr::select(-Clicked_on_Ad)
Загружаем пакет:
library(lime)
Создаем “объясняющий” объект explainer_caret с помощью функции lime(), которая принимает на вход
x: данныеmodel: модельexplainer_caret = lime(x=X_train, model=lr, bin_continuous = TRUE)
“Объясняем” предсказания с помощью функции explain(). В качестве основных параметров передаем
explainer_caret: объект, объясняющий модельn_features: число признаков, используемых для объясненияn_permutations: число генерируемых искусственных наблюденийТакже можно указать feature_select – способ отбора указанного числа признаков и ряд других аргументов, подробнее тут и тут.
explanation_cor = explain(test_data_cor, explainer_caret, n_labels = 1, n_features=4,
feature_select = "lasso_path", n_permutations = 5000)
explanation_wrong = explain(test_data_wrong, explainer_caret, n_labels = 1, n_features=4,
feature_select = "lasso_path", n_permutations = 5000)
Сначала посмотрим на правильное предсказание. Для 190 наблюдения значение предиктора Daily_Time_Spent_on_Site (ежедневное время пребывания на сайте) больше 78.8, и связано с принадлежностью к классу тех, кто не кликнет на рекламу. Остальные переменные наоборот “поддержали” отнесение данного случая к положительному (клик) классу.
plot_features(explanation_cor, ncol = 2) + labs(title = "LIME Feature Importance, correctly predicted")
В качестве примера, когда модель сделала неверное предсказание, возьмем кейс 129. Видим, что значение переменной Daily_Time_Spent_on_Site было больше 78.8, и пользователь был мужчиной (1). Так, несмотря на поддержку лейбла “кликнет” переменными Area_Income и Age, модель предсказала неверный класс – “не кликнет”.
plot_features(explanation_wrong, ncol = 2) + labs(title = "LIME Feature Importance", subtitle = "Incorrectly predicted cases")
library(magick)
explanation = .load_image_example() # загрузим пример чтобы не обучать тут классификаторы картинок
lime также можно использовать при классификации текстов и картинок!
В первом ряду цветом выделены те части изображения, которые “поддерживают” решение модели, а во втором – наоборот, противоречат. Например, в первом примере модель определила, что на картинке есть клубника, и это действительно так (подсвечено синим). Но против данного решения выступили части помидора и яблочка.
Во втором случае модель нашла на картинке свечку, вероятно из-за похожести части яблока на её фитиль. Против такого решения снова выступила часть подмидора с:
plot_image_explanation(explanation, show_negative = TRUE)