data("Boston", package = "MASS")
BostonInterpretable Machine Learning
Pendahuluan
Pembelajaran mesin (machine learning) memiliki potensi besar dalam meningkatkan kualitas sebuah produk, proses, maupun kajian penelitian. Namun, perangkat komputer yang mengadopsi model machine learning biasanya belum mampu menjelaskan hasil atau output prediksi yang dihasilkan.
Beberapa model klasik sederhana yang dapat diinterpretasikan (interpretable models) di kalangan praktisi machine learning, data saintis, statistisi, dan para peneliti yaitu regresi linear, pohon keputusan, dan aturan asosiasi sudah sangat populer digunakan. Semakin berkembangnya teknologi komputasi, berkembang pula model yang lebih kompleks yang mampu mengenali data berukuran besar dan tidak memiliki struktur dengan waktu yang singkat.
Random forest, model ensemble (gabungan), neural network dan variannya (deep learning) termasuk ke dalam model kompleks yang hasil prediksi atau keputusannya terkadang tanpa disertai penjelasan dan interpretasi lebih lanjut. Model-model ini dikenal dengan istilah model black box.
Interpretabilitas
Tidak ada definisi matematis dari interpretabilitas. Definisi (non-matematis) menurut Miller (2017)1:
1 Miller, Tim. 2017. “Explanation in artificial intelligence: Insights from the social sciences.”. arXiv Preprint arXiv:1706.07269.
Interpretabilitas adalah sejauh mana manusia dapat memahami penyebab keputusan.
Definisi lain adalah:
.. sejauh mana manusia dapat secara konsisten memprediksi hasil model2.
2 Kim, Been, Rajiv Khanna, and Oluwasanmi O. Koyejo. 2016. “Examples are not enough, learn to criticize! Criticism for interpretability”. Advances in Neural Information Processing Systems.
Semakin tinggi interpretabilitas model machine learning, semakin mudah bagi seseorang untuk memahami mengapa keputusan atau prediksi tertentu telah dibuat. Suatu model dapat diinterpretasikan lebih baik daripada model lain jika keputusannya lebih mudah dipahami manusia daripada keputusan dari model lain.
Studi Kasus 1: Dataset perumahan di Boston
| crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | black | lstat | medv |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0.00632 | 18 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1 | 296 | 15.3 | 396.90 | 4.98 | 24.0 |
| 0.02731 | 0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242 | 17.8 | 396.90 | 9.14 | 21.6 |
| 0.02729 | 0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242 | 17.8 | 392.83 | 4.03 | 34.7 |
| 0.03237 | 0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222 | 18.7 | 394.63 | 2.94 | 33.4 |
| 0.06905 | 0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222 | 18.7 | 396.90 | 5.33 | 36.2 |
| 0.02985 | 0 | 2.18 | 0 | 0.458 | 6.430 | 58.7 | 6.0622 | 3 | 222 | 18.7 | 394.12 | 5.21 | 28.7 |
CART
# First we fit a machine learning model on the Boston housing data
rf <- rpart(medv ~ ., data = Boston)rpart.plot::rpart.plot(rf)X <- Boston[-which(names(Boston) == "medv")]
mod <- Predictor$new(rf, data = X)SHAP
# Then we explain the first instance of the dataset with the Shapley method:
x.interest <- X[1, ]
shapley <- Shapley$new(mod, x.interest = x.interest)
shapley
## Interpretation method: Shapley
## Predicted value: 27.427273, Average prediction: 22.532806 (diff = 4.894466)
##
## Analysed predictor:
## Prediction task: unknown
##
##
## Analysed data:
## Sampling from data.frame with 506 rows and 13 columns.
##
##
## Head of results:
## feature phi phi.var feature.value
## 1 crim 0.6707019 3.040883 crim=0.00632
## 2 zn 0.0000000 0.000000 zn=18
## 3 indus 0.0000000 0.000000 indus=2.31
## 4 chas 0.0000000 0.000000 chas=0
## 5 nox 0.0000000 0.000000 nox=0.538
## 6 rm 0.5648295 46.013896 rm=6.575# Look at the results in a table
kable(shapley$results)| feature | phi | phi.var | feature.value |
|---|---|---|---|
| crim | 0.6707019 | 3.040884 | crim=0.00632 |
| zn | 0.0000000 | 0.000000 | zn=18 |
| indus | 0.0000000 | 0.000000 | indus=2.31 |
| chas | 0.0000000 | 0.000000 | chas=0 |
| nox | 0.0000000 | 0.000000 | nox=0.538 |
| rm | 0.5648295 | 46.013896 | rm=6.575 |
| age | 0.0000000 | 0.000000 | age=65.2 |
| dis | -0.4806171 | 5.852113 | dis=4.09 |
| rad | 0.0000000 | 0.000000 | rad=1 |
| tax | 0.0000000 | 0.000000 | tax=296 |
| ptratio | 0.0000000 | 0.000000 | ptratio=15.3 |
| black | 0.0000000 | 0.000000 | black=396.9 |
| lstat | 3.2864649 | 25.089114 | lstat=4.98 |
# Or as a plot
plot(shapley)# Explain another instance
shapley$explain(X[2, ])
plot(shapley)Studi Kasus 2: Dataset Iris
Ilustrasi berikut digunakan untuk permasalahan multikelas
CART
rf_iris <- rpart(Species ~ ., data = iris)rpart.plot::rpart.plot(rf_iris)X <- iris[-which(names(iris) == "Species")]
mod <- Predictor$new(rf_iris, data = X, type = "prob")SHAP
# Then we explain the first instance of the dataset with the Shapley() method:
shapley <- Shapley$new(mod, x.interest = X[1, ])
kable(shapley$results)| feature | class | phi | phi.var | feature.value |
|---|---|---|---|---|
| Sepal.Length | setosa | 0.0000000 | 0.0000000 | Sepal.Length=5.1 |
| Sepal.Width | setosa | 0.0000000 | 0.0000000 | Sepal.Width=3.5 |
| Petal.Length | setosa | 0.6400000 | 0.2327273 | Petal.Length=1.4 |
| Petal.Width | setosa | 0.0000000 | 0.0000000 | Petal.Width=0.2 |
| Sepal.Length | versicolor | 0.0000000 | 0.0000000 | Sepal.Length=5.1 |
| Sepal.Width | versicolor | 0.0000000 | 0.0000000 | Sepal.Width=3.5 |
| Petal.Length | versicolor | -0.5276006 | 0.2012435 | Petal.Length=1.4 |
| Petal.Width | versicolor | 0.1505636 | 0.1117980 | Petal.Width=0.2 |
| Sepal.Length | virginica | 0.0000000 | 0.0000000 | Sepal.Length=5.1 |
| Sepal.Width | virginica | 0.0000000 | 0.0000000 | Sepal.Width=3.5 |
| Petal.Length | virginica | -0.1123994 | 0.0502612 | Petal.Length=1.4 |
| Petal.Width | virginica | -0.1505636 | 0.1117980 | Petal.Width=0.2 |
plot(shapley)# You can also focus on one class
mod <- Predictor$new(rf_iris, data = X, type = "prob", class = "setosa")
shapley <- Shapley$new(mod, x.interest = X[1, ])
shapley$results
## feature phi phi.var feature.value
## 1 Sepal.Length 0.00 0.0000000 Sepal.Length=5.1
## 2 Sepal.Width 0.00 0.0000000 Sepal.Width=3.5
## 3 Petal.Length 0.78 0.1733333 Petal.Length=1.4
## 4 Petal.Width 0.00 0.0000000 Petal.Width=0.2plot(shapley)Studi Kasus 3: Pinjaman Bank
Suatu perusahaan perbankan meneliti 75 jenis skema pinjaman yang telah diberi rating oleh para customernya pada loan.csv berikut
# Read the Data
loan <- read.csv("loan.csv", header=T)
loan| besar.pinjaman | lama.pembayaran | bunga | pembayaran.per.bulan | banyak.cash.back | rating |
|---|---|---|---|---|---|
| 70 | 4 | 1 | 130 | 10.0 | 68.40297 |
| 120 | 3 | 5 | 15 | 2.0 | 33.98368 |
| 70 | 4 | 1 | 260 | 9.0 | 59.42551 |
| 50 | 4 | 0 | 140 | 14.0 | 93.70491 |
| 110 | 2 | 2 | 180 | 1.5 | 29.50954 |
| 110 | 2 | 0 | 125 | 1.0 | 33.17409 |
Peubah yang digunakan ialah:
- Besar pinjaman (dalam juta rupiah)
- Lama pembayaran (dalam tahun)
- Tambahan bunga yang ditetapkan (dalam %)
- Pembayaran per bulan (dalam 10000)
- Banyak cash back yang diterapkan pada skema tersebut
Tujuan penelitian yang dilakukan ialah memprediksi rating skema pinjaman berdasarkan variabel-variabel tersebut. Manakah variabel yang paling penting?
Persiapan data
# Random sampling
samplesize <- 0.60 * nrow(loan)
set.seed(80)
index <- sample( seq_len(nrow(loan)), size = samplesize )# Create training and test set
datatrain <- loan[ index, ]
datatest <- loan[ -index, ]Penskalaan
Langkah pertama adalah penskalaan gugus data. Hal ini penting karena jika tidak, suatu peubah mungkin mempunyai dampak besar pada peubah hasil prediksi hanya karena skalanya. Terkadang peubah yang belum dilakukan proses scaling cenderung menghasilkan hasil yang tidak memiliki makna.
## Scale data for neural network
max <- apply(loan , 2 , max)
min <- apply(loan, 2 , min)
scaled <- as.data.frame(scale(loan, center = min, scale = max - min))
kable(head(scaled))| besar.pinjaman | lama.pembayaran | bunga | pembayaran.per.bulan | banyak.cash.back | rating |
|---|---|---|---|---|---|
| 0.1818182 | 0.6 | 0.2 | 0.406250 | 0.7142857 | 0.6655928 |
| 0.6363636 | 0.4 | 1.0 | 0.046875 | 0.1428571 | 0.2106846 |
| 0.1818182 | 0.6 | 0.2 | 0.812500 | 0.6428571 | 0.5469406 |
| 0.0000000 | 0.6 | 0.0 | 0.437500 | 1.0000000 | 1.0000000 |
| 0.5454545 | 0.2 | 0.4 | 0.562500 | 0.1071429 | 0.1515514 |
| 0.5454545 | 0.2 | 0.0 | 0.390625 | 0.0714286 | 0.1999845 |
ANN dengan neuralnet
Selain terdapat di dalam paket keras dan tensorflow, komunitas R juga mengembangkan model jaringan syaraf tiruan (ANN) pada paket neuralnet.
Partisi data
# creating training and test set
trainNN <- scaled[index , ]
testNN <- scaled[-index , ]Membangun model
# fit neural network
set.seed(2)
NN <- neuralnet(
rating ~ besar.pinjaman + lama.pembayaran + bunga + pembayaran.per.bulan + banyak.cash.back, trainNN,
hidden = 3 ,
linear.output = T
)Visualisasi ANN
# plot neural network
plot(NN)Prediksi
## Prediction using neural network
predict_testNN <- compute(NN, testNN[,c(1:5)])
predict_testNN <- (predict_testNN$net.result * (max(loan$rating) - min(loan$rating))) + min(loan$rating)plot(datatest$rating, predict_testNN, col='blue', pch=16, ylab = "predicted rating NN", xlab = "real rating")
abline(0,1)# Calculate Root Mean Square Error (RMSE)
RMSE.NN <- (sum((datatest$rating - predict_testNN)^2) / nrow(datatest)) ^ 0.5
cat("Metrik RMSE: ", RMSE.NN)
## Metrik RMSE: 6.945017Feature importance
X <- loan[which(names(loan) != "rating")]
predictor <- Predictor$new(NN, data = X, y = loan$rating)
imp <- FeatureImp$new(predictor, loss = "mae")
## Warning: package 'lubridate' was built under R version 4.2.2
## Warning: package 'timechange' was built under R version 4.2.2
## Warning: package 'cli' was built under R version 4.2.3
## Warning: package 'stringr' was built under R version 4.2.3
## Warning: package 'digest' was built under R version 4.2.2
## Warning: package 'htmltools' was built under R version 4.2.2
## Warning: package 'htmlwidgets' was built under R version 4.2.2
## Warning: package 'rlang' was built under R version 4.2.3
## Warning: package 'rstudioapi' was built under R version 4.2.2
## Warning: package 'jsonlite' was built under R version 4.2.2
## Warning: package 'prediction' was built under R version 4.2.3
## Warning: package 'Metrics' was built under R version 4.2.3
## Warning: package 'hms' was built under R version 4.2.2
## Warning: package 'evaluate' was built under R version 4.2.2
## Warning: package 'vctrs' was built under R version 4.2.3
## Warning: package 'foreach' was built under R version 4.2.3
## Warning: package 'purrr' was built under R version 4.2.3
## Warning: package 'xfun' was built under R version 4.2.2
plot(imp)imp$results
## feature importance.05 importance importance.95 permutation.error
## 1 besar.pinjaman 0.9993735 1.0006164 1.0011689 42.45750
## 2 banyak.cash.back 0.9996252 1.0001133 1.0006106 42.43615
## 3 lama.pembayaran 0.9992501 0.9997796 1.0004272 42.42199
## 4 bunga 0.9994564 0.9996783 0.9999468 42.41770
## 5 pembayaran.per.bulan 0.9976793 0.9982261 1.0000772 42.35607SHAP
shapley <- Shapley$new(predictor, x.interest = X[1, ])
shapley$plot()Studi Kasus 4: Prediksi Gejala Stroke
According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths.3
This dataset is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient.4
Attribute Information:
id: unique identifiergender: “Male”, “Female” or “Other”age: age of the patienthypertension: 0 if the patient doesn’t have hypertension, 1 if the patient has hypertensionheart_disease: 0 if the patient doesn’t have any heart diseases, 1 if the patient has a heart diseaseever_married: “No” or “Yes”work_type: “children”, “Govt_jov”, “Never_worked”, “Private” or “Self-employed”Residence_type: “Rural” or “Urban”avg_glucose_level: average glucose level in bloodbmi: body mass indexsmoking_status: “formerly smoked”, “never smoked”, “smokes” or “Unknown”*stroke: 1 if the patient had a stroke or 0 if not
Note: Unknown in smoking_status means that the information is unavailable for this patient
stroke <- read_csv("healthcare-dataset-stroke-data.csv") %>%
mutate(hypertension = recode_factor(hypertension, `1` = "Yes", `0` = "No"),
heart_disease = recode_factor(heart_disease, `1` = "Yes", `0` = "No"),
stroke = factor(stroke),
bmi = as.numeric(as.character(bmi))) %>%
mutate_if(is.character, as.factor)
## Rows: 5110 Columns: 12
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (6): gender, ever_married, work_type, Residence_type, bmi, smoking_status
## dbl (6): id, age, hypertension, heart_disease, avg_glucose_level, stroke
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
## Warning in eval(cols[[col]], .data, parent.frame()): NAs introduced by coercion
stroke$id <- NULLNote: We will use the DALEX package for Explanatory Model Analysis (EMA). To use some of the functions in this package, categorical predictor variables need to be converted to factors.
skim(stroke)| Name | stroke |
| Number of rows | 5110 |
| Number of columns | 11 |
| _______________________ | |
| Column type frequency: | |
| factor | 8 |
| numeric | 3 |
| ________________________ | |
| Group variables | None |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| gender | 0 | 1 | FALSE | 3 | Fem: 2994, Mal: 2115, Oth: 1 |
| hypertension | 0 | 1 | FALSE | 2 | No: 4612, Yes: 498 |
| heart_disease | 0 | 1 | FALSE | 2 | No: 4834, Yes: 276 |
| ever_married | 0 | 1 | FALSE | 2 | Yes: 3353, No: 1757 |
| work_type | 0 | 1 | FALSE | 5 | Pri: 2925, Sel: 819, chi: 687, Gov: 657 |
| Residence_type | 0 | 1 | FALSE | 2 | Urb: 2596, Rur: 2514 |
| smoking_status | 0 | 1 | FALSE | 4 | nev: 1892, Unk: 1544, for: 885, smo: 789 |
| stroke | 0 | 1 | FALSE | 2 | 0: 4861, 1: 249 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| age | 0 | 1.00 | 43.23 | 22.61 | 0.08 | 25.00 | 45.00 | 61.00 | 82.00 | ▅▆▇▇▆ |
| avg_glucose_level | 0 | 1.00 | 106.15 | 45.28 | 55.12 | 77.24 | 91.88 | 114.09 | 271.74 | ▇▃▁▁▁ |
| bmi | 201 | 0.96 | 28.89 | 7.85 | 10.30 | 23.50 | 28.10 | 33.10 | 97.60 | ▇▇▁▁▁ |
kable(head(stroke))| gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke |
|---|---|---|---|---|---|---|---|---|---|---|
| Male | 67 | No | Yes | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 |
| Female | 61 | No | No | Yes | Self-employed | Rural | 202.21 | NA | never smoked | 1 |
| Male | 80 | No | Yes | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 |
| Female | 49 | No | No | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 |
| Female | 79 | Yes | No | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 |
| Male | 81 | No | No | Yes | Private | Urban | 186.21 | 29.0 | formerly smoked | 1 |
Missing data
Over 30% missing observations for smoking_status, also a few for bmi (< 5%). Omit rows with missing data. The observations with missing values were also omitted in the previously mentioned article published on arxiv.org.
stroke <- na.omit(stroke)Data tidak seimbang
Ada baiknya perlu mengetahui berapa banyak pasien yang statusnya tergolong stroke (1) dan tidak (0), karena variabel ini merupakan variabel yang akan kita gunakan dalam pemodelan sebagai variabel target.
table(stroke$stroke)
##
## 0 1
## 4700 209prop.table(table(stroke$stroke))
##
## 0 1
## 0.95742514 0.04257486Berdasarkan hasil di atas diperoleh bahwa yang stroke ada sebanyak 4861 pasien atau 95% dari keseluruhan yang ada.
barplot(prop.table(table(stroke$stroke, stroke$gender), margin=2), col=c(2, 4))The dataset is highly unbalanced. Only 1.9% of the people in the dataset suffer from stroke condition. This poses a difficult problem in training a decision tree (to be exact in any machine-learning based model).
tree <- rpart(stroke ~ ., data = stroke, method = "class")
tree
## n= 4909
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 4909 209 0 (0.95742514 0.04257486) *The tree simply predicts stroke = “No” for every observation in the dataset.
Downsample data
We employ a random downsampling technique to reduce the adverse effect of unbalanced dataset.
set.seed(123)
minority_obs <- stroke %>% filter(stroke == 1)
majority_obs <- stroke %>% filter(stroke == 0) %>% sample_n(nrow(minority_obs))
balanced_data <- bind_rows(minority_obs, majority_obs)prop.table(table(balanced_data$stroke))
##
## 0 1
## 0.5 0.5Partisi data
set.seed(123)
training_samples <- as.vector(caret::createDataPartition(balanced_data$stroke, p = 0.7, list = FALSE))
train_data <- balanced_data[ training_samples, ]
test_data <- balanced_data[-training_samples, ]Membangun model
CART
cart <- rpart(stroke ~ ., data = train_data, method = "class")
rpart.plot(cart, extra = 4)Evaluasi model
Evaluate model performance on the test dataset.
caret::confusionMatrix(predict(cart, test_data, type = "class"), as.factor(test_data$stroke), positive = "1", mode = "prec_recall")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 39 9
## 1 23 53
##
## Accuracy : 0.7419
## 95% CI : (0.6557, 0.8163)
## No Information Rate : 0.5
## P-Value [Acc > NIR] : 3.264e-08
##
## Kappa : 0.4839
##
## Mcnemar's Test P-Value : 0.02156
##
## Precision : 0.6974
## Recall : 0.8548
## F1 : 0.7681
## Prevalence : 0.5000
## Detection Rate : 0.4274
## Detection Prevalence : 0.6129
## Balanced Accuracy : 0.7419
##
## 'Positive' Class : 1
## Random Forest
trControl <- trainControl(method = "cv", number = 10, search = "grid")
set.seed(1234)
rf_mtry<- caret::train(as.factor(stroke) ~ .,
data = train_data,
method = "rf",
metric = "Accuracy",
trControl = trControl)
best_mtry <- rf_mtry$bestTune$mtry
best_mtry
## [1] 9max(rf_mtry$results$Accuracy)
## [1] 0.7850575store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(5: 15)) {
set.seed(1234)
rf_maxnode <- caret::train(as.factor(stroke) ~ .,
data = train_data,
method = "rf",
metric = "Accuracy",
tuneGrid = tuneGrid,
trControl = trControl,
importance = TRUE,
nodesize = 14,
maxnodes = maxnodes,
ntree = 300)
current_iteration <- toString(maxnodes)
store_maxnode[[current_iteration]] <- rf_maxnode
}
results_mtry <- resamples(store_maxnode)
summary(results_mtry)
##
## Call:
## summary.resamples(object = results_mtry)
##
## Models: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## 5 0.6551724 0.7347701 0.7931034 0.7783908 0.8318966 0.8333333 0
## 6 0.6896552 0.7606322 0.8103448 0.7887356 0.8318966 0.8333333 0
## 7 0.7241379 0.7347701 0.8103448 0.7920690 0.8318966 0.8666667 0
## 8 0.7241379 0.7347701 0.8103448 0.7920690 0.8318966 0.8666667 0
## 9 0.7241379 0.7347701 0.8103448 0.7954023 0.8318966 0.8666667 0
## 10 0.5862069 0.7413793 0.8000000 0.7781609 0.8275862 0.8666667 0
## 11 0.6206897 0.7672414 0.8000000 0.7816092 0.8275862 0.8666667 0
## 12 0.5862069 0.7327586 0.7833333 0.7645977 0.8275862 0.8666667 0
## 13 0.5862069 0.7586207 0.7666667 0.7647126 0.8189655 0.8666667 0
## 14 0.5517241 0.7586207 0.7798851 0.7611494 0.8206897 0.8666667 0
## 15 0.5517241 0.7586207 0.7798851 0.7577011 0.8206897 0.8666667 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## 5 0.3028846 0.4767913 0.5827314 0.5569665 0.6643026 0.6666667 0
## 6 0.3741007 0.5267760 0.6180051 0.5778075 0.6643026 0.6666667 0
## 7 0.4449761 0.4748826 0.6180051 0.5846426 0.6643026 0.7333333 0
## 8 0.4449761 0.4748826 0.6180051 0.5846426 0.6643026 0.7333333 0
## 9 0.4449761 0.4748826 0.6180051 0.5913093 0.6643026 0.7333333 0
## 10 0.1753555 0.4855557 0.6000000 0.5568120 0.6559773 0.7333333 0
## 11 0.2458629 0.5360039 0.6000000 0.5635369 0.6559773 0.7333333 0
## 12 0.1753555 0.4685009 0.5666667 0.5296044 0.6559773 0.7333333 0
## 13 0.1753555 0.5166585 0.5333333 0.5297633 0.6361281 0.7333333 0
## 14 0.1045131 0.5166585 0.5605055 0.5226206 0.6392086 0.7333333 0
## 15 0.1045131 0.5166585 0.5605055 0.5159273 0.6392086 0.7333333 0store_maxtrees <- list()
for (ntree in c(250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000)) {
set.seed(5678)
rf_maxtrees <- caret::train(as.factor(stroke) ~ .,
data = train_data,
method = "rf",
metric = "Accuracy",
tuneGrid = tuneGrid,
trControl = trControl,
importance = TRUE,
nodesize = 14,
maxnodes = 14,
ntree = ntree)
key <- toString(ntree)
store_maxtrees[[key]] <- rf_maxtrees
}
results_tree <- resamples(store_maxtrees)
summary(results_tree)
##
## Call:
## summary.resamples(object = results_tree)
##
## Models: 250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## 250 0.6785714 0.7396552 0.7965517 0.7953284 0.8534483 0.9000000 0
## 300 0.6785714 0.7396552 0.7931034 0.7884319 0.8206897 0.9000000 0
## 350 0.6785714 0.7396552 0.7931034 0.7952135 0.8465517 0.9333333 0
## 400 0.6785714 0.7396552 0.7931034 0.7952135 0.8465517 0.9333333 0
## 450 0.6428571 0.7396552 0.7931034 0.7848604 0.8206897 0.9000000 0
## 500 0.6428571 0.7396552 0.7931034 0.7848604 0.8206897 0.9000000 0
## 550 0.6428571 0.7396552 0.7931034 0.7881938 0.8206897 0.9333333 0
## 600 0.6428571 0.7396552 0.7931034 0.7881938 0.8206897 0.9333333 0
## 800 0.6785714 0.7396552 0.7931034 0.7850985 0.8206897 0.9000000 0
## 1000 0.6785714 0.7396552 0.7931034 0.7884319 0.8206897 0.9000000 0
## 2000 0.6785714 0.7396552 0.7931034 0.7884319 0.8206897 0.9333333 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## 250 0.3571429 0.4811475 0.5928571 0.5911511 0.7069084 0.8000000 0
## 300 0.3571429 0.4811475 0.5837225 0.5773055 0.6429078 0.8000000 0
## 350 0.3571429 0.4811475 0.5837225 0.5907630 0.6938389 0.8666667 0
## 400 0.3571429 0.4811475 0.5837225 0.5907630 0.6938389 0.8666667 0
## 450 0.2857143 0.4811475 0.5837225 0.5701627 0.6429078 0.8000000 0
## 500 0.2857143 0.4811475 0.5837225 0.5701627 0.6429078 0.8000000 0
## 550 0.2857143 0.4811475 0.5837225 0.5768293 0.6429078 0.8666667 0
## 600 0.2857143 0.4811475 0.5837225 0.5768293 0.6429078 0.8666667 0
## 800 0.3571429 0.4811475 0.5837225 0.5706388 0.6429078 0.8000000 0
## 1000 0.3571429 0.4811475 0.5837225 0.5773055 0.6429078 0.8000000 0
## 2000 0.3571429 0.4811475 0.5837225 0.5773055 0.6429078 0.8666667 0Plot kalibrasi
set.seed(1234)
model_forest <- caret::train(as.factor(stroke) ~ ., data = train_data, method = "rf", metric = "Accuracy", mrty = 2, ntree = 300, maxnodes = 14)
pred <- predict(model_forest, test_data, type = "prob")
prob <- pred[, "1"]
calCurve <- caret::calibration(as.factor(test_data$stroke) ~ prob, data = test_data, class = '1')
calCurve
##
## Call:
## calibration.formula(x = as.factor(test_data$stroke) ~ prob, data =
## test_data, class = "1")
##
## Models: prob
## Event: 1
## Cuts: 11Explanatory Model Analysis using DALEX
model_tree <- rpart(stroke ~ ., data = train_data, method = "class")
exp_tree <- DALEX::explain(model_tree, data = train_data[,-11], y = as.numeric(train_data$stroke)-1, label = "Decision Tree", type = "classification")
## Preparation of a new explainer is initiated
## -> model label : Decision Tree
## -> data : 294 rows 10 cols
## -> data : tibble converted into a data.frame
## -> target variable : 294 values
## -> predict function : yhat.rpart will be used ( default )
## -> predicted values : No value for predict function target column. ( default )
## -> model_info : package rpart , ver. 4.1.16 , task classification ( default )
## -> model_info : type set to classification
## -> predicted values : numerical, min = 0.05208333 , mean = 0.5 , max = 0.9
## -> residual function : difference between y and yhat ( default )
## -> residuals : numerical, min = -0.9 , mean = 7.983914e-18 , max = 0.9479167
## A new explainer has been created!exp_tree_updated <- update_data(exp_tree, data = test_data[,-11], y = as.numeric(test_data$stroke)-1)
## -> data : 124 rows 10 cols
## -> target variable : 124 values
## [32m An explainer has been updated! [39mmodel_forest <- caret::train(as.factor(stroke) ~ ., data = train_data, method = "rf", metric = "Accuracy", mrty = 2, ntree = 300, maxnodes = 14)
exp_forest <- DALEX::explain(model_forest, data = train_data[,-11], y = as.numeric(train_data$stroke)-1, label = "Random forest", type = "classification")
## Preparation of a new explainer is initiated
## -> model label : Random forest
## -> data : 294 rows 10 cols
## -> data : tibble converted into a data.frame
## -> target variable : 294 values
## -> predict function : yhat.train will be used ( default )
## -> predicted values : No value for predict function target column. ( default )
## -> model_info : package caret , ver. 6.0.93 , task classification ( default )
## -> model_info : type set to classification
## -> predicted values : numerical, min = 0 , mean = 0.5494898 , max = 1
## -> residual function : difference between y and yhat ( default )
## -> residuals : numerical, min = -0.9133333 , mean = -0.0494898 , max = 0.6533333
## A new explainer has been created!exp_forest_updated <- update_data(exp_forest, data = test_data[,-11], y = as.numeric(test_data$stroke)-1)
## -> data : 124 rows 10 cols
## -> target variable : 124 values
## [32m An explainer has been updated! [39mKinerja model
CART
perf_tree <- model_performance(exp_tree_updated)
perf_tree
## Measures for: classification
## recall : 0.8548387
## precision : 0.6973684
## f1 : 0.7681159
## accuracy : 0.7419355
## auc : 0.7814776
##
## Residuals:
## 0% 10% 20% 30% 40% 50%
## -0.90000000 -0.80900749 -0.42307692 -0.05208333 -0.05208333 0.02395833
## 60% 70% 80% 90% 100%
## 0.16831683 0.16831683 0.16831683 0.24390244 0.94791667Random Forest
perf_forest <- model_performance(exp_forest_updated)
perf_forest
## Measures for: classification
## recall : 0.9032258
## precision : 0.7179487
## f1 : 0.8
## accuracy : 0.7741935
## auc : 0.8433923
##
## Residuals:
## 0% 10% 20% 30% 40% 50%
## -0.993333333 -0.736000000 -0.455333333 -0.041000000 -0.006000000 0.001666667
## 60% 70% 80% 90% 100%
## 0.039333333 0.086666667 0.213333333 0.359666667 0.970000000Grafik
plot(perf_tree, perf_forest, geom = "roc")plot(perf_tree, perf_forest, geom = "boxplot")Feature attribution - prediction level
henry <- test_data[1,]
predict(exp_tree, henry)
## [1] 0.7560976predict(exp_forest, henry)
## [1] 0.9866667henry$stroke
## [1] 1
## Levels: 0 1sh_forest <- predict_parts(exp_forest, henry, type = "shap", B = 1)
plot(sh_forest, show_boxplots = FALSE) +
ggtitle("Shapley values for Henry","")bd_forest <- predict_parts(exp_forest, henry, type = "break_down_interactions")
bd_forest
## contribution
## Random forest: intercept 0.549
## Random forest: age = 67 0.224
## Random forest: avg_glucose_level = 228.69 0.051
## Random forest: bmi = 36.6 0.115
## Random forest: heart_disease = 1 0.015
## Random forest: ever_married = 2 0.009
## Random forest: hypertension = 2 -0.001
## Random forest: smoking_status = 1 0.023
## Random forest: Residence_type = 2 0.000
## Random forest: work_type = 4 0.001
## Random forest: gender = 2 0.000
## Random forest: prediction 0.987plot(bd_forest, show_boxplots = FALSE) +
ggtitle("Break down values for Henry","") # + # scale_y_continuous("",limits = c(0.09,0.33))Feature importance - model level
Visualize feature importance for the decision tree and random forest model.
mp_tree <- model_parts(exp_tree, type = "difference")
mp_forest<- model_parts(exp_forest, type = "difference")
plot(mp_tree, mp_forest, show_boxplots = FALSE)Centeris-paribus profiles (What if?)
Visualize how the model response would change for Henry if one of the coordinates in his observation was changed while leaving all other coordinates unchanged.
cp_forest <- predict_profile(exp_forest, henry)cp_forest
## Top profiles :
## gender age hypertension heart_disease ever_married work_type
## 1 Female 67.0000 No Yes Yes Private
## 2 Male 67.0000 No Yes Yes Private
## 3 Male 0.4800 No Yes Yes Private
## 4 Male 1.2952 No Yes Yes Private
## 5 Male 2.1104 No Yes Yes Private
## 6 Male 2.9256 No Yes Yes Private
## Residence_type avg_glucose_level bmi smoking_status _yhat_ _vname_ _ids_
## 1 Urban 228.69 36.6 formerly smoked 0.9866667 gender 1
## 2 Urban 228.69 36.6 formerly smoked 0.9866667 gender 1
## 3 Urban 228.69 36.6 formerly smoked 0.1133333 age 1
## 4 Urban 228.69 36.6 formerly smoked 0.1133333 age 1
## 5 Urban 228.69 36.6 formerly smoked 0.1133333 age 1
## 6 Urban 228.69 36.6 formerly smoked 0.1133333 age 1
## _label_
## 1 Random forest
## 2 Random forest
## 3 Random forest
## 4 Random forest
## 5 Random forest
## 6 Random forest
##
##
## Top observations:
## gender age hypertension heart_disease ever_married work_type Residence_type
## 1 Male 67 No Yes Yes Private Urban
## avg_glucose_level bmi smoking_status _yhat_ _label_ _ids_
## 1 228.69 36.6 formerly smoked 0.9866667 Random forest 1plot(cp_forest, variables = c("age", "avg_glucose_level"))Partial dependence profiles
Create a partial dependence profile for age for the random forest model. Partial dependence profiles are averages from CP profiles for all observations.
mp_forest <- model_profile(exp_forest)
plot(mp_forest, variables = "age")