According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, being responsible for approximately 11% of total deaths. This comes out to roughly six and a half-million deaths annually.
It is also estimated that there are over 12.2 million new cases of strokes each year and that one in four people over 25 years of age will experience a stroke.
The ten leading factors, and there percent weight, that can increase the probability of a stroke are the following:
Statistic source: World Stoke Organization
This project aims to build and deploy a model that can predict whether or not a patient will have a stroke based on a variety of factors relating to patient data. This encompasses the patient’s medical history, and demographic information. The deployment of such a model could result in successful preventive care by focusing medical attention on high risk patients.
The dataset being used to build the prediction model comes from the open data source website kaggle. Below are a list of the attributes as depicted in the dataset and their description.
This specific model will utilize classification to predict the
outcome variable (stroke). Classification was chosen
because it deals with the prediction of categorical data (in this case
the label being whether or not a patient will have a stroke or not).
This model will also be considered supervised, due to the outcome
variable being pre-labeled. This means that the data can be validated by
the pre-existing defined stroke cases.
With classification being used, the means to evaluate the model would be through the use of the following metrics:
Below are the packages being loaded to successfully conduct this machine learning project. The packages are in order of use and are broken up which step in the process they are being utilized.
# Packages needed for data manipulation, visualization, model building, and model evaluation
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.1 ✔ tibble 3.2.1
## ✔ lubridate 1.9.3 ✔ tidyr 1.3.1
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.5 ✔ rsample 1.2.1
## ✔ dials 1.2.1 ✔ tune 1.2.1
## ✔ infer 1.0.7 ✔ workflows 1.1.4
## ✔ modeldata 1.3.0 ✔ workflowsets 1.1.0
## ✔ parsnip 1.2.1 ✔ yardstick 1.3.1
## ✔ recipes 1.0.10
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Dig deeper into tidy modeling with R at https://www.tmwr.org
library(workflows)
library(themis)
library(tune)
library(ranger)
library(xgboost)
##
## Attaching package: 'xgboost'
##
## The following object is masked from 'package:dplyr':
##
## slice
library(kknn)
The first step in this process is to import the data and assign it to an object for manipulation.
# set working directory
getwd()
setwd("/Users/nickwinters/desktop/DS Projects/R/strokePredictor")
# assign csv file to an object
strk <- read.csv("healthcare-dataset-stroke-data.csv")
Before diving into building the model it would be beneficial to get acclimated with the data.
# display the first five rows
head(strk, 5)
## id gender age hypertension heart_disease ever_married work_type
## 1 9046 Male 67 0 1 Yes Private
## 2 51676 Female 61 0 0 Yes Self-employed
## 3 31112 Male 80 0 1 Yes Private
## 4 60182 Female 49 0 0 Yes Private
## 5 1665 Female 79 1 0 Yes Self-employed
## Residence_type avg_glucose_level bmi smoking_status stroke
## 1 Urban 228.69 36.6 formerly smoked 1
## 2 Rural 202.21 N/A never smoked 1
## 3 Rural 105.92 32.5 never smoked 1
## 4 Urban 171.23 34.4 smokes 1
## 5 Rural 174.12 24 never smoked 1
# size and type of data
dim(strk)
## [1] 5110 12
summary(strk)
## id gender age hypertension
## Min. : 67 Length:5110 Min. : 0.08 Min. :0.00000
## 1st Qu.:17741 Class :character 1st Qu.:25.00 1st Qu.:0.00000
## Median :36932 Mode :character Median :45.00 Median :0.00000
## Mean :36518 Mean :43.23 Mean :0.09746
## 3rd Qu.:54682 3rd Qu.:61.00 3rd Qu.:0.00000
## Max. :72940 Max. :82.00 Max. :1.00000
## heart_disease ever_married work_type Residence_type
## Min. :0.00000 Length:5110 Length:5110 Length:5110
## 1st Qu.:0.00000 Class :character Class :character Class :character
## Median :0.00000 Mode :character Mode :character Mode :character
## Mean :0.05401
## 3rd Qu.:0.00000
## Max. :1.00000
## avg_glucose_level bmi smoking_status stroke
## Min. : 55.12 Length:5110 Length:5110 Min. :0.00000
## 1st Qu.: 77.25 Class :character Class :character 1st Qu.:0.00000
## Median : 91.89 Mode :character Mode :character Median :0.00000
## Mean :106.15 Mean :0.04873
## 3rd Qu.:114.09 3rd Qu.:0.00000
## Max. :271.74 Max. :1.00000
# lets specifically look at class of each column
sapply(strk, class)
## id gender age hypertension
## "integer" "character" "numeric" "integer"
## heart_disease ever_married work_type Residence_type
## "integer" "character" "character" "character"
## avg_glucose_level bmi smoking_status stroke
## "numeric" "character" "character" "integer"
# Lets check for any missing data
sum(is.na(strk))
## [1] 0
Observation: Above it is shown that the data is 12
columns by 5110 rows, and that the data is a mix of integers(4),
floats(2), and characters(6). Interesting to note is that the age column
is a float where the minimum value is 0.08, which could be indicative of
babies having a reported age in months. Secondly, bmi–a numerical
measurement–is classed as being a character. Finally, there are no
apparent data entries that are categorized as missing. This lack of
missing though is contradictive to what is seen when displaying the
first five rows of data. In the bmi column, an “N/A” is
shown. This means that there is missing data within the bmi
column and that the reason this column is classed as being a charater is
because of the presence of “N/A” strings. Also in the
smoking_status column it is known that “Unknown” is being
used to represent unavailable information. These issues will have to be
addressed in the data manipulation phase.
It is now time to prep the data for machine learning.
First the unrepresented data will be correctly displayed as NA, and the total count of missing data will be computed.
# address missing values ("N/A" was used for NA values)
strk[strk == 'N/A'] <- NA
strk[strk == 'Unknown'] <- NA
#validate the above change
sum(is.na(strk))
## [1] 1745
Next, because this is a classification task, the categorical
variables must be converted to factors. This prevents the model from
seeing “0s” and “1s” as numerical values but as representations of data.
In addition, bmi will also be converted to a double or
float data type.
# Change datatypes
strk <- transform(strk,
bmi = as.double(bmi),
gender = as.factor(gender),
ever_married = as.factor(ever_married),
work_type = as.factor(work_type),
Residence_type = as.factor(Residence_type),
smoking_status = as.factor(smoking_status),
hypertension = as.factor(hypertension),
heart_disease = as.factor(heart_disease),
stroke = as.factor(stroke))
# validate the above change
sapply(strk, class)
## id gender age hypertension
## "integer" "factor" "numeric" "factor"
## heart_disease ever_married work_type Residence_type
## "factor" "factor" "factor" "factor"
## avg_glucose_level bmi smoking_status stroke
## "numeric" "numeric" "factor" "factor"
To both simplify the model and address any outliers present in the
data, the age, bmi, and
avg_glucose_level columns will be binned and ordered into
classes based on common segmentation for disease analysis.
Sources used for choosing the classes:
# bin data into appropriate groupings
strk <- strk |>
mutate(age_group = cut(age,
breaks = c(0, 6, 19, 46, 65, 76, 100),
labels = c('0-5y', '6-18y', '19-45y', '46-64y', '65-75y', '76+'),
include.lowest = T,
right = F,
ordered_result = T),
bmi_group = cut(bmi,
breaks = c(0, 19, 25, 30, 35, 40, 100),
labels = c("underweight", "normal", "overweight", "class I obesity", "class II obesity", "class III obesity"),
include.lowest = T,
right = F,
ordered_result = T),
glucose_group = cut(avg_glucose_level,
breaks = c(0, 140, 200, 300),
labels = c("Normal", "Prediabetic", "Diabetic"),
include.lowest = T,
right = F,
ordered_result = T)
)
Once the values have been binned now, the new column’s attributes will be validated.
# validate the bins
strk |>
select(age, age_group, bmi, bmi_group, avg_glucose_level, glucose_group) |>
head()
## age age_group bmi bmi_group avg_glucose_level glucose_group
## 1 67 65-75y 36.6 class II obesity 228.69 Diabetic
## 2 61 46-64y NA <NA> 202.21 Diabetic
## 3 80 76+ 32.5 class I obesity 105.92 Normal
## 4 49 46-64y 34.4 class I obesity 171.23 Prediabetic
## 5 79 76+ 24.0 normal 174.12 Prediabetic
## 6 81 76+ 29.0 overweight 186.21 Prediabetic
# check order
is.ordered(strk$age_group)
## [1] TRUE
levels(strk$age_group)
## [1] "0-5y" "6-18y" "19-45y" "46-64y" "65-75y" "76+"
is.ordered(strk$bmi_group)
## [1] TRUE
levels(strk$bmi_group)
## [1] "underweight" "normal" "overweight"
## [4] "class I obesity" "class II obesity" "class III obesity"
is.ordered(strk$glucose_group)
## [1] TRUE
levels(strk$glucose_group)
## [1] "Normal" "Prediabetic" "Diabetic"
Now that the data has been appropriately transformed, it will now be visually explored to gain insights into what features seem to be the most important for determining whether a patient may or may not have a stroke. Based on these insights features of un-importance will be dropped from the data.
The outcome variable was graphed to get a sense of the distribution between patients that have had a stroke versus those who have never had a stroke.
strk |>
ggplot(aes(x = stroke, fill = stroke)) +
geom_bar()
Observation: Within the data the patients that have never had a stroke severely outnumber patients that have had a stroke. This class imbalance will have to be addressed in order to prevent the model from prioritising accuracy by biasing the negative outcome.
strk |>
ggplot(aes(x = gender, fill = stroke)) +
geom_bar(position = "fill")
Observation: Stroke incidents are relatively the same between males and females; therefore, gender must not play a significant role for stroke prediction.
strk |>
ggplot(aes(x = age_group, fill = stroke)) +
geom_bar(position = "fill")
Observation: A clear trend is shown that demonstrates that as age increases cases of stroke also increase; therefore, age must play a significant role in stroke prediction.
strk |>
ggplot(aes(x = hypertension, fill = stroke)) +
geom_bar(position = "fill")
Observation: Cases of stroke appear more heavily in the hypertension group; therefore, hypertension must play a significant role in stroke detection. This is also supported by the WSO reporting that high blood pressure is a leading cause of stroke.
strk |>
ggplot(aes(x = heart_disease, fill = stroke)) +
geom_bar(position = "fill")
Observation: Cases of stroke appear more heavily in the heart disease group; therefore, heart disease must play a significant role in stroke detection.
strk |>
ggplot(aes(x = ever_married, fill = stroke)) +
geom_bar(position = "fill")
strk |>
ggplot(aes(x = ever_married, fill = age_group)) +
geom_bar(position = "fill") +
facet_wrap(~stroke)
Observation: Stroke incidents are close in count between the “No” and “Yes” groups, with the latter having more counts. This is likely due to more younger individuals being present in the “No” group whereas the more at risk older individuals would be present in the “Yes” group. Marital status must not play a significant role for stroke prediction, likley because age is a prominent factor.
strk |>
ggplot(aes(x = work_type, fill = stroke)) +
geom_bar(position = "fill")
Observation: Stroke incidents are more prevalent in
groups that are employed, with self-employment having the highest
counts. Just like with ever_married, work_type
is likely also related to age, so to avoid the potntail of covariance,
work_type will likley be left out of the model.
strk |>
ggplot(aes(x = Residence_type, fill = stroke)) +
geom_bar(position = "fill")
Observation: Stroke incidents are relatively the same between Rural and Urban; therefore, residence must not play a significant role for stroke prediction.
strk |>
ggplot(aes(x = glucose_group, fill = stroke)) +
geom_bar(position = "fill")
Observation: A clear trend is shown that demonstrates that as glucose levels increase, cases of stroke also increase; therefore, glucose levels must play a significant role in stroke prediction. This is also supported by the WSO reporting that high fasted glucose levels is a leading cause of stroke.
strk |>
filter(!is.na(bmi_group)) |>
ggplot(aes(x = bmi_group, fill = stroke)) +
geom_bar(position = "fill")
strk |>
ggplot(aes(x = bmi_group, fill = stroke)) +
geom_bar(position = "fill")
Observation: Stroke cases are the most prevalent among higher weight classed individuals, with the highest cases occurring in the overweight and class I obese groups. BMI likely plays a role in stroke detection, and this is supported by the WSO reporting that BMI is a leading cause of stroke. It is important to note that when also graphing the missing values, this group had by far the highest stroke cases.
strk |>
ggplot(aes(x = smoking_status, fill = stroke)) +
geom_bar(position = "fill")
Observation: Looking specifically at the never smoked and smokes groupings, both displayed similar cases of stroke. The formerly smoked had the highest stroke cases. Despite the WSO reporting that smoking is a leading cause this data because the same counts between the oppoosite groups and the large presence of missing data (~1700), to simplify the model smoking status will be omitted.
Based on the data exploration above it appears that age, bmi, glucose levels, hypertension, and heart disease have the most significant impact on cases when a stroke occurred. For this reason, each respective column will be selected as features for the model.
clean_strk <- strk |>
select(age_group,
bmi_group,
glucose_group,
hypertension,
heart_disease,
stroke)
One of the most important steps in the machine learning process involves splitting the data into a training set and a test set. This is done where the model is created and trained using the training set, and then it is evaluated on the test set. This is implemented to determine how well the model generalizes with new data.
It is common practice to allocate 80% of the data to training and 20% of the data to testing.
# set random seed for replication
set.seed(42)
# split the data into training(80%) and testing (20%)
strk_split <- initial_split(clean_strk,
prop = 4/5)
strk_split
## <Training/Testing/Total>
## <4088/1022/5110>
# extracting training and test sets
train <- training(strk_split)
test <- testing(strk_split)
Cross fold analysis or validation is integral to fine tuning the machine learning process. Specifically it is a great way to determine how well the model will generalize to newer data. This works by splitting the training data into multiple folds or parts, where for different iterations of evaluation, each part gets a turn to become the validation set or the set that is going to act as the actual for the generated predictions.
The cross fold object below will mainly be used to fine tune the selected model.
# set random seed for reproducability
set.seed(42)
# create a CV object
stroke_folds <- vfold_cv(train, strata = stroke)
To start the process of building the model, a recipe is created that does the following:
The metrics being used to evaluate model performance will also be assigned.
# recipe creation
stroke_recipe <- recipe(stroke ~
age_group +
bmi_group +
glucose_group +
hypertension +
heart_disease,
data = clean_strk) |>
step_impute_knn(all_predictors()) |>
step_dummy(all_unordered_predictors()) |>
step_ordinalscore(all_ordered_predictors()) |>
step_smote(stroke)
# Specify desired metrics
stroke_metrics <- metric_set(roc_auc, accuracy, sensitivity, specificity)
Next the recipe is passed to a workflow to allow for the addition of different models and fits.
# add recipe to workflow
stroke_workflow <- workflow() |>
add_recipe(stroke_recipe)
The classification models that are going to be used for this task include:
This step assigns to an object:
# logistic regression model
lr_model <-
logistic_reg() |>
set_engine("glm") |>
set_mode("classification")
This step adds the classification model that was created above to the previously created workflow. This model + wokflow is then fit to the training data that has been preped for cross validation. This fit is evaluated based on desired metrics (accuracy, sensitivity, specificity, roc_auc).
# logistic regression model
lr_rs <- stroke_workflow|>
add_model(lr_model) |>
fit_resamples(
resamples = stroke_folds,
metrics = stroke_metrics,
control = control_resamples(save_pred = TRUE)
)
This step collects the performance metrics from the fit conducted above.
lr_metric <- collect_metrics(lr_rs)
lr_metric
## # A tibble: 4 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.777 10 0.00960 Preprocessor1_Model1
## 2 roc_auc binary 0.836 10 0.0137 Preprocessor1_Model1
## 3 sensitivity binary 0.780 10 0.0103 Preprocessor1_Model1
## 4 specificity binary 0.712 10 0.0441 Preprocessor1_Model1
collect_predictions(lr_rs) |>
conf_mat(truth = stroke, estimate = .pred_class)
## Truth
## Prediction 0 1
## 0 3031 58
## 1 852 147
Observation: The logistic regression model yielded:
This step assigns to an object:
# Random forest
rf_model <-
rand_forest() |>
set_engine("randomForest") |>
set_mode("classification")
This step adds the classification model that was created above to the previously created workflow. This model + wokflow is then fit to the training data that has been preped for cross validation. This fit is evaluated based on desired metrics (accuracy, sensitivity, specificity, roc_auc).
# logistic regression model
rf_rs <- stroke_workflow|>
add_model(rf_model) |>
fit_resamples(
resamples = stroke_folds,
metrics = stroke_metrics,
control = control_resamples(save_pred = TRUE)
)
This step collects the performance metrics from the fit conducted above.
rf_metric <- collect_metrics(rf_rs)
rf_metric
## # A tibble: 4 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.849 10 0.00927 Preprocessor1_Model1
## 2 roc_auc binary 0.793 10 0.0167 Preprocessor1_Model1
## 3 sensitivity binary 0.870 10 0.00815 Preprocessor1_Model1
## 4 specificity binary 0.458 10 0.0468 Preprocessor1_Model1
collect_predictions(rf_rs) |>
conf_mat(truth = stroke, estimate = .pred_class)
## Truth
## Prediction 0 1
## 0 3379 112
## 1 504 93
Observation: The logistic regression model yielded:
This step assigns to an object:
# kn
kn_model <- nearest_neighbor() |>
set_engine("kknn") |>
set_mode("classification")
This step adds the classification model that was created above to the previously created workflow. This model + wokflow is then fit to the training data that has been preped for cross validation. This fit is evaluated based on desired metrics (accuracy, sensitivity, specificity, roc_auc).
# logistic regression model
kn_rs <- stroke_workflow|>
add_model(kn_model) |>
fit_resamples(
resamples = stroke_folds,
metrics = stroke_metrics,
control = control_resamples(save_pred = TRUE)
)
This step collects the performance metrics from the fit conducted above.
kn_metric <- collect_metrics(kn_rs)
kn_metric
## # A tibble: 4 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.934 10 0.00344 Preprocessor1_Model1
## 2 roc_auc binary 0.678 10 0.0224 Preprocessor1_Model1
## 3 sensitivity binary 0.978 10 0.00195 Preprocessor1_Model1
## 4 specificity binary 0.0884 10 0.0195 Preprocessor1_Model1
collect_predictions(kn_rs) |>
conf_mat(truth = stroke, estimate = .pred_class)
## Truth
## Prediction 0 1
## 0 3799 186
## 1 84 19
Observation: The logistic regression model yielded:
This step assigns to an object:
# xg
xg_model <- boost_tree() |>
set_engine("xgboost") |>
set_mode("classification")
This step adds the classification model that was created above to the previously created workflow. This model + wokflow is then fit to the training data that has been preped for cross validation. This fit is evaluated based on desired metrics (accuracy, sensitivity, specificity, roc_auc).
# logistic regression model
xg_rs <- stroke_workflow|>
add_model(xg_model) |>
fit_resamples(
resamples = stroke_folds,
metrics = stroke_metrics,
control = control_resamples(save_pred = TRUE)
)
This step collects the performance metrics from the fit conducted above.
xg_metric <- collect_metrics(xg_rs)
xg_metric
## # A tibble: 4 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.841 10 0.0102 Preprocessor1_Model1
## 2 roc_auc binary 0.783 10 0.0199 Preprocessor1_Model1
## 3 sensitivity binary 0.860 10 0.00883 Preprocessor1_Model1
## 4 specificity binary 0.492 10 0.0480 Preprocessor1_Model1
collect_predictions(xg_rs) |>
conf_mat(truth = stroke, estimate = .pred_class)
## Truth
## Prediction 0 1
## 0 3340 105
## 1 543 100
Observation: The logistic regression model yielded:
Based on the captured metrics, the random forest model seems to be the best overall. It had the second highest accuracy, roc_auc, and sensitivity; as well as the third highest specificity.
In the context of disease detection and prevention, sensitivity is being given the higher priority sense this would limit the potential number of false negative cases. Even though this is the case it is also beneficial to have a desirable specificity. It is for this reason why the K-nearest neighbors model was not chosen. Despite this model having the highest accuracy and sensitivity, its specificity was lack luster. This means that this model allows for an increase in false negatives to achieve impressive true positive rates.
Now that the random forest model has been chosen, its parameters will now be adjusted to optimize its performance. The parameter being adjusted is metry (the number of randomly selected predictors). ### Stating the desired parameters
rf_grid <- expand.grid(mtry = c(3, 4, 5))
A new random forest model is being created with a flexible set_args.
rf_model_2 <- rand_forest() |>
set_args(mtry = tune()) |>
set_engine("ranger") |>
set_mode("classification")
The newly created model will now be added to the workflow.
rf_workflow_2 <- workflow() |>
add_recipe(stroke_recipe) |>
add_model(rf_model_2)
The workflow will now employ a grid search that will evaluate performance for each pre-recorded parameter.
rf_tune <- rf_workflow_2 |>
tune_grid(
resamples = stroke_folds,
grid = rf_grid,
metrics = stroke_metrics,
control = control_resamples(save_pred = TRUE)
)
The metrics were then collected across all the tested parameters, and the parameter that yeilded the highest sensitivity score was selected.
# evaluation
rf_tune |>
collect_metrics()
## # A tibble: 12 × 7
## mtry .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 3 accuracy binary 0.862 10 0.00863 Preprocessor1_Model1
## 2 3 roc_auc binary 0.751 10 0.0220 Preprocessor1_Model1
## 3 3 sensitivity binary 0.887 10 0.00794 Preprocessor1_Model1
## 4 3 specificity binary 0.388 10 0.0420 Preprocessor1_Model1
## 5 4 accuracy binary 0.864 10 0.00879 Preprocessor1_Model2
## 6 4 roc_auc binary 0.736 10 0.0245 Preprocessor1_Model2
## 7 4 sensitivity binary 0.891 10 0.00783 Preprocessor1_Model2
## 8 4 specificity binary 0.370 10 0.0420 Preprocessor1_Model2
## 9 5 accuracy binary 0.864 10 0.00883 Preprocessor1_Model3
## 10 5 roc_auc binary 0.713 10 0.0280 Preprocessor1_Model3
## 11 5 sensitivity binary 0.890 10 0.00743 Preprocessor1_Model3
## 12 5 specificity binary 0.365 10 0.0467 Preprocessor1_Model3
# selection
param_final <- rf_tune |>
select_best(metric = "sensitivity")
param_final
## # A tibble: 1 × 2
## mtry .config
## <dbl> <chr>
## 1 4 Preprocessor1_Model2
The selected parameter was then used to finalize the workflow.
rf_workflow_tuned <- finalize_workflow(rf_workflow_2, param_final)
rf_workflow_tuned
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
##
## • step_impute_knn()
## • step_dummy()
## • step_ordinalscore()
## • step_smote()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = 4
##
## Computational engine: ranger
The finalized workflow was fit to the testing data and evaluated based on the previously designated metrics.
rf_fit <- rf_workflow_tuned |>
last_fit(split = strk_split,
metrics = stroke_metrics)
The metrics were collected for final assessment.
rf_fit_metric <- collect_metrics(rf_fit)
rf_fit_metric
## # A tibble: 4 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.846 Preprocessor1_Model1
## 2 sensitivity binary 0.867 Preprocessor1_Model1
## 3 specificity binary 0.386 Preprocessor1_Model1
## 4 roc_auc binary 0.729 Preprocessor1_Model1
collect_predictions(rf_fit) |>
conf_mat(stroke, .pred_class)
## Truth
## Prediction 0 1
## 0 848 27
## 1 130 17
Observation: The tuned random forest model yielded:
The tuned random forest model was fitted to the whole data set for training to prepare for new data for prediction.
final_model <- fit(rf_workflow_tuned, clean_strk)
final_model
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
##
## • step_impute_knn()
## • step_dummy()
## • step_ordinalscore()
## • step_smote()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~4, x), num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE)
##
## Type: Probability estimation
## Number of trees: 500
## Sample size: 9722
## Number of independent variables: 5
## Mtry: 4
## Target node size: 10
## Variable importance mode: none
## Splitrule: gini
## OOB prediction error (Brier s.): 0.1060317
The final model was used to predict whether a patient will have a stroke based on new data.
# new data for fictional patient
new_patient <- tribble(~age_group, ~bmi_group, ~glucose_group, ~hypertension, ~heart_disease,
"76+", "overweight", "Prediabetic", "1", "0")
# stroke prediction
prediction <- predict(final_model, new_data = new_patient)
prediction
## # A tibble: 1 × 1
## .pred_class
## <fct>
## 1 1
Observation: Based on the new input data the model output was 1.
The purpose of this machine learning project was to implement a model that could predict whether or not a patient is at risk of getting a stroke based on features relating to the patient’s medical history. The model created to serve this purpose was a Random Rorest Classifier that had an accuracy score of 0.8463796, a sensitivity score of 0.8670757, a specificity score of 0.3863636, and an roc auc score of 0.7292364.
an accuracy of 0.8463796 indicates that the model was successful at predicting whether or not a patient would have a stroke 84.6% of the time.
a specificity of 0.3863636 indicates that 38.6% of the time when the model predicts a positive (having a stroke) the instance is a true positve.
a specificity of 0.8670757 indicates that 86.7% of the time when something is positive, the model also predicts it to be positive.
An roc auc score of 0.7292364 is the discrimination between positive and negative values. This score ranges from 0 to 1, where 0.5 indicates random guessing and 1 indicates perfect performance.
Given all these metrics, the one that should be given more importance is the sensitivity score with a rather acceptable score of 0.8670757. This score should be looked at more heavily because its indicative that true positive cases (instances where a patient had a stroke) was picked up by the model to a higher degree, which for its application would mean that 86.7% of patients at risk of having a stroke would be able to be detected as such by this model. This juxtaposed with an accuracy of 84.6% makes this model acceptable for stroke prediction.