This RMarkdown file contains the report of the data analysis done for the project on building and deploying a stroke prediction model in R. It contains analysis such as data exploration, summary statistics and building the prediction models. The final report was completed on Sat Mar 14 21:57:33 2026.
Data Description:
According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths.
This data set is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relevant information about the patient.
packages <- c(
"tidyverse", "tidymodels", "skimr", "naniar",
"themis", "ranger", "xgboost", "corrplot",
"vetiver", "plumber", "pins", "gt"
)
install_if_missing <- packages[!packages %in% installed.packages()[, "Package"]]
if (length(install_if_missing) > 0) {
install.packages(install_if_missing, repos = "https://cloud.r-project.org")
}
library(tidyverse)
library(tidymodels)
library(skimr)
library(naniar)
library(themis)
library(ranger)
library(xgboost)
library(corrplot)
library(vetiver)
library(plumber)
library(pins)
library(gt)
set.seed(42)
stroke_raw <- read_csv(
"healthcare-dataset-stroke-data.csv",
na = c("N/A", ""),
show_col_types = FALSE
)
cat("Dataset dimensions:", nrow(stroke_raw), "rows x", ncol(stroke_raw), "columns\n")## Dataset dimensions: 5110 rows x 12 columns
## Rows: 5,110
## Columns: 12
## $ id <dbl> 9046, 51676, 31112, 60182, 1665, 56669, 53882, 10434…
## $ gender <chr> "Male", "Female", "Male", "Female", "Female", "Male"…
## $ age <dbl> 67, 61, 80, 49, 79, 81, 74, 69, 59, 78, 81, 61, 54, …
## $ hypertension <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1…
## $ heart_disease <dbl> 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0…
## $ ever_married <chr> "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No…
## $ work_type <chr> "Private", "Self-employed", "Private", "Private", "S…
## $ Residence_type <chr> "Urban", "Rural", "Rural", "Urban", "Rural", "Urban"…
## $ avg_glucose_level <dbl> 228.69, 202.21, 105.92, 171.23, 174.12, 186.21, 70.0…
## $ bmi <dbl> 36.6, NA, 32.5, 34.4, 24.0, 29.0, 27.4, 22.8, NA, 24…
## $ smoking_status <chr> "formerly smoked", "never smoked", "never smoked", "…
## $ stroke <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
| Name | stroke_raw |
| Number of rows | 5110 |
| Number of columns | 12 |
| _______________________ | |
| Column type frequency: | |
| character | 5 |
| numeric | 7 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| gender | 0 | 1 | 4 | 6 | 0 | 3 | 0 |
| ever_married | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
| work_type | 0 | 1 | 7 | 13 | 0 | 5 | 0 |
| Residence_type | 0 | 1 | 5 | 5 | 0 | 2 | 0 |
| smoking_status | 0 | 1 | 6 | 15 | 0 | 4 | 0 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| id | 0 | 1.00 | 36517.83 | 21161.72 | 67.00 | 17741.25 | 36932.00 | 54682.00 | 72940.00 | ▇▇▇▇▇ |
| age | 0 | 1.00 | 43.23 | 22.61 | 0.08 | 25.00 | 45.00 | 61.00 | 82.00 | ▅▆▇▇▆ |
| hypertension | 0 | 1.00 | 0.10 | 0.30 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 | ▇▁▁▁▁ |
| heart_disease | 0 | 1.00 | 0.05 | 0.23 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 | ▇▁▁▁▁ |
| avg_glucose_level | 0 | 1.00 | 106.15 | 45.28 | 55.12 | 77.24 | 91.88 | 114.09 | 271.74 | ▇▃▁▁▁ |
| bmi | 201 | 0.96 | 28.89 | 7.85 | 10.30 | 23.50 | 28.10 | 33.10 | 97.60 | ▇▇▁▁▁ |
| stroke | 0 | 1.00 | 0.05 | 0.22 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 | ▇▁▁▁▁ |
vis_miss(stroke_raw) +
labs(title = "Missing Value Map") +
theme_minimal() +
theme(
axis.text.x = element_text(angle = 45, hjust = 1, size = 9)
)stroke_raw %>%
summarise(across(everything(), ~ sum(is.na(.)))) %>%
pivot_longer(everything(), names_to = "variable", values_to = "n_missing") %>%
filter(n_missing > 0) %>%
mutate(pct_missing = round(n_missing / nrow(stroke_raw) * 100, 2)) %>%
gt() %>%
tab_header(title = "Missing Values by Column")| Missing Values by Column | ||
| variable | n_missing | pct_missing |
|---|---|---|
| bmi | 201 | 3.93 |
stroke_raw %>%
count(stroke) %>%
mutate(
label = if_else(stroke == 1, "Stroke", "No Stroke"),
pct = round(n / sum(n) * 100, 1)
) %>%
ggplot(aes(x = label, y = n, fill = label)) +
geom_col(width = 0.5, show.legend = FALSE) +
geom_text(aes(label = paste0(n, " (", pct, "%)")), vjust = -0.4, size = 4) +
scale_fill_manual(values = c("No Stroke" = "#2196F3", "Stroke" = "#F44336")) +
labs(
title = "Distribution of Stroke Cases",
subtitle = "Severe class imbalance — SMOTE will be applied during modelling",
x = NULL, y = "Count"
) +
theme_minimal(base_size = 13)stroke_raw %>%
mutate(stroke_label = if_else(stroke == 1, "Stroke", "No Stroke")) %>%
ggplot(aes(x = age, fill = stroke_label)) +
geom_histogram(binwidth = 5, alpha = 0.7, position = "identity") +
scale_fill_manual(values = c("No Stroke" = "#2196F3", "Stroke" = "#F44336")) +
labs(
title = "Age Distribution by Stroke Outcome",
x = "Age", y = "Count", fill = NULL
) +
theme_minimal(base_size = 13) +
facet_wrap(~stroke_label, ncol = 1)stroke_raw %>%
mutate(stroke_label = if_else(stroke == 1, "Stroke", "No Stroke")) %>%
ggplot(aes(x = avg_glucose_level, y = bmi, colour = stroke_label)) +
geom_point(alpha = 0.4, size = 1.5) +
scale_colour_manual(values = c("No Stroke" = "#2196F3", "Stroke" = "#F44336")) +
labs(
title = "Glucose Level vs BMI by Stroke Outcome",
x = "Average Glucose Level", y = "BMI", colour = NULL
) +
theme_minimal(base_size = 13)cat_vars <- c("gender", "hypertension", "heart_disease",
"ever_married", "work_type", "smoking_status")
stroke_raw %>%
select(all_of(cat_vars), stroke) %>%
mutate(across(all_of(cat_vars), as.character)) %>%
pivot_longer(-stroke, names_to = "variable", values_to = "value") %>%
group_by(variable, value) %>%
summarise(stroke_rate = mean(stroke, na.rm = TRUE), .groups = "drop") %>%
ggplot(aes(x = reorder(value, stroke_rate), y = stroke_rate, fill = variable)) +
geom_col(show.legend = FALSE) +
coord_flip() +
facet_wrap(~variable, scales = "free_y") +
scale_y_continuous(labels = scales::percent_format()) +
labs(
title = "Stroke Rate by Categorical Variable",
x = NULL, y = "Stroke Rate"
) +
theme_minimal(base_size = 11)stroke_raw %>%
select(age, avg_glucose_level, bmi, stroke) %>%
drop_na() %>%
cor() %>%
corrplot(
method = "color",
type = "upper",
addCoef.col = "black",
tl.col = "black",
title = "Correlation Matrix of Numeric Features",
mar = c(0, 0, 2, 0)
)stroke_clean <- stroke_raw %>%
select(-id) %>%
filter(gender != "Other") %>%
mutate(
stroke = factor(stroke, levels = c("1", "0"),
labels = c("stroke", "no_stroke")),
hypertension = factor(hypertension),
heart_disease = factor(heart_disease),
gender = factor(gender),
ever_married = factor(ever_married),
work_type = factor(work_type),
Residence_type = factor(Residence_type),
smoking_status = factor(smoking_status)
)
cat("Cleaned dataset:", nrow(stroke_clean), "rows\n")## Cleaned dataset: 5109 rows
## Stroke prevalence: 4.87 %
stroke_split <- initial_split(stroke_clean, prop = 0.80, strata = stroke)
stroke_train <- training(stroke_split)
stroke_test <- testing(stroke_split)
cat("Training rows:", nrow(stroke_train), "\n")## Training rows: 4087
## Testing rows : 1022
## # 5-fold cross-validation using stratification
## # A tibble: 5 × 2
## splits id
## <list> <chr>
## 1 <split [3269/818]> Fold1
## 2 <split [3269/818]> Fold2
## 3 <split [3270/817]> Fold3
## 4 <split [3270/817]> Fold4
## 5 <split [3270/817]> Fold5
stroke_recipe <- recipe(stroke ~ ., data = stroke_train) %>%
step_impute_median(bmi) %>%
step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
step_normalize(all_numeric_predictors()) %>%
step_smote(stroke, over_ratio = 1, seed = 42)
stroke_recipe %>% prep() %>% bake(new_data = NULL) %>% head()## # A tibble: 6 × 23
## age avg_glucose_level bmi gender_Female gender_Male hypertension_X0
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 1.06 2.74 0.997 -1.18 1.18 0.321
## 2 0.796 2.15 -0.108 0.847 -0.847 0.321
## 3 1.63 0.00162 0.471 -1.18 1.18 0.321
## 4 0.266 1.46 0.715 0.847 -0.847 0.321
## 5 1.59 1.52 -0.622 0.847 -0.847 -3.11
## 6 1.68 1.79 0.0209 -1.18 1.18 0.321
## # ℹ 17 more variables: hypertension_X1 <dbl>, heart_disease_X0 <dbl>,
## # heart_disease_X1 <dbl>, ever_married_No <dbl>, ever_married_Yes <dbl>,
## # work_type_children <dbl>, work_type_Govt_job <dbl>,
## # work_type_Never_worked <dbl>, work_type_Private <dbl>,
## # work_type_Self.employed <dbl>, Residence_type_Rural <dbl>,
## # Residence_type_Urban <dbl>, smoking_status_formerly.smoked <dbl>,
## # smoking_status_never.smoked <dbl>, smoking_status_smokes <dbl>, …
logistic_spec <- logistic_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet") %>%
set_mode("classification")
logistic_wf <- workflow() %>%
add_recipe(stroke_recipe) %>%
add_model(logistic_spec)
logistic_grid <- grid_regular(penalty(range = c(-4, 0)), levels = 10)
logistic_results <- tune_grid(
logistic_wf,
resamples = stroke_folds,
grid = logistic_grid,
metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
control = control_grid(save_pred = TRUE, verbose = FALSE)
)
best_logistic <- select_best(logistic_results, metric = "roc_auc")
cat("Best logistic penalty:", best_logistic$penalty, "\n")## Best logistic penalty: 0.01668101
rf_spec <- rand_forest(mtry = tune(), trees = tune(), min_n = tune()) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("classification")
rf_wf <- workflow() %>%
add_recipe(stroke_recipe) %>%
add_model(rf_spec)
rf_grid <- grid_regular(
mtry(range = c(2, 8)),
trees(range = c(100, 400)),
min_n(range = c(5, 20)),
levels = 3
)
rf_results <- tune_grid(
rf_wf,
resamples = stroke_folds,
grid = rf_grid,
metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
control = control_grid(save_pred = TRUE, verbose = FALSE)
)
best_rf <- select_best(rf_results, metric = "roc_auc")
cat("Best RF params — mtry:", best_rf$mtry,
"| trees:", best_rf$trees,
"| min_n:", best_rf$min_n, "\n")## Best RF params — mtry: 8 | trees: 400 | min_n: 20
xgb_spec <- boost_tree(
trees = tune(),
learn_rate = tune(),
tree_depth = tune(),
loss_reduction = tune()
) %>%
set_engine("xgboost") %>%
set_mode("classification")
xgb_wf <- workflow() %>%
add_recipe(stroke_recipe) %>%
add_model(xgb_spec)
xgb_grid <- grid_latin_hypercube(
trees(range = c(100, 500)),
learn_rate(range = c(-3, -1)),
tree_depth(range = c(3, 8)),
loss_reduction(),
size = 15
)
xgb_results <- tune_grid(
xgb_wf,
resamples = stroke_folds,
grid = xgb_grid,
metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
control = control_grid(save_pred = TRUE, verbose = FALSE)
)
best_xgb <- select_best(xgb_results, metric = "roc_auc")
cat("Best XGBoost params:\n")## Best XGBoost params:
## # A tibble: 1 × 5
## trees tree_depth learn_rate loss_reduction .config
## <int> <int> <dbl> <dbl> <chr>
## 1 293 4 0.00441 0.00000000630 pre0_mod08_post0
model_comparison <- bind_rows(
collect_metrics(logistic_results) %>%
filter(.metric == "roc_auc") %>%
slice_max(mean, n = 1) %>%
mutate(model = "Logistic Regression"),
collect_metrics(rf_results) %>%
filter(.metric == "roc_auc") %>%
slice_max(mean, n = 1) %>%
mutate(model = "Random Forest"),
collect_metrics(xgb_results) %>%
filter(.metric == "roc_auc") %>%
slice_max(mean, n = 1) %>%
mutate(model = "XGBoost")
) %>%
select(model, mean, std_err) %>%
rename(cv_roc_auc = mean)
model_comparison %>%
gt() %>%
tab_header(title = "Cross-Validation ROC AUC Comparison") %>%
fmt_number(columns = c(cv_roc_auc, std_err), decimals = 4)| Cross-Validation ROC AUC Comparison | ||
| model | cv_roc_auc | std_err |
|---|---|---|
| Logistic Regression | 0.8418 | 0.0074 |
| Random Forest | 0.8056 | 0.0190 |
| XGBoost | 0.8247 | 0.0135 |
model_comparison %>%
ggplot(aes(x = reorder(model, cv_roc_auc), y = cv_roc_auc, fill = model)) +
geom_col(width = 0.5, show.legend = FALSE) +
geom_errorbar(aes(ymin = cv_roc_auc - std_err,
ymax = cv_roc_auc + std_err), width = 0.15) +
coord_flip() +
ylim(0, 1) +
scale_fill_brewer(palette = "Set2") +
labs(
title = "Cross-Validation ROC AUC by Model",
x = NULL, y = "Mean ROC AUC (5-fold CV)"
) +
theme_minimal(base_size = 13)final_rf_wf <- finalize_workflow(rf_wf, best_rf)
final_rf_fit <- last_fit(final_rf_wf, stroke_split)
test_metrics <- collect_metrics(final_rf_fit)
test_metrics %>%
gt() %>%
tab_header(title = "Random Forest — Test Set Metrics")| Random Forest — Test Set Metrics | |||
| .metric | .estimator | .estimate | .config |
|---|---|---|---|
| accuracy | binary | 0.8864971 | pre0_mod0_post0 |
| roc_auc | binary | 0.7513057 | pre0_mod0_post0 |
| brier_class | binary | 0.0841723 | pre0_mod0_post0 |
test_preds <- collect_predictions(final_rf_fit)
extended <- bind_rows(
sensitivity(test_preds, truth = stroke, estimate = .pred_class,
event_level = "first"),
specificity(test_preds, truth = stroke, estimate = .pred_class,
event_level = "first"),
f_meas(test_preds, truth = stroke, estimate = .pred_class,
event_level = "first"),
accuracy(test_preds, truth = stroke, estimate = .pred_class),
roc_auc(test_preds, truth = stroke, .pred_stroke,
event_level = "first")
)
extended %>%
select(.metric, .estimate) %>%
rename(metric = .metric, value = .estimate) %>%
mutate(value = round(value, 4)) %>%
gt() %>%
tab_header(title = "Extended Test-Set Performance — Random Forest")| Extended Test-Set Performance — Random Forest | |
| metric | value |
|---|---|
| sensitivity | 0.1224 |
| specificity | 0.9250 |
| f_meas | 0.0938 |
| accuracy | 0.8865 |
| roc_auc | 0.7513 |
test_preds %>%
conf_mat(truth = stroke, estimate = .pred_class) %>%
autoplot(type = "heatmap") +
scale_fill_gradient(low = "#FFFFFF", high = "#1565C0") +
labs(title = "Confusion Matrix — Random Forest (Test Set)")test_preds %>%
roc_curve(truth = stroke, .pred_stroke, event_level = "first") %>%
autoplot() +
labs(
title = "ROC Curve — Random Forest (Test Set)",
subtitle = paste0("AUC = ", round(
roc_auc(test_preds, truth = stroke,
.pred_stroke, event_level = "first")$.estimate, 4))
) +
theme_minimal(base_size = 13)final_rf_wf %>%
fit(data = stroke_train) %>%
extract_fit_parsnip() %>%
vip::vi() %>%
slice_max(Importance, n = 15) %>%
ggplot(aes(x = reorder(Variable, Importance), y = Importance)) +
geom_col(fill = "#1565C0") +
coord_flip() +
labs(
title = "Top 15 Variable Importances — Random Forest",
x = NULL, y = "Mean Decrease in Impurity"
) +
theme_minimal(base_size = 12)final_model_fit <- fit(final_rf_wf, data = stroke_clean)
v <- vetiver_model(
model = final_model_fit,
model_name = "stroke_prediction_rf",
description = "Random Forest model to predict stroke probability"
)
print(v)##
## ── stroke_prediction_rf ─ <bundled_workflow> model for deployment
## Random Forest model to predict stroke probability using 10 features
model_board <- board_folder("model_board", versioned = FALSE)
vetiver_pin_write(model_board, v)
vetiver_write_plumber(model_board, "stroke_prediction_rf", file = "plumber.R")
cat("plumber.R has been written.\n")## plumber.R has been written.
## To start the API locally, run:
## pr <- plumber::plumb("plumber.R")
## pr$run(port = 8000)
## Then test predictions with:
## library(httr2)
## req <- request("http://127.0.0.1:8000/predict") |>
## req_body_json(list(
## gender = "Male", age = 67, hypertension = "0",
## heart_disease = "1", ever_married = "Yes",
## work_type = "Private", Residence_type = "Urban",
## avg_glucose_level = 228.69, bmi = 36.6,
## smoking_status = "formerly smoked"
## )) |>
## req_perform()
## resp_body_json(req)
library(shiny)
ui <- fluidPage(
titlePanel("Stroke Risk Prediction"),
sidebarLayout(
sidebarPanel(
numericInput("age", "Age", value = 50, min = 1, max = 110),
numericInput("bmi", "BMI", value = 25, min = 10, max = 60),
numericInput("glucose", "Avg Glucose Level", value = 100, min = 40, max = 300),
selectInput("gender", "Gender", choices = c("Male", "Female")),
selectInput("hypertension","Hypertension", choices = c("0", "1")),
selectInput("heart_disease","Heart Disease", choices = c("0", "1")),
selectInput("ever_married","Ever Married", choices = c("Yes", "No")),
selectInput("work_type", "Work Type",
choices = c("Private","Self-employed","Govt_job","children","Never_worked")),
selectInput("Residence_type","Residence Type", choices = c("Urban","Rural")),
selectInput("smoking_status","Smoking Status",
choices = c("never smoked","formerly smoked","smokes","Unknown")),
actionButton("predict_btn", "Predict", class = "btn-primary")
),
mainPanel(
h3("Stroke Probability"),
verbatimTextOutput("result"),
plotOutput("gauge_plot")
)
)
)
server <- function(input, output, session) {
observeEvent(input$predict_btn, {
new_patient <- tibble(
gender = input$gender,
age = input$age,
hypertension = input$hypertension,
heart_disease = input$heart_disease,
ever_married = input$ever_married,
work_type = input$work_type,
Residence_type = input$Residence_type,
avg_glucose_level = input$glucose,
bmi = input$bmi,
smoking_status = input$smoking_status
) %>%
mutate(
hypertension = factor(hypertension, levels = levels(stroke_clean$hypertension)),
heart_disease = factor(heart_disease, levels = levels(stroke_clean$heart_disease)),
gender = factor(gender, levels = levels(stroke_clean$gender)),
ever_married = factor(ever_married, levels = levels(stroke_clean$ever_married)),
work_type = factor(work_type, levels = levels(stroke_clean$work_type)),
Residence_type = factor(Residence_type, levels = levels(stroke_clean$Residence_type)),
smoking_status = factor(smoking_status, levels = levels(stroke_clean$smoking_status))
)
prob <- predict(final_model_fit, new_patient, type = "prob")$.pred_stroke
risk_label <- if_else(prob > 0.5, "HIGH RISK", "LOW RISK")
output$result <- renderText({
paste0(
"Stroke Probability: ", round(prob * 100, 1), "%\n",
"Risk Assessment: ", risk_label
)
})
output$gauge_plot <- renderPlot({
ggplot(tibble(x = "Risk", y = prob), aes(x = x, y = y, fill = y)) +
geom_col(width = 0.4) +
scale_y_continuous(limits = c(0, 1), labels = scales::percent_format()) +
scale_fill_gradient(low = "#4CAF50", high = "#F44336", limits = c(0, 1)) +
labs(title = "Stroke Probability", x = NULL, y = NULL) +
theme_minimal(base_size = 14) +
theme(legend.position = "none")
})
})
}
shinyApp(ui, server)This analysis built and evaluated three classification models to predict stroke risk from clinical and demographic patient data. Key steps included:
Data Exploration: The dataset contained 5,110 patient records with 11 features. Significant class imbalance was identified — only ~4.9% of patients experienced a stroke, making standard accuracy misleading as a sole metric.
Data Preprocessing:
Model Building: Three classification models were trained and tuned via 5-fold stratified cross-validation:
Model Evaluation: Models were compared by cross-validation ROC AUC, then the best model was evaluated on the held-out test set using accuracy, sensitivity, specificity, F1, and ROC AUC.