Overview

The Problem

According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, being responsible for approximately 11% of total deaths. This comes out to roughly six and a half-million deaths annually.

It is also estimated that there are over 12.2 million new cases of strokes each year and that one in four people over 25 years of age will experience a stroke.

The ten leading factors, and there percent weight, that can increase the probability of a stroke are the following:

  1. Elevated systolic blood pressure (56%)
  2. High body mass index (24%)
  3. High fasting glucose (20%)
  4. Air pollution (20%)
  5. Smoking (18%)
  6. Poor diet (31%)
  7. High ldl cholesterol (10%)
  8. Kidney dysfunction (8%)
  9. Alcohol use (6%)
  10. Low physical activity (2%)

Statistic source: World Stoke Organization

This project aims to build and deploy a model that can predict whether or not a patient will have a stroke based on a variety of factors relating to patient data. This encompasses the patient’s medical history, and demographic information. The deployment of such a model could result in successful preventive care by focusing medical attention on high risk patients.

The Dataset

The dataset being used to build the prediction model comes from the open data source website kaggle. Below are a list of the attributes as depicted in the dataset and their description.

  • id: unique patient identifier
  • gender: patient’s gender (male, female, other)
  • age: age of the patient
  • hypertension: 0 if the patient doesn’t have hypertension, 1 if the patient has hypertension
  • heart_disease: 0 if the patient doesn’t have any heart diseases, 1 if the patient has a heart disease
  • ever_married: marrital status; “No” they have not been married or “Yes” they have been married
  • work_type: “children”, “Govt_jov”, “Never_worked”, “Private” or “Self-employed”
  • Residence_type: “Rural” or “Urban”
  • avg_glucose_level: average glucose level in blood
  • bmi: body mass index
  • smoking_status: “formerly smoked”, “never smoked”, “smokes” or “Unknown”
  • stroke: 1 if the patient had a stroke or 0 if not

Specific Machine Learning Task

This specific model will utilize classification to predict the outcome variable (stroke). Classification was chosen because it deals with the prediction of categorical data (in this case the label being whether or not a patient will have a stroke or not). This model will also be considered supervised, due to the outcome variable being pre-labeled. This means that the data can be validated by the pre-existing defined stroke cases.

Metrics for Performance

With classification being used, the means to evaluate the model would be through the use of the following metrics:

  • Precision (specificity): Addresses how accurate the model predicted that a patient has a stroke.
  • Recall (sensitivity): Addresses how often the model correctly predicts that the patient has a stroke.
  • Accuracy: Overall how does the final predictions of the model compare to the actual outcomes.
  • ROC_AUC: Relationship between the true positive rate and the false positive rate.

Get the Data

Load packages

Below are the packages being loaded to successfully conduct this machine learning project. The packages are in order of use and are broken up which step in the process they are being utilized.

# Packages needed for data manipulation, visualization, model building, and model evaluation
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.1     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.5      ✔ rsample      1.2.1 
## ✔ dials        1.2.1      ✔ tune         1.2.1 
## ✔ infer        1.0.7      ✔ workflows    1.1.4 
## ✔ modeldata    1.3.0      ✔ workflowsets 1.1.0 
## ✔ parsnip      1.2.1      ✔ yardstick    1.3.1 
## ✔ recipes      1.0.10     
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Dig deeper into tidy modeling with R at https://www.tmwr.org
library(workflows)
library(themis)
library(tune)
library(ranger)
library(xgboost)
## 
## Attaching package: 'xgboost'
## 
## The following object is masked from 'package:dplyr':
## 
##     slice
library(kknn)

Import the data

The first step in this process is to import the data and assign it to an object for manipulation.

# set working directory
getwd()
setwd("/Users/nickwinters/desktop/DS Projects/R/strokePredictor")

# assign csv file to an object
strk <- read.csv("healthcare-dataset-stroke-data.csv")

Data Exploration

Taking a peak at the data

Before diving into building the model it would be beneficial to get acclimated with the data.

# display the first five rows
head(strk, 5)
##      id gender age hypertension heart_disease ever_married     work_type
## 1  9046   Male  67            0             1          Yes       Private
## 2 51676 Female  61            0             0          Yes Self-employed
## 3 31112   Male  80            0             1          Yes       Private
## 4 60182 Female  49            0             0          Yes       Private
## 5  1665 Female  79            1             0          Yes Self-employed
##   Residence_type avg_glucose_level  bmi  smoking_status stroke
## 1          Urban            228.69 36.6 formerly smoked      1
## 2          Rural            202.21  N/A    never smoked      1
## 3          Rural            105.92 32.5    never smoked      1
## 4          Urban            171.23 34.4          smokes      1
## 5          Rural            174.12   24    never smoked      1
# size and type of data
dim(strk)
## [1] 5110   12
summary(strk)
##        id           gender               age         hypertension    
##  Min.   :   67   Length:5110        Min.   : 0.08   Min.   :0.00000  
##  1st Qu.:17741   Class :character   1st Qu.:25.00   1st Qu.:0.00000  
##  Median :36932   Mode  :character   Median :45.00   Median :0.00000  
##  Mean   :36518                      Mean   :43.23   Mean   :0.09746  
##  3rd Qu.:54682                      3rd Qu.:61.00   3rd Qu.:0.00000  
##  Max.   :72940                      Max.   :82.00   Max.   :1.00000  
##  heart_disease     ever_married        work_type         Residence_type    
##  Min.   :0.00000   Length:5110        Length:5110        Length:5110       
##  1st Qu.:0.00000   Class :character   Class :character   Class :character  
##  Median :0.00000   Mode  :character   Mode  :character   Mode  :character  
##  Mean   :0.05401                                                           
##  3rd Qu.:0.00000                                                           
##  Max.   :1.00000                                                           
##  avg_glucose_level     bmi            smoking_status         stroke       
##  Min.   : 55.12    Length:5110        Length:5110        Min.   :0.00000  
##  1st Qu.: 77.25    Class :character   Class :character   1st Qu.:0.00000  
##  Median : 91.89    Mode  :character   Mode  :character   Median :0.00000  
##  Mean   :106.15                                          Mean   :0.04873  
##  3rd Qu.:114.09                                          3rd Qu.:0.00000  
##  Max.   :271.74                                          Max.   :1.00000
# lets specifically look at class of each column
sapply(strk, class)
##                id            gender               age      hypertension 
##         "integer"       "character"         "numeric"         "integer" 
##     heart_disease      ever_married         work_type    Residence_type 
##         "integer"       "character"       "character"       "character" 
## avg_glucose_level               bmi    smoking_status            stroke 
##         "numeric"       "character"       "character"         "integer"
# Lets check for any missing data
sum(is.na(strk))
## [1] 0

Observation: Above it is shown that the data is 12 columns by 5110 rows, and that the data is a mix of integers(4), floats(2), and characters(6). Interesting to note is that the age column is a float where the minimum value is 0.08, which could be indicative of babies having a reported age in months. Secondly, bmi–a numerical measurement–is classed as being a character. Finally, there are no apparent data entries that are categorized as missing. This lack of missing though is contradictive to what is seen when displaying the first five rows of data. In the bmi column, an “N/A” is shown. This means that there is missing data within the bmi column and that the reason this column is classed as being a charater is because of the presence of “N/A” strings. Also in the smoking_status column it is known that “Unknown” is being used to represent unavailable information. These issues will have to be addressed in the data manipulation phase.

Data Manipulation

It is now time to prep the data for machine learning.

Addressing missing data

First the unrepresented data will be correctly displayed as NA, and the total count of missing data will be computed.

# address missing values ("N/A" was used for NA values)
strk[strk == 'N/A'] <- NA
strk[strk == 'Unknown'] <- NA

#validate the above change
sum(is.na(strk))
## [1] 1745

Changing the Datatypes of Columns

Next, because this is a classification task, the categorical variables must be converted to factors. This prevents the model from seeing “0s” and “1s” as numerical values but as representations of data. In addition, bmi will also be converted to a double or float data type.

# Change datatypes
strk <- transform(strk,
                       bmi = as.double(bmi),
                       gender = as.factor(gender),
                       ever_married = as.factor(ever_married),
                       work_type = as.factor(work_type),
                       Residence_type = as.factor(Residence_type),
                       smoking_status = as.factor(smoking_status),
                       hypertension = as.factor(hypertension),
                       heart_disease = as.factor(heart_disease),
                       stroke = as.factor(stroke))

# validate the above change
sapply(strk, class)
##                id            gender               age      hypertension 
##         "integer"          "factor"         "numeric"          "factor" 
##     heart_disease      ever_married         work_type    Residence_type 
##          "factor"          "factor"          "factor"          "factor" 
## avg_glucose_level               bmi    smoking_status            stroke 
##         "numeric"         "numeric"          "factor"          "factor"

Binning Data

To both simplify the model and address any outliers present in the data, the age, bmi, and avg_glucose_level columns will be binned and ordered into classes based on common segmentation for disease analysis.

Sources used for choosing the classes:

# bin data into appropriate groupings
strk <- strk |> 
  mutate(age_group = cut(age, 
                         breaks = c(0, 6, 19, 46, 65, 76, 100),
                         labels = c('0-5y', '6-18y', '19-45y', '46-64y', '65-75y', '76+'),
                         include.lowest = T, 
                         right = F,
                         ordered_result = T),
         bmi_group = cut(bmi,
                         breaks = c(0, 19, 25, 30, 35, 40, 100),
                         labels = c("underweight", "normal", "overweight", "class I obesity", "class II obesity", "class III obesity"),
                         include.lowest = T,
                         right = F,
                         ordered_result = T),
         glucose_group = cut(avg_glucose_level,
                             breaks = c(0, 140, 200, 300),
                             labels = c("Normal", "Prediabetic", "Diabetic"),
                             include.lowest = T,
                             right = F,
                             ordered_result = T)
         )

Once the values have been binned now, the new column’s attributes will be validated.

# validate the bins
strk |> 
  select(age, age_group, bmi, bmi_group, avg_glucose_level, glucose_group) |> 
  head()
##   age age_group  bmi        bmi_group avg_glucose_level glucose_group
## 1  67    65-75y 36.6 class II obesity            228.69      Diabetic
## 2  61    46-64y   NA             <NA>            202.21      Diabetic
## 3  80       76+ 32.5  class I obesity            105.92        Normal
## 4  49    46-64y 34.4  class I obesity            171.23   Prediabetic
## 5  79       76+ 24.0           normal            174.12   Prediabetic
## 6  81       76+ 29.0       overweight            186.21   Prediabetic
# check order
is.ordered(strk$age_group)
## [1] TRUE
levels(strk$age_group)
## [1] "0-5y"   "6-18y"  "19-45y" "46-64y" "65-75y" "76+"
is.ordered(strk$bmi_group)
## [1] TRUE
levels(strk$bmi_group)
## [1] "underweight"       "normal"            "overweight"       
## [4] "class I obesity"   "class II obesity"  "class III obesity"
is.ordered(strk$glucose_group)
## [1] TRUE
levels(strk$glucose_group)
## [1] "Normal"      "Prediabetic" "Diabetic"

Data Visualization

Now that the data has been appropriately transformed, it will now be visually explored to gain insights into what features seem to be the most important for determining whether a patient may or may not have a stroke. Based on these insights features of un-importance will be dropped from the data.

stroke

The outcome variable was graphed to get a sense of the distribution between patients that have had a stroke versus those who have never had a stroke.

strk |> 
  ggplot(aes(x = stroke, fill = stroke)) +
  geom_bar()

Observation: Within the data the patients that have never had a stroke severely outnumber patients that have had a stroke. This class imbalance will have to be addressed in order to prevent the model from prioritising accuracy by biasing the negative outcome.

gender

strk |> 
  ggplot(aes(x = gender, fill = stroke)) +
  geom_bar(position = "fill")

Observation: Stroke incidents are relatively the same between males and females; therefore, gender must not play a significant role for stroke prediction.

age

strk |> 
  ggplot(aes(x = age_group, fill = stroke)) +
  geom_bar(position = "fill")

Observation: A clear trend is shown that demonstrates that as age increases cases of stroke also increase; therefore, age must play a significant role in stroke prediction.

hypertension

strk |> 
  ggplot(aes(x = hypertension, fill = stroke)) +
  geom_bar(position = "fill")

Observation: Cases of stroke appear more heavily in the hypertension group; therefore, hypertension must play a significant role in stroke detection. This is also supported by the WSO reporting that high blood pressure is a leading cause of stroke.

heart_disease

strk |> 
  ggplot(aes(x = heart_disease, fill = stroke)) +
  geom_bar(position = "fill")

Observation: Cases of stroke appear more heavily in the heart disease group; therefore, heart disease must play a significant role in stroke detection.

ever_married

strk |> 
  ggplot(aes(x = ever_married, fill = stroke)) +
  geom_bar(position = "fill")

strk |> 
  ggplot(aes(x = ever_married, fill = age_group)) +
  geom_bar(position = "fill") +
  facet_wrap(~stroke)

Observation: Stroke incidents are close in count between the “No” and “Yes” groups, with the latter having more counts. This is likely due to more younger individuals being present in the “No” group whereas the more at risk older individuals would be present in the “Yes” group. Marital status must not play a significant role for stroke prediction, likley because age is a prominent factor.

work_type

strk |> 
  ggplot(aes(x = work_type, fill = stroke)) +
  geom_bar(position = "fill")

Observation: Stroke incidents are more prevalent in groups that are employed, with self-employment having the highest counts. Just like with ever_married, work_type is likely also related to age, so to avoid the potntail of covariance, work_type will likley be left out of the model.

Residence_type

strk |> 
  ggplot(aes(x = Residence_type, fill = stroke)) +
  geom_bar(position = "fill")

Observation: Stroke incidents are relatively the same between Rural and Urban; therefore, residence must not play a significant role for stroke prediction.

avg_glucose_level

strk |> 
  ggplot(aes(x = glucose_group, fill = stroke)) +
  geom_bar(position = "fill")

Observation: A clear trend is shown that demonstrates that as glucose levels increase, cases of stroke also increase; therefore, glucose levels must play a significant role in stroke prediction. This is also supported by the WSO reporting that high fasted glucose levels is a leading cause of stroke.

bmi

strk |> 
  filter(!is.na(bmi_group)) |> 
  ggplot(aes(x = bmi_group, fill = stroke)) +
    geom_bar(position = "fill")

strk |> 
  ggplot(aes(x = bmi_group, fill = stroke)) +
    geom_bar(position = "fill")

Observation: Stroke cases are the most prevalent among higher weight classed individuals, with the highest cases occurring in the overweight and class I obese groups. BMI likely plays a role in stroke detection, and this is supported by the WSO reporting that BMI is a leading cause of stroke. It is important to note that when also graphing the missing values, this group had by far the highest stroke cases.

smoking_status

strk |> 
  ggplot(aes(x = smoking_status, fill = stroke)) +
  geom_bar(position = "fill")

Observation: Looking specifically at the never smoked and smokes groupings, both displayed similar cases of stroke. The formerly smoked had the highest stroke cases. Despite the WSO reporting that smoking is a leading cause this data because the same counts between the oppoosite groups and the large presence of missing data (~1700), to simplify the model smoking status will be omitted.

Feature Selection

Based on the data exploration above it appears that age, bmi, glucose levels, hypertension, and heart disease have the most significant impact on cases when a stroke occurred. For this reason, each respective column will be selected as features for the model.

clean_strk <- strk |> 
  select(age_group, 
         bmi_group,
         glucose_group,
         hypertension, 
         heart_disease,
         stroke)

Building the prediction model

Splitting the data

One of the most important steps in the machine learning process involves splitting the data into a training set and a test set. This is done where the model is created and trained using the training set, and then it is evaluated on the test set. This is implemented to determine how well the model generalizes with new data.

It is common practice to allocate 80% of the data to training and 20% of the data to testing.

# set random seed for replication
set.seed(42)

# split the data into training(80%) and testing (20%)
strk_split <- initial_split(clean_strk, 
                            prop = 4/5)
strk_split
## <Training/Testing/Total>
## <4088/1022/5110>
# extracting training and test sets
train <- training(strk_split)
test <- testing(strk_split)

Cross fold analysis

Cross fold analysis or validation is integral to fine tuning the machine learning process. Specifically it is a great way to determine how well the model will generalize to newer data. This works by splitting the training data into multiple folds or parts, where for different iterations of evaluation, each part gets a turn to become the validation set or the set that is going to act as the actual for the generated predictions.

The cross fold object below will mainly be used to fine tune the selected model.

# set random seed for reproducability
set.seed(42)

# create a CV object
stroke_folds <- vfold_cv(train, strata = stroke)

Recipe

To start the process of building the model, a recipe is created that does the following:

  • classify between outcome and predictor variables
  • impute NAs by using the nearest neighbor method
  • binary code all non-ordered factorial predictors
  • numerically score the ordered factorial predictors
  • synthetically over sample the minority (positive cases of stroke) to resolve class imbalance.

The metrics being used to evaluate model performance will also be assigned.

# recipe creation
stroke_recipe <- recipe(stroke ~ 
                          age_group + 
                          bmi_group + 
                          glucose_group + 
                          hypertension + 
                          heart_disease, 
                        data = clean_strk) |> 
  step_impute_knn(all_predictors()) |> 
  step_dummy(all_unordered_predictors()) |> 
  step_ordinalscore(all_ordered_predictors()) |> 
  step_smote(stroke)

# Specify desired metrics
stroke_metrics <- metric_set(roc_auc, accuracy, sensitivity, specificity)

Workflow

Next the recipe is passed to a workflow to allow for the addition of different models and fits.

# add recipe to workflow
stroke_workflow <- workflow() |> 
  add_recipe(stroke_recipe)

Model Creation and Evaluation

The classification models that are going to be used for this task include:

Logistic Regression

Model Creation

This step assigns to an object:

  • the specific model for computation
  • the engine R will run the model through
  • the specific machine learning task (classification) that the model will perform
# logistic regression model
lr_model <- 
  logistic_reg() |> 
  set_engine("glm") |> 
  set_mode("classification")

Fitting the training data

This step adds the classification model that was created above to the previously created workflow. This model + wokflow is then fit to the training data that has been preped for cross validation. This fit is evaluated based on desired metrics (accuracy, sensitivity, specificity, roc_auc).

# logistic regression model
lr_rs <- stroke_workflow|> 
  add_model(lr_model) |> 
  fit_resamples(
    resamples = stroke_folds,
    metrics = stroke_metrics,
    control = control_resamples(save_pred = TRUE)
  )

Evaluating the training data

This step collects the performance metrics from the fit conducted above.

lr_metric <- collect_metrics(lr_rs)
lr_metric
## # A tibble: 4 × 6
##   .metric     .estimator  mean     n std_err .config             
##   <chr>       <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy    binary     0.777    10 0.00960 Preprocessor1_Model1
## 2 roc_auc     binary     0.836    10 0.0137  Preprocessor1_Model1
## 3 sensitivity binary     0.780    10 0.0103  Preprocessor1_Model1
## 4 specificity binary     0.712    10 0.0441  Preprocessor1_Model1
collect_predictions(lr_rs) |> 
  conf_mat(truth = stroke, estimate = .pred_class)
##           Truth
## Prediction    0    1
##          0 3031   58
##          1  852  147

Observation: The logistic regression model yielded:

  • an average accuracy of 0.7773845
  • an average roc_auc of 0.8361093
  • an average sensitivity of 0.7804192
  • an average specificity of 0.7121668

Random forest

Model Creation

This step assigns to an object:

  • the specific model for computation
  • the engine R will run the model through
  • the specific machine learning task (classification) that the model will perform
# Random forest
rf_model <-
  rand_forest() |> 
  set_engine("randomForest") |> 
  set_mode("classification")

Fitting the training data

This step adds the classification model that was created above to the previously created workflow. This model + wokflow is then fit to the training data that has been preped for cross validation. This fit is evaluated based on desired metrics (accuracy, sensitivity, specificity, roc_auc).

# logistic regression model
rf_rs <- stroke_workflow|> 
  add_model(rf_model) |> 
  fit_resamples(
    resamples = stroke_folds,
    metrics = stroke_metrics,
    control = control_resamples(save_pred = TRUE)
  )

Evaluating the training data

This step collects the performance metrics from the fit conducted above.

rf_metric <- collect_metrics(rf_rs)
rf_metric
## # A tibble: 4 × 6
##   .metric     .estimator  mean     n std_err .config             
##   <chr>       <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy    binary     0.849    10 0.00927 Preprocessor1_Model1
## 2 roc_auc     binary     0.793    10 0.0167  Preprocessor1_Model1
## 3 sensitivity binary     0.870    10 0.00815 Preprocessor1_Model1
## 4 specificity binary     0.458    10 0.0468  Preprocessor1_Model1
collect_predictions(rf_rs) |> 
  conf_mat(truth = stroke, estimate = .pred_class)
##           Truth
## Prediction    0    1
##          0 3379  112
##          1  504   93

Observation: The logistic regression model yielded:

  • an average accuracy of 0.8493079
  • an average roc_auc of 0.7926797
  • an average sensitivity of 0.8700754
  • an average specificity of 0.4584445

K-nearest neighbor

Model Creation

This step assigns to an object:

  • the specific model for computation
  • the engine R will run the model through
  • the specific machine learning task (classification) that the model will perform
# kn
kn_model <- nearest_neighbor() |>
  set_engine("kknn") |> 
  set_mode("classification")

Fitting the training data

This step adds the classification model that was created above to the previously created workflow. This model + wokflow is then fit to the training data that has been preped for cross validation. This fit is evaluated based on desired metrics (accuracy, sensitivity, specificity, roc_auc).

# logistic regression model
kn_rs <- stroke_workflow|> 
  add_model(kn_model) |> 
  fit_resamples(
    resamples = stroke_folds,
    metrics = stroke_metrics,
    control = control_resamples(save_pred = TRUE)
  )

Evaluating the training data

This step collects the performance metrics from the fit conducted above.

kn_metric <- collect_metrics(kn_rs)
kn_metric
## # A tibble: 4 × 6
##   .metric     .estimator   mean     n std_err .config             
##   <chr>       <chr>       <dbl> <int>   <dbl> <chr>               
## 1 accuracy    binary     0.934     10 0.00344 Preprocessor1_Model1
## 2 roc_auc     binary     0.678     10 0.0224  Preprocessor1_Model1
## 3 sensitivity binary     0.978     10 0.00195 Preprocessor1_Model1
## 4 specificity binary     0.0884    10 0.0195  Preprocessor1_Model1
collect_predictions(kn_rs) |> 
  conf_mat(truth = stroke, estimate = .pred_class)
##           Truth
## Prediction    0    1
##          0 3799  186
##          1   84   19

Observation: The logistic regression model yielded:

  • an average accuracy of 0.9339524
  • an average roc_auc of 0.6779787
  • an average sensitivity of 0.978343
  • an average specificity of 0.0883937

XGBoost

Model Creation

This step assigns to an object:

  • the specific model for computation
  • the engine R will run the model through
  • the specific machine learning task (classification) that the model will perform
# xg
xg_model <- boost_tree() |>
  set_engine("xgboost") |> 
  set_mode("classification")

Fitting the training data

This step adds the classification model that was created above to the previously created workflow. This model + wokflow is then fit to the training data that has been preped for cross validation. This fit is evaluated based on desired metrics (accuracy, sensitivity, specificity, roc_auc).

# logistic regression model
xg_rs <- stroke_workflow|> 
  add_model(xg_model) |> 
  fit_resamples(
    resamples = stroke_folds,
    metrics = stroke_metrics,
    control = control_resamples(save_pred = TRUE)
  )

Evaluating the training data

This step collects the performance metrics from the fit conducted above.

xg_metric <- collect_metrics(xg_rs)
xg_metric
## # A tibble: 4 × 6
##   .metric     .estimator  mean     n std_err .config             
##   <chr>       <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy    binary     0.841    10 0.0102  Preprocessor1_Model1
## 2 roc_auc     binary     0.783    10 0.0199  Preprocessor1_Model1
## 3 sensitivity binary     0.860    10 0.00883 Preprocessor1_Model1
## 4 specificity binary     0.492    10 0.0480  Preprocessor1_Model1
collect_predictions(xg_rs) |> 
  conf_mat(truth = stroke, estimate = .pred_class)
##           Truth
## Prediction    0    1
##          0 3340  105
##          1  543  100

Observation: The logistic regression model yielded:

  • an average accuracy of 0.8414779
  • an average roc_auc of 0.7831615
  • an average sensitivity of 0.8599866
  • an average specificity of 0.4916759

Model Selection

Based on the captured metrics, the random forest model seems to be the best overall. It had the second highest accuracy, roc_auc, and sensitivity; as well as the third highest specificity.

In the context of disease detection and prevention, sensitivity is being given the higher priority sense this would limit the potential number of false negative cases. Even though this is the case it is also beneficial to have a desirable specificity. It is for this reason why the K-nearest neighbors model was not chosen. Despite this model having the highest accuracy and sensitivity, its specificity was lack luster. This means that this model allows for an increase in false negatives to achieve impressive true positive rates.

Model Fine-Tuning

Now that the random forest model has been chosen, its parameters will now be adjusted to optimize its performance. The parameter being adjusted is metry (the number of randomly selected predictors). ### Stating the desired parameters

rf_grid <- expand.grid(mtry = c(3, 4, 5))

Random Forest Model with flexible tuning

A new random forest model is being created with a flexible set_args.

rf_model_2 <- rand_forest() |> 
  set_args(mtry = tune()) |> 
  set_engine("ranger") |> 
  set_mode("classification")

Workflow

The newly created model will now be added to the workflow.

rf_workflow_2 <- workflow() |> 
  add_recipe(stroke_recipe) |> 
  add_model(rf_model_2) 

Parameter Evaluation

The metrics were then collected across all the tested parameters, and the parameter that yeilded the highest sensitivity score was selected.

# evaluation
rf_tune |> 
  collect_metrics()
## # A tibble: 12 × 7
##     mtry .metric     .estimator  mean     n std_err .config             
##    <dbl> <chr>       <chr>      <dbl> <int>   <dbl> <chr>               
##  1     3 accuracy    binary     0.862    10 0.00863 Preprocessor1_Model1
##  2     3 roc_auc     binary     0.751    10 0.0220  Preprocessor1_Model1
##  3     3 sensitivity binary     0.887    10 0.00794 Preprocessor1_Model1
##  4     3 specificity binary     0.388    10 0.0420  Preprocessor1_Model1
##  5     4 accuracy    binary     0.864    10 0.00879 Preprocessor1_Model2
##  6     4 roc_auc     binary     0.736    10 0.0245  Preprocessor1_Model2
##  7     4 sensitivity binary     0.891    10 0.00783 Preprocessor1_Model2
##  8     4 specificity binary     0.370    10 0.0420  Preprocessor1_Model2
##  9     5 accuracy    binary     0.864    10 0.00883 Preprocessor1_Model3
## 10     5 roc_auc     binary     0.713    10 0.0280  Preprocessor1_Model3
## 11     5 sensitivity binary     0.890    10 0.00743 Preprocessor1_Model3
## 12     5 specificity binary     0.365    10 0.0467  Preprocessor1_Model3
# selection
param_final <- rf_tune |> 
  select_best(metric = "sensitivity")
param_final
## # A tibble: 1 × 2
##    mtry .config             
##   <dbl> <chr>               
## 1     4 Preprocessor1_Model2

Finalized Workflow

The selected parameter was then used to finalize the workflow.

rf_workflow_tuned <- finalize_workflow(rf_workflow_2, param_final)

rf_workflow_tuned
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
## 
## • step_impute_knn()
## • step_dummy()
## • step_ordinalscore()
## • step_smote()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   mtry = 4
## 
## Computational engine: ranger

Fit to testing

The finalized workflow was fit to the testing data and evaluated based on the previously designated metrics.

rf_fit <- rf_workflow_tuned |> 
  last_fit(split = strk_split,
           metrics = stroke_metrics)

Evaluate Testing

The metrics were collected for final assessment.

rf_fit_metric <- collect_metrics(rf_fit)
rf_fit_metric
## # A tibble: 4 × 4
##   .metric     .estimator .estimate .config             
##   <chr>       <chr>          <dbl> <chr>               
## 1 accuracy    binary         0.846 Preprocessor1_Model1
## 2 sensitivity binary         0.867 Preprocessor1_Model1
## 3 specificity binary         0.386 Preprocessor1_Model1
## 4 roc_auc     binary         0.729 Preprocessor1_Model1
collect_predictions(rf_fit) |> 
  conf_mat(stroke, .pred_class)
##           Truth
## Prediction   0   1
##          0 848  27
##          1 130  17

Observation: The tuned random forest model yielded:

  • an accuracy of 0.8463796
  • a sensitivity of 0.8670757
  • a specificity of 0.3863636
  • a roc_auc of 0.7292364

Final Model

The tuned random forest model was fitted to the whole data set for training to prepare for new data for prediction.

final_model <- fit(rf_workflow_tuned, clean_strk)
final_model
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
## 
## • step_impute_knn()
## • step_dummy()
## • step_ordinalscore()
## • step_smote()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Ranger result
## 
## Call:
##  ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~4,      x), num.threads = 1, verbose = FALSE, seed = sample.int(10^5,      1), probability = TRUE) 
## 
## Type:                             Probability estimation 
## Number of trees:                  500 
## Sample size:                      9722 
## Number of independent variables:  5 
## Mtry:                             4 
## Target node size:                 10 
## Variable importance mode:         none 
## Splitrule:                        gini 
## OOB prediction error (Brier s.):  0.1060317

Prediction based on new data

The final model was used to predict whether a patient will have a stroke based on new data.

# new data for fictional patient
new_patient <- tribble(~age_group, ~bmi_group, ~glucose_group, ~hypertension, ~heart_disease,
                     "76+", "overweight", "Prediabetic", "1", "0")

# stroke prediction
prediction <- predict(final_model, new_data = new_patient)
prediction
## # A tibble: 1 × 1
##   .pred_class
##   <fct>      
## 1 1

Observation: Based on the new input data the model output was 1.

  • 1 = stroke
  • 0 = no stroke

Findings and Conclusions

The purpose of this machine learning project was to implement a model that could predict whether or not a patient is at risk of getting a stroke based on features relating to the patient’s medical history. The model created to serve this purpose was a Random Rorest Classifier that had an accuracy score of 0.8463796, a sensitivity score of 0.8670757, a specificity score of 0.3863636, and an roc auc score of 0.7292364.

  • an accuracy of 0.8463796 indicates that the model was successful at predicting whether or not a patient would have a stroke 84.6% of the time.

  • a specificity of 0.3863636 indicates that 38.6% of the time when the model predicts a positive (having a stroke) the instance is a true positve.

  • a specificity of 0.8670757 indicates that 86.7% of the time when something is positive, the model also predicts it to be positive.

  • An roc auc score of 0.7292364 is the discrimination between positive and negative values. This score ranges from 0 to 1, where 0.5 indicates random guessing and 1 indicates perfect performance.

Given all these metrics, the one that should be given more importance is the sensitivity score with a rather acceptable score of 0.8670757. This score should be looked at more heavily because its indicative that true positive cases (instances where a patient had a stroke) was picked up by the model to a higher degree, which for its application would mean that 86.7% of patients at risk of having a stroke would be able to be detected as such by this model. This juxtaposed with an accuracy of 84.6% makes this model acceptable for stroke prediction.