Predict whether a patient will have stroke or not based on some given attributes using a decision tree and random forest model.
Also improve explainability of the models and predictions by using the methods provided by the DALEX package in R.
library(readr)
library(dplyr)
library(magrittr)
library(tibble)
library(skimr)
library(rpart)
library(rpart.plot)
library(randomForest)
library(caret)
library(e1071)
library(DALEX)
See
data <- read_csv("data/healthcare-dataset-stroke-data/train_2v.csv") %>%
mutate(hypertension = recode(hypertension, `1` = "Yes", `0` = "No"),
heart_disease = recode(heart_disease, `1` = "Yes", `0` = "No")) %>%
mutate_if(is.character, as.factor) %>%
select(-id)
glimpse(data)
## Rows: 43,400
## Columns: 11
## $ gender <fct> Male, Male, Female, Female, Male, Female, Female,...
## $ age <dbl> 3, 58, 8, 70, 14, 47, 52, 75, 32, 74, 79, 79, 37,...
## $ hypertension <fct> No, Yes, No, No, No, No, No, No, No, Yes, No, No,...
## $ heart_disease <fct> No, No, No, No, No, No, No, Yes, No, No, No, Yes,...
## $ ever_married <fct> No, Yes, No, Yes, No, Yes, Yes, Yes, Yes, Yes, Ye...
## $ work_type <fct> children, Private, Private, Private, Never_worked...
## $ Residence_type <fct> Rural, Urban, Urban, Rural, Rural, Urban, Urban, ...
## $ avg_glucose_level <dbl> 95.12, 87.96, 110.89, 69.04, 161.28, 210.95, 77.5...
## $ bmi <dbl> 18.0, 39.2, 17.6, 35.9, 19.1, 50.1, 17.7, 27.0, 3...
## $ smoking_status <fct> NA, never smoked, NA, formerly smoked, NA, NA, fo...
## $ stroke <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
skim(data[, -1])
| Name | data[, -1] |
| Number of rows | 43400 |
| Number of columns | 10 |
| _______________________ | |
| Column type frequency: | |
| factor | 6 |
| numeric | 4 |
| ________________________ | |
| Group variables | None |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| hypertension | 0 | 1.00 | FALSE | 2 | No: 39339, Yes: 4061 |
| heart_disease | 0 | 1.00 | FALSE | 2 | No: 41338, Yes: 2062 |
| ever_married | 0 | 1.00 | FALSE | 2 | Yes: 27938, No: 15462 |
| work_type | 0 | 1.00 | FALSE | 5 | Pri: 24834, Sel: 6793, chi: 6156, Gov: 5440 |
| Residence_type | 0 | 1.00 | FALSE | 2 | Urb: 21756, Rur: 21644 |
| smoking_status | 13292 | 0.69 | FALSE | 3 | nev: 16053, for: 7493, smo: 6562 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| age | 0 | 1.00 | 42.22 | 22.52 | 0.08 | 24.00 | 44.00 | 60.00 | 82.00 | ▆▆▇▇▆ |
| avg_glucose_level | 0 | 1.00 | 104.48 | 43.11 | 55.00 | 77.54 | 91.58 | 112.07 | 291.05 | ▇▂▁▁▁ |
| bmi | 1462 | 0.97 | 28.61 | 7.77 | 10.10 | 23.20 | 27.70 | 32.90 | 97.60 | ▇▇▁▁▁ |
| stroke | 0 | 1.00 | 0.02 | 0.13 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 | ▇▁▁▁▁ |
Note: We will use the DALEX package for Explanatory Model Analysis (EMA). To use some of the functions in this package, categorical predictor variables need to be converted to factors.
Over 30% missing observations for smoking_status, also a few for bmi (< 5%). Omit rows with missing data. The observations with missing values were also omitted in the previously mentioned article published on arxiv.org.
data <- na.omit(data)
The dataset is highly unbalanced. Only 1.9% of the people in the dataset suffer from stroke condition. This poses a difficult problem in training a decision tree (to be exact in any machine-learning based model).
tree <- rpart(stroke ~ ., data = data, method = "class")
tree
## n= 29072
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 29072 548 0 (0.98115025 0.01884975) *
The tree simply predicts stroke = "No" for every observation in the dataset.
We employ a random downsampling technique to reduce the adverse effect of unbalanced dataset.
set.seed(123)
minority_obs <- data %>% filter(stroke == 1)
majority_obs <- data %>% filter(stroke == 0) %>% sample_n(nrow(minority_obs))
balanced_data <- bind_rows(minority_obs, majority_obs)
We use 70% of the dataset for training the machine learning models and 30% of the dataset for testing their performance.
set.seed(123)
training_samples <- as.vector(caret::createDataPartition(balanced_data$stroke, p = 0.7, list = FALSE))
train <- balanced_data[ training_samples, ]
test <- balanced_data[-training_samples, ]
Train a decision tree model.
cart <- rpart(stroke ~ ., data = train, method = "class")
rpart.plot(cart, extra = 4)
Evaluate model performance on the test dataset.
confusionMatrix(predict(cart, test, type = "class"), as.factor(test$stroke), positive = "1", mode = "prec_recall")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 120 37
## 1 44 127
##
## Accuracy : 0.753
## 95% CI : (0.7027, 0.7988)
## No Information Rate : 0.5
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.5061
##
## Mcnemar's Test P-Value : 0.505
##
## Precision : 0.7427
## Recall : 0.7744
## F1 : 0.7582
## Prevalence : 0.5000
## Detection Rate : 0.3872
## Detection Prevalence : 0.5213
## Balanced Accuracy : 0.7530
##
## 'Positive' Class : 1
##
https://rpmcruz.github.io/machine%20learning/2018/02/09/probabilities-trees.html
Train a random forest model.
trControl <- trainControl(method = "cv", number = 10, search = "grid")
set.seed(1234)
rf_mtry<- train(as.factor(stroke) ~ .,
data = train,
method = "rf",
metric = "Accuracy",
trControl = trControl)
best_mtry <- rf_mtry$bestTune$mtry
best_mtry
## [1] 2
max(rf_mtry$results$Accuracy)
## [1] 0.7278612
store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(5: 15)) {
set.seed(1234)
rf_maxnode <- train(as.factor(stroke) ~ .,
data = train,
method = "rf",
metric = "Accuracy",
tuneGrid = tuneGrid,
trControl = trControl,
importance = TRUE,
nodesize = 14,
maxnodes = maxnodes,
ntree = 300)
current_iteration <- toString(maxnodes)
store_maxnode[[current_iteration]] <- rf_maxnode
}
results_mtry <- resamples(store_maxnode)
summary(results_mtry)
##
## Call:
## summary.resamples(object = results_mtry)
##
## Models: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## 5 0.6363636 0.7074176 0.7142857 0.7175395 0.7344498 0.7763158 0
## 6 0.6623377 0.6983510 0.7255639 0.7239326 0.7574334 0.7692308 0
## 7 0.6710526 0.6893107 0.7077922 0.7161720 0.7434211 0.7922078 0
## 8 0.6623377 0.6993084 0.7189850 0.7214185 0.7272727 0.7922078 0
## 9 0.6883117 0.7114662 0.7189850 0.7214023 0.7272727 0.7763158 0
## 10 0.6623377 0.7012987 0.7189850 0.7213515 0.7508325 0.7631579 0
## 11 0.7012987 0.7166353 0.7272727 0.7330735 0.7476274 0.7792208 0
## 12 0.6753247 0.7166353 0.7320574 0.7213685 0.7394053 0.7435897 0
## 13 0.7105263 0.7166353 0.7272727 0.7343722 0.7378871 0.7922078 0
## 14 0.6753247 0.7245813 0.7272727 0.7330735 0.7378871 0.8051948 0
## 15 0.6842105 0.7175325 0.7337662 0.7343214 0.7508325 0.7763158 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## 5 0.2711291 0.4146180 0.4278959 0.4348248 0.4691979 0.5526316 0
## 6 0.3227334 0.3964244 0.4506935 0.4475685 0.5150199 0.5384615 0
## 7 0.3421053 0.3782268 0.4153661 0.4321548 0.4868421 0.5840648 0
## 8 0.3241053 0.3986167 0.4373421 0.4426355 0.4552793 0.5846258 0
## 9 0.3748309 0.4226184 0.4386909 0.4427125 0.4550043 0.5526316 0
## 10 0.3254717 0.4013854 0.4375356 0.4425304 0.5012275 0.5263158 0
## 11 0.4039044 0.4320383 0.4546352 0.4660805 0.4948172 0.5585160 0
## 12 0.3503206 0.4323289 0.4637925 0.4425245 0.4776903 0.4871795 0
## 13 0.4210526 0.4326191 0.4544519 0.4686235 0.4762337 0.5849057 0
## 14 0.3481206 0.4488168 0.4546352 0.4658416 0.4762337 0.6104553 0
## 15 0.3684211 0.4335791 0.4682521 0.4682948 0.5009770 0.5526316 0
The highest accuracy score is obtained with a value of maxnode equals to 14.
store_maxtrees <- list()
for (ntree in c(250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000)) {
set.seed(5678)
rf_maxtrees <- train(as.factor(stroke) ~ .,
data = train,
method = "rf",
metric = "Accuracy",
tuneGrid = tuneGrid,
trControl = trControl,
importance = TRUE,
nodesize = 14,
maxnodes = 14,
ntree = ntree)
key <- toString(ntree)
store_maxtrees[[key]] <- rf_maxtrees
}
results_tree <- resamples(store_maxtrees)
summary(results_tree)
##
## Call:
## summary.resamples(object = results_tree)
##
## Models: 250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## 250 0.6753247 0.7272727 0.7451299 0.7422345 0.7654648 0.7894737 0
## 300 0.6883117 0.7329545 0.7516234 0.7512755 0.7662338 0.8205128 0
## 350 0.6753247 0.7199248 0.7467532 0.7422345 0.7654648 0.8026316 0
## 400 0.6753247 0.7296651 0.7402597 0.7409529 0.7705485 0.8026316 0
## 450 0.6753247 0.7296651 0.7402597 0.7474302 0.7869105 0.8076923 0
## 500 0.6753247 0.7296651 0.7402597 0.7461486 0.7902854 0.8076923 0
## 550 0.6753247 0.7296651 0.7402597 0.7474473 0.7935321 0.8076923 0
## 600 0.6753247 0.7199248 0.7532468 0.7474302 0.7836637 0.8076923 0
## 800 0.6753247 0.7305195 0.7516234 0.7474473 0.7804170 0.8076923 0
## 1000 0.6753247 0.7199248 0.7532468 0.7500110 0.7836637 0.8205128 0
## 2000 0.6623377 0.7199248 0.7467532 0.7474306 0.7836637 0.8205128 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## 250 0.3511965 0.4564706 0.4898649 0.4845893 0.5308704 0.5789474 0
## 300 0.3744076 0.4673529 0.5042828 0.5026920 0.5321515 0.6410256 0
## 350 0.3511965 0.4412186 0.4940515 0.4845858 0.5308704 0.6052632 0
## 400 0.3481206 0.4607740 0.4816545 0.4820354 0.5408677 0.6052632 0
## 450 0.3481206 0.4607740 0.4811283 0.4949542 0.5737650 0.6153846 0
## 500 0.3481206 0.4607740 0.4811283 0.4923883 0.5805083 0.6153846 0
## 550 0.3481206 0.4607740 0.4811283 0.4950028 0.5870445 0.6153846 0
## 600 0.3481206 0.4412186 0.5070709 0.4949524 0.5673077 0.6153846 0
## 800 0.3481206 0.4629847 0.5027881 0.4949730 0.5607714 0.6153846 0
## 1000 0.3485618 0.4412186 0.5069835 0.5001593 0.5673077 0.6410256 0
## 2000 0.3222749 0.4412186 0.4943853 0.4949315 0.5670706 0.6410256 0
The final model is a random forest with the following parameters:
ntree = 300: 300 trees will be trained mtry = 2: 2 features are chosen for each iteration maxnodes = 14: Maximum 14 nodes in the terminal nodes (leaves)
hist(prob[test$stroke == 1], xlab = “probability of stroke condition”, main = “True Outcome: Stroke condition”) hist(prob[test$stroke == 0], xlab = “probability of stroke condition”, main = “True Outcome: No stroke condition”)
The estimated class probabilities should be reflective of the true underlying probability of the sample. That is, the predicted class probability needs to be well-calibrated. To be well-calibrated, the probabilities must effectively reflect the true likelihood of a stroke condition.
One way to asess the quality of the class probabilities is using a calibration plot. The plot shows some measure of the observed probability of an event versus the predicted probability. One approach for creating this visualization is to score a collection of samples with known outcomes using the classification model. The next step is to bin the data into groups based on their class probabilities. For each bin, determine the observed event rate. The calibration plot displays the midpoint of each bin on the x-axis and the observed event rate on the y-axis.If the points fall along a 45 degrees line, the model has produced well-calibrated probabilities.
set.seed(1234)
model_forest <- train(as.factor(stroke) ~ ., data = train, method = "rf", metric = "Accuracy", mrty = 2, ntree = 300, maxnodes = 14)
pred <- predict(model_forest, test, type = "prob")
prob <- pred[, "1"]
calCurve <- caret::calibration(as.factor(test$stroke) ~ prob, data = test, class = '1')
calCurve
##
## Call:
## calibration.formula(x = as.factor(test$stroke) ~ prob, data = test, class = "1")
##
## Models: prob
## Event: 1
## Cuts: 11
xyplot(calCurve)
Create an explainer for the decision tree.
model_tree <- rpart(stroke ~ ., data = train, method = "class")
exp_tree <- explain(model_tree, data = train[, -11], y = train$stroke, label = "Decision Tree", type = "classification")
## Preparation of a new explainer is initiated
## -> model label : Decision Tree
## -> data : 768 rows 10 cols
## -> data : tibbble converted into a data.frame
## -> target variable : 768 values
## -> model_info : package rpart , ver. 4.1.15 , task regression ( [33m default [39m )
## -> model_info : type set to classification
## -> predict function : yhat.rpart will be used ( [33m default [39m )
## -> predicted values : numerical, min = 0.1290323 , mean = 0.5 , max = 0.8082192
## -> residual function : difference between y and yhat ( [33m default [39m )
## -> residuals : numerical, min = -0.8082192 , mean = 1.690801e-18 , max = 0.8709677
## [32m A new explainer has been created! [39m
exp_tree_updated <- update_data(exp_tree, data = test[, -11], y = test$stroke)
## -> data : 328 rows 10 cols
## -> target variable : 328 values
## [32m An explainer has been updated! [39m
Create an explainer for the random forest.
model_forest<- train(as.factor(stroke) ~ ., data = train, method = "rf", metric = "Accuracy", mrty = 2, ntree = 300, maxnodes = 14)
exp_forest <- explain(model_forest, data = train[, -11], y = train$stroke, label = "Random forest", type = "classification")
## Preparation of a new explainer is initiated
## -> model label : Random forest
## -> data : 768 rows 10 cols
## -> data : tibbble converted into a data.frame
## -> target variable : 768 values
## -> model_info : package caret , ver. 6.0.86 , task Classification ( [33m default [39m )
## -> model_info : type set to classification
## -> predict function : yhat.train will be used ( [33m default [39m )
## -> predicted values : numerical, min = 0 , mean = 0.5529514 , max = 1
## -> residual function : difference between y and yhat ( [33m default [39m )
## -> residuals : numerical, min = -0.9833333 , mean = -0.05295139 , max = 0.9933333
## [32m A new explainer has been created! [39m
exp_forest_updated <- update_data(exp_forest, data = test[, -11], y = test$stroke)
## -> data : 328 rows 10 cols
## -> target variable : 328 values
## [32m An explainer has been updated! [39m
For classification commonly used measures for evaluation of model performance are accuracy, precision, recall and F1 score.
cat("Decision tree\n")
## Decision tree
cat("===============\n")
## ===============
perf_tree <- model_performance(exp_tree_updated)
perf_tree
## Measures for: classification
## recall : 0.7743902
## precision: 0.7426901
## f1 : 0.758209
## accuracy : 0.7530488
## auc : 0.8770821
##
## Residuals:
## 0% 10% 20% 30% 40% 50%
## -0.80821918 -0.76136364 -0.39007092 -0.12903226 -0.12903226 0.03137428
## 60% 70% 80% 90% 100%
## 0.19178082 0.19178082 0.23863636 0.60992908 0.87096774
cat("\n")
cat("Random forest\n")
## Random forest
cat("=============\n")
## =============
perf_forest <- model_performance(exp_forest_updated)
perf_forest
## Measures for: classification
## recall : 0.9085366
## precision: 0.6962617
## f1 : 0.7883598
## accuracy : 0.7560976
## auc : 0.8280042
##
## Residuals:
## 0% 10% 20% 30% 40% 50%
## -0.996666667 -0.762666667 -0.488666667 -0.103000000 -0.003333333 0.000000000
## 60% 70% 80% 90% 100%
## 0.023333333 0.053000000 0.138666667 0.349666667 0.993333333
The random forest model has a higher recall but a lower precision than the decision tree. It also has a higher F1 score.
plot(perf_tree, perf_forest, geom = "roc")
plot(perf_tree, perf_forest, geom = "boxplot")
Calculate the prediction of the decision tree and random forest model for the first person in the test set. Let’s call him Henry.
henry <- test[1,]
predict(exp_tree, henry)
## [1] 0.7613636
predict(exp_forest, henry)
## [1] 0.9833333
henry$stroke
## [1] 1
Visualize feature attributions for Henry’s prediction using Shapley values.
sh_forest <- predict_parts(exp_forest, henry, type = "shap", B = 1)
plot(sh_forest, show_boxplots = FALSE) +
ggtitle("Shapley values for Henry","")
Visualize feature attributions for Henry’s prediction using break down values.
bd_forest <- predict_parts(exp_forest, henry, type = "break_down_interactions")
bd_forest
plot(bd_forest, show_boxplots = FALSE) +
ggtitle("Break down values for Henry","") # +
# scale_y_continuous("",limits = c(0.09,0.33))
Visualize feature importance for the decision tree and random forest model.
mp_tree <- model_parts(exp_tree, type = "difference")
mp_forest<- model_parts(exp_forest, type = "difference")
plot(mp_tree, mp_forest, show_boxplots = FALSE)
Visualize how the model response would change for Henry if one of the coordinates in his observation was changed while leaving all other coordinates unchanged.
cp_forest <- predict_profile(exp_forest, henry)
## Warning in if (class(new_observation) != "data.frame") {: the condition has
## length > 1 and only the first element will be used
cp_forest
plot(cp_forest, variables = c("heart_disease", "hypertension"))
## 'variable_type' changed to 'categorical' due to lack of numerical variables.
## 'variable_type' changed to 'categorical' due to lack of numerical variables.
plot(cp_forest, variables = c("age", "avg_glucose_level"))
Create a partial dependence profile for age for the random forest model. Partial dependence profiles are averages from CP profiles for all observations.
mp_forest <- model_profile(exp_forest)
plot(mp_forest, variables = "age")