This RMarkdown file contains the report of the data analysis done for the project on building and deploying a stroke prediction model in R. It contains analysis such as data exploration, summary statistics and building the prediction models. The final report was completed on Tue Nov 5 05:49:00 2024.
Data Description:
According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths.
This data set is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relevant information about the patient.
# Install Package
install.packages("pROC")
## Installing package into '/usr/local/lib/R/site-library'
## (as 'lib' is unspecified)
install.packages("tidyverse")
## Installing package into '/usr/local/lib/R/site-library'
## (as 'lib' is unspecified)
install.packages("caret")
## Installing package into '/usr/local/lib/R/site-library'
## (as 'lib' is unspecified)
install.packages("corrplot")
## Installing package into '/usr/local/lib/R/site-library'
## (as 'lib' is unspecified)
install.packages("mice")
## Installing package into '/usr/local/lib/R/site-library'
## (as 'lib' is unspecified)
install.packages("randomForest")
## Installing package into '/usr/local/lib/R/site-library'
## (as 'lib' is unspecified)
# Load required libraries
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.2 ✔ readr 2.1.4
## ✔ forcats 1.0.0 ✔ stringr 1.5.0
## ✔ ggplot2 3.4.2 ✔ tibble 3.2.1
## ✔ lubridate 1.9.2 ✔ tidyr 1.3.0
## ✔ purrr 1.0.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
##
## The following object is masked from 'package:purrr':
##
## lift
library(dplyr)
library(tidyr)
library(ggplot2)
library(corrplot)
## corrplot 0.92 loaded
library(mice)
##
## Attaching package: 'mice'
##
## The following object is masked from 'package:stats':
##
## filter
##
## The following objects are masked from 'package:base':
##
## cbind, rbind
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
##
## The following object is masked from 'package:dplyr':
##
## combine
##
## The following object is masked from 'package:ggplot2':
##
## margin
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
##
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
getwd()
## [1] "/home/rstudio/Build-deploy-stroke-prediction-model-R"
# Read the dataset
# Check if file exists first
if (file.exists("healthcare-dataset-stroke-data.csv")) {
stroke_data <- read.csv("healthcare-dataset-stroke-data.csv")
print("Data loaded successfully")
} else {
print("File not found in current directory")
}
## [1] "Data loaded successfully"
# Display the first few rows and basic information
head(stroke_data)
## id gender age hypertension heart_disease ever_married work_type
## 1 9046 Male 67 0 1 Yes Private
## 2 51676 Female 61 0 0 Yes Self-employed
## 3 31112 Male 80 0 1 Yes Private
## 4 60182 Female 49 0 0 Yes Private
## 5 1665 Female 79 1 0 Yes Self-employed
## 6 56669 Male 81 0 0 Yes Private
## Residence_type avg_glucose_level bmi smoking_status stroke
## 1 Urban 228.69 36.6 formerly smoked 1
## 2 Rural 202.21 N/A never smoked 1
## 3 Rural 105.92 32.5 never smoked 1
## 4 Urban 171.23 34.4 smokes 1
## 5 Rural 174.12 24 never smoked 1
## 6 Urban 186.21 29 formerly smoked 1
str(stroke_data)
## 'data.frame': 5110 obs. of 12 variables:
## $ id : int 9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
## $ gender : chr "Male" "Female" "Male" "Female" ...
## $ age : num 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : int 0 0 0 0 1 0 1 0 0 0 ...
## $ heart_disease : int 1 0 1 0 0 0 1 0 0 0 ...
## $ ever_married : chr "Yes" "Yes" "Yes" "Yes" ...
## $ work_type : chr "Private" "Self-employed" "Private" "Private" ...
## $ Residence_type : chr "Urban" "Rural" "Rural" "Urban" ...
## $ avg_glucose_level: num 229 202 106 171 174 ...
## $ bmi : chr "36.6" "N/A" "32.5" "34.4" ...
## $ smoking_status : chr "formerly smoked" "never smoked" "never smoked" "smokes" ...
## $ stroke : int 1 1 1 1 1 1 1 1 1 1 ...
summary(stroke_data)
## id gender age hypertension
## Min. : 67 Length:5110 Min. : 0.08 Min. :0.00000
## 1st Qu.:17741 Class :character 1st Qu.:25.00 1st Qu.:0.00000
## Median :36932 Mode :character Median :45.00 Median :0.00000
## Mean :36518 Mean :43.23 Mean :0.09746
## 3rd Qu.:54682 3rd Qu.:61.00 3rd Qu.:0.00000
## Max. :72940 Max. :82.00 Max. :1.00000
## heart_disease ever_married work_type Residence_type
## Min. :0.00000 Length:5110 Length:5110 Length:5110
## 1st Qu.:0.00000 Class :character Class :character Class :character
## Median :0.00000 Mode :character Mode :character Mode :character
## Mean :0.05401
## 3rd Qu.:0.00000
## Max. :1.00000
## avg_glucose_level bmi smoking_status stroke
## Min. : 55.12 Length:5110 Length:5110 Min. :0.00000
## 1st Qu.: 77.25 Class :character Class :character 1st Qu.:0.00000
## Median : 91.89 Mode :character Mode :character Median :0.00000
## Mean :106.15 Mean :0.04873
## 3rd Qu.:114.09 3rd Qu.:0.00000
## Max. :271.74 Max. :1.00000
# Basic data exploration
# Convert categorical variables to factors
stroke_data$gender <- as.factor(stroke_data$gender)
stroke_data$ever_married <- as.factor(stroke_data$ever_married)
stroke_data$work_type <- as.factor(stroke_data$work_type)
stroke_data$Residence_type <- as.factor(stroke_data$Residence_type)
stroke_data$smoking_status <- as.factor(stroke_data$smoking_status)
stroke_data$bmi <- as.numeric(stroke_data$bmi)
## Warning: NAs introduced by coercion
stroke_data$stroke <- as.factor(stroke_data$stroke)
# Get distribution of stroke cases
stroke_distribution <- stroke_data %>%
group_by(stroke) %>%
summarise(count = n(),
percentage = n()/nrow(stroke_data)*100,
.groups = 'drop')
print("Distribution of Stroke Cases:")
## [1] "Distribution of Stroke Cases:"
print(stroke_distribution)
## # A tibble: 2 × 3
## stroke count percentage
## <fct> <int> <dbl>
## 1 0 4861 95.1
## 2 1 249 4.87
# Calculate summary statistics by stroke status
stroke_summary <- stroke_data %>%
group_by(stroke) %>%
summarise(
avg_age = mean(age),
avg_glucose = mean(avg_glucose_level),
avg_bmi = mean(bmi, na.rm = TRUE),
hypertension_rate = mean(hypertension) * 100,
heart_disease_rate = mean(heart_disease) * 100,
.groups = 'drop'
)
print("\nSummary Statistics by Stroke Status:")
## [1] "\nSummary Statistics by Stroke Status:"
print(stroke_summary)
## # A tibble: 2 × 6
## stroke avg_age avg_glucose avg_bmi hypertension_rate heart_disease_rate
## <fct> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 0 42.0 105. 28.8 8.89 4.71
## 2 1 67.7 133. 30.5 26.5 18.9
# Create age groups for better visualization
stroke_data$age_group <- cut(stroke_data$age,
breaks = c(0, 20, 40, 60, 80, 100),
labels = c("0-20", "21-40", "41-60", "61-80", "80+"))
# Age distribution analysis
age_distribution <- stroke_data %>%
group_by(age_group, stroke) %>%
summarise(count = n(),
.groups = 'drop') %>%
pivot_wider(names_from = stroke,
values_from = count,
names_prefix = "stroke_")
print("\nAge Distribution by Stroke Status:")
## [1] "\nAge Distribution by Stroke Status:"
print(age_distribution)
## # A tibble: 5 × 3
## age_group stroke_0 stroke_1
## <fct> <int> <int>
## 1 0-20 1023 2
## 2 21-40 1213 6
## 3 41-60 1498 64
## 4 61-80 1034 154
## 5 80+ 93 23
# Check missing values
missing_values <- colSums(is.na(stroke_data))
print("\nMissing Values in Each Column:")
## [1] "\nMissing Values in Each Column:"
print(missing_values)
## id gender age hypertension
## 0 0 0 0
## heart_disease ever_married work_type Residence_type
## 0 0 0 0
## avg_glucose_level bmi smoking_status stroke
## 0 201 0 0
## age_group
## 0
# Create multiple visualizations
library(gridExtra)
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:randomForest':
##
## combine
## The following object is masked from 'package:dplyr':
##
## combine
# 1. Age Distribution Plot
age_plot <- ggplot(stroke_data, aes(x = age, fill = stroke)) +
geom_histogram(bins = 30, alpha = 0.6, position = "identity") +
labs(title = "Age Distribution by Stroke Status",
x = "Age",
y = "Count") +
theme_minimal()
# 2. BMI vs Glucose Level Plot
bmi_glucose_plot <- ggplot(stroke_data, aes(x = bmi, y = avg_glucose_level, color = stroke)) +
geom_point(alpha = 0.5) +
labs(title = "BMI vs Average Glucose Level",
x = "BMI",
y = "Average Glucose Level") +
theme_minimal()
# 3. Stroke Distribution by Gender
gender_plot <- ggplot(stroke_data, aes(x = gender, fill = stroke)) +
geom_bar(position = "fill") +
labs(title = "Stroke Distribution by Gender",
x = "Gender",
y = "Proportion") +
theme_minimal()
# 4. Medical Conditions Impact
medical_data <- stroke_data %>%
gather(key = "condition", value = "status",
c("hypertension", "heart_disease")) %>%
mutate(status = as.factor(status))
medical_plot <- ggplot(medical_data, aes(x = condition, fill = stroke)) +
geom_bar(position = "fill") +
labs(title = "Stroke Distribution by Medical Conditions",
x = "Condition",
y = "Proportion") +
theme_minimal()
# Arrange all plots in a grid
grid.arrange(age_plot, bmi_glucose_plot, gender_plot, medical_plot, ncol = 2)
## Warning: Removed 201 rows containing missing values (`geom_point()`).
# Create summary statistics table
summary_stats <- data.frame(
Metric = c("Total Observations",
"Missing BMI Values",
"Age Range",
"Average Glucose Level Range"),
Value = c(nrow(stroke_data),
sum(is.na(stroke_data$bmi)),
paste(min(stroke_data$age), "-", max(stroke_data$age)),
paste(round(min(stroke_data$avg_glucose_level),2), "-",
round(max(stroke_data$avg_glucose_level),2)))
)
print("Summary Statistics:")
## [1] "Summary Statistics:"
print(summary_stats)
## Metric Value
## 1 Total Observations 5110
## 2 Missing BMI Values 201
## 3 Age Range 0.08 - 82
## 4 Average Glucose Level Range 55.12 - 271.74
# 1. First, let's handle missing values in BMI
# Create a copy of the dataset
stroke_processed <- stroke_data
# Impute missing BMI values using median by age group and gender
stroke_processed$bmi <- as.numeric(stroke_processed$bmi) # Ensure BMI is numeric
stroke_processed <- stroke_processed %>%
group_by(age_group, gender) %>%
mutate(bmi = ifelse(is.na(bmi), median(bmi, na.rm = TRUE), bmi)) %>%
ungroup()
# Check if any missing values remain
print("Remaining missing values after imputation:")
## [1] "Remaining missing values after imputation:"
print(colSums(is.na(stroke_processed)))
## id gender age hypertension
## 0 0 0 0
## heart_disease ever_married work_type Residence_type
## 0 0 0 0
## avg_glucose_level bmi smoking_status stroke
## 0 0 0 0
## age_group
## 0
# 2. Create derived features
stroke_processed <- stroke_processed %>%
mutate(
# BMI Category
bmi_category = case_when(
bmi < 18.5 ~ "Underweight",
bmi >= 18.5 & bmi < 25 ~ "Normal",
bmi >= 25 & bmi < 30 ~ "Overweight",
bmi >= 30 ~ "Obese"
),
# Glucose Category
glucose_category = case_when(
avg_glucose_level < 70 ~ "Low",
avg_glucose_level >= 70 & avg_glucose_level < 100 ~ "Normal",
avg_glucose_level >= 100 & avg_glucose_level < 126 ~ "Pre-diabetes",
avg_glucose_level >= 126 ~ "Diabetes"
),
# Combined health risk score
health_risk_score = hypertension + heart_disease,
# Age categories (already created, but let's make it more specific)
age_category = case_when(
age < 13 ~ "Child",
age >= 13 & age < 20 ~ "Teen",
age >= 20 & age < 40 ~ "Young Adult",
age >= 40 & age < 60 ~ "Middle Aged",
age >= 60 ~ "Senior"
)
)
# 3. Convert categorical variables to factors
categorical_vars <- c("bmi_category", "glucose_category", "age_category")
stroke_processed[categorical_vars] <- lapply(stroke_processed[categorical_vars], as.factor)
# 4. Create interaction features
stroke_processed <- stroke_processed %>%
mutate(
hypertension_heart = interaction(hypertension, heart_disease),
age_hypertension = interaction(age_category, hypertension)
)
# 5. Print summary of new features
print("\nSummary of new derived features:")
## [1] "\nSummary of new derived features:"
summary(stroke_processed[c("bmi_category", "glucose_category", "health_risk_score", "age_category")])
## bmi_category glucose_category health_risk_score age_category
## Normal :1268 Diabetes : 981 Min. :0.0000 Child : 588
## Obese :1951 Low : 754 1st Qu.:0.0000 Middle Aged:1564
## Overweight :1554 Normal :2377 Median :0.0000 Senior :1376
## Underweight: 337 Pre-diabetes: 998 Mean :0.1515 Teen : 378
## 3rd Qu.:0.0000 Young Adult:1204
## Max. :2.0000
# 6. Verify the structure of processed dataset
str(stroke_processed)
## tibble [5,110 × 19] (S3: tbl_df/tbl/data.frame)
## $ id : int [1:5110] 9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
## $ gender : Factor w/ 3 levels "Female","Male",..: 2 1 2 1 1 2 2 1 1 1 ...
## $ age : num [1:5110] 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : int [1:5110] 0 0 0 0 1 0 1 0 0 0 ...
## $ heart_disease : int [1:5110] 1 0 1 0 0 0 1 0 0 0 ...
## $ ever_married : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 2 1 2 2 ...
## $ work_type : Factor w/ 5 levels "children","Govt_job",..: 4 5 4 4 5 4 4 4 4 4 ...
## $ Residence_type : Factor w/ 2 levels "Rural","Urban": 2 1 1 2 1 2 1 2 1 2 ...
## $ avg_glucose_level : num [1:5110] 229 202 106 171 174 ...
## $ bmi : num [1:5110] 36.6 29.2 32.5 34.4 24 29 27.4 22.8 29.8 24.2 ...
## $ smoking_status : Factor w/ 4 levels "formerly smoked",..: 1 2 2 3 2 1 2 2 4 4 ...
## $ stroke : Factor w/ 2 levels "0","1": 2 2 2 2 2 2 2 2 2 2 ...
## $ age_group : Factor w/ 5 levels "0-20","21-40",..: 4 4 4 3 4 5 4 4 3 4 ...
## $ bmi_category : Factor w/ 4 levels "Normal","Obese",..: 2 3 2 2 1 3 3 1 3 1 ...
## $ glucose_category : Factor w/ 4 levels "Diabetes","Low",..: 1 1 4 1 1 1 3 3 3 2 ...
## $ health_risk_score : int [1:5110] 1 0 1 0 1 0 2 0 0 0 ...
## $ age_category : Factor w/ 5 levels "Child","Middle Aged",..: 3 3 3 2 3 3 3 3 2 3 ...
## $ hypertension_heart: Factor w/ 4 levels "0.0","1.0","0.1",..: 3 1 3 1 2 1 4 1 1 1 ...
## $ age_hypertension : Factor w/ 10 levels "Child.0","Middle Aged.0",..: 3 3 3 2 8 3 8 3 2 3 ...
# 7. Create a correlation matrix for numeric variables
numeric_vars <- stroke_processed %>%
select_if(is.numeric) %>%
select(-id) # Remove ID column as it's not relevant for correlation
correlation_matrix <- cor(numeric_vars, use = "complete.obs")
print("\nCorrelation matrix of numeric variables:")
## [1] "\nCorrelation matrix of numeric variables:"
print(round(correlation_matrix, 2))
## age hypertension heart_disease avg_glucose_level bmi
## age 1.00 0.28 0.26 0.24 0.33
## hypertension 0.28 1.00 0.11 0.17 0.16
## heart_disease 0.26 0.11 1.00 0.16 0.04
## avg_glucose_level 0.24 0.17 0.16 1.00 0.17
## bmi 0.33 0.16 0.04 0.17 1.00
## health_risk_score 0.36 0.82 0.66 0.23 0.15
## health_risk_score
## age 0.36
## hypertension 0.82
## heart_disease 0.66
## avg_glucose_level 0.23
## bmi 0.15
## health_risk_score 1.00
# 8. Save processed dataset
processed_data <- stroke_processed %>%
select(-id) # Remove ID column as it's not needed for modeling
# Display the first few rows of the processed dataset
print("\nFirst few rows of processed dataset:")
## [1] "\nFirst few rows of processed dataset:"
head(processed_data)
## # A tibble: 6 × 18
## gender age hypertension heart_disease ever_married work_type Residence_type
## <fct> <dbl> <int> <int> <fct> <fct> <fct>
## 1 Male 67 0 1 Yes Private Urban
## 2 Female 61 0 0 Yes Self-empl… Rural
## 3 Male 80 0 1 Yes Private Rural
## 4 Female 49 0 0 Yes Private Urban
## 5 Female 79 1 0 Yes Self-empl… Rural
## 6 Male 81 0 0 Yes Private Urban
## # ℹ 11 more variables: avg_glucose_level <dbl>, bmi <dbl>,
## # smoking_status <fct>, stroke <fct>, age_group <fct>, bmi_category <fct>,
## # glucose_category <fct>, health_risk_score <int>, age_category <fct>,
## # hypertension_heart <fct>, age_hypertension <fct>
# 1. First, let's check and clean the stroke variable
print("Initial stroke value counts:")
## [1] "Initial stroke value counts:"
table(stroke_data$stroke)
##
## 0 1
## 4861 249
# 2. Clean and prepare the data
stroke_clean <- stroke_data %>%
# Remove any rows with NA values
na.omit() %>%
# Ensure stroke is binary (0 or 1)
filter(stroke %in% c(0, 1)) %>%
# Convert stroke to factor with proper labels
mutate(
stroke = factor(stroke,
levels = c(0, 1),
labels = c("No_Stroke", "Stroke"))
)
# 3. Verify the cleaning
print("\nCleaned stroke value counts:")
## [1] "\nCleaned stroke value counts:"
table(stroke_clean$stroke)
##
## No_Stroke Stroke
## 4700 209
# 4. Create the model with clean data
# Split the data
set.seed(123)
split_index <- createDataPartition(stroke_clean$stroke, p = 0.8, list = FALSE)
train_data <- stroke_clean[split_index, ]
test_data <- stroke_clean[-split_index, ]
# Set up cross-validation
ctrl <- trainControl(
method = "cv",
number = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary,
sampling = "smote"
)
# Select features for modeling
features <- c("age", "gender", "hypertension", "heart_disease",
"ever_married", "work_type", "Residence_type",
"avg_glucose_level", "bmi", "smoking_status")
# Train the model
model_formula <- as.formula("stroke ~ age + gender + hypertension + heart_disease +
ever_married + work_type + Residence_type +
avg_glucose_level + bmi + smoking_status")
rf_model <- train(
model_formula,
data = train_data,
method = "rf",
metric = "ROC",
trControl = ctrl,
na.action = na.omit
)
## Loading required package: recipes
##
## Attaching package: 'recipes'
## The following object is masked from 'package:stringr':
##
## fixed
## The following object is masked from 'package:stats':
##
## step
# Make predictions
rf_pred <- predict(rf_model, test_data)
rf_pred_prob <- predict(rf_model, test_data, type = "prob")
# Calculate performance metrics
conf_matrix <- confusionMatrix(rf_pred, test_data$stroke)
# Print results
print("\nModel Performance:")
## [1] "\nModel Performance:"
print(conf_matrix)
## Confusion Matrix and Statistics
##
## Reference
## Prediction No_Stroke Stroke
## No_Stroke 935 41
## Stroke 5 0
##
## Accuracy : 0.9531
## 95% CI : (0.9379, 0.9655)
## No Information Rate : 0.9582
## P-Value [Acc > NIR] : 0.8115
##
## Kappa : -0.0092
##
## Mcnemar's Test P-Value : 2.463e-07
##
## Sensitivity : 0.9947
## Specificity : 0.0000
## Pos Pred Value : 0.9580
## Neg Pred Value : 0.0000
## Prevalence : 0.9582
## Detection Rate : 0.9531
## Detection Prevalence : 0.9949
## Balanced Accuracy : 0.4973
##
## 'Positive' Class : No_Stroke
##
# Calculate and plot ROC curve
roc_obj <- roc(test_data$stroke, rf_pred_prob[,"Stroke"])
## Setting levels: control = No_Stroke, case = Stroke
## Setting direction: controls < cases
print(paste("\nAUC-ROC:", auc(roc_obj)))
## [1] "\nAUC-ROC: 0.856188375713544"
# Plot ROC curve
plot(roc_obj, main = "ROC Curve for Stroke Prediction")
# Feature importance
importance <- varImp(rf_model)
print("\nFeature Importance:")
## [1] "\nFeature Importance:"
print(importance)
## rf variable importance
##
## Overall
## age 100.00000
## hypertension 55.62782
## ever_marriedYes 43.38329
## genderMale 38.19801
## heart_disease 36.93665
## avg_glucose_level 36.05987
## smoking_statusUnknown 35.80096
## smoking_statusnever smoked 35.54423
## work_typePrivate 34.40080
## work_typeSelf-employed 33.31519
## Residence_typeUrban 28.82592
## smoking_statussmokes 20.35180
## work_typeGovt_job 18.61763
## bmi 18.36958
## work_typeNever_worked 0.07495
## genderOther 0.00000
# Save the model
saveRDS(rf_model, "stroke_prediction_model.rds")
# Enhance ROC curve visualization
plot(roc_obj,
main = "ROC Curve for Stroke Prediction",
col = "blue",
lwd = 2,
print.auc = TRUE,
print.thres = TRUE,
auc.polygon = TRUE,
grid = TRUE,
legacy.axes = TRUE)
# Add confidence intervals
ci.roc <- ci(roc_obj)
plot(ci.roc, col = "#1c61b6AA")
# Add legend
legend("bottomright",
legend = c(paste("AUC =", round(auc(roc_obj), 3))),
col = "blue",
lwd = 2)
# Calculate and print optimal threshold
optimal_threshold <- coords(roc_obj, "best", ret = "threshold")
print(paste("Optimal threshold:", round(optimal_threshold, 3)))
## [1] "Optimal threshold: 0.095"
# 1. Create a deployment-ready model interface
deploy_stroke_model <- function() {
# Load required libraries
library(caret)
library(randomForest)
# Create a prediction function
predict_stroke_risk <- function(new_data) {
# Load the saved model
model <- readRDS("stroke_prediction_model.rds")
# Ensure input data has correct format
required_columns <- c("age", "gender", "hypertension", "heart_disease",
"ever_married", "work_type", "Residence_type",
"avg_glucose_level", "bmi", "smoking_status")
# Convert factors to correct levels
new_data$gender <- as.factor(new_data$gender)
new_data$ever_married <- as.factor(new_data$ever_married)
new_data$work_type <- as.factor(new_data$work_type)
new_data$Residence_type <- as.factor(new_data$Residence_type)
new_data$smoking_status <- as.factor(new_data$smoking_status)
# Make prediction
pred_prob <- predict(model, new_data, type = "prob")
# Return risk score and classification
risk_score <- pred_prob[, "Stroke"]
classification <- ifelse(risk_score > 0.5, "High Risk", "Low Risk")
return(list(
risk_score = risk_score,
classification = classification,
probability = pred_prob
))
}
return(predict_stroke_risk)
}
# 2. Example usage function
example_prediction <- function() {
# Create example patient data
new_patient <- data.frame(
age = 65,
gender = "Male",
hypertension = 1,
heart_disease = 0,
ever_married = "Yes",
work_type = "Private",
Residence_type = "Urban",
avg_glucose_level = 120,
bmi = 28,
smoking_status = "formerly smoked"
)
# Get prediction function
predict_stroke_risk <- deploy_stroke_model()
# Make prediction
result <- predict_stroke_risk(new_patient)
# Print results
cat("Stroke Risk Assessment:\n")
cat("Risk Score:", round(result$risk_score, 3), "\n")
cat("Classification:", result$classification, "\n")
return(result)
}
# 3. Save deployment files
saveRDS(rf_model, "stroke_prediction_model.rds")
save(deploy_stroke_model, file = "deploy_functions.RData")
# 4. Create documentation
model_documentation <- "
Stroke Prediction Model Documentation
Input Features Required:
- age: numeric (years)
- gender: factor ('Male' or 'Female')
- hypertension: binary (0 or 1)
- heart_disease: binary (0 or 1)
- ever_married: factor ('Yes' or 'No')
- work_type: factor ('Private', 'Self-employed', 'Govt_job', 'children', 'Never_worked')
- Residence_type: factor ('Urban' or 'Rural')
- avg_glucose_level: numeric (mg/dL)
- bmi: numeric
- smoking_status: factor ('formerly smoked', 'never smoked', 'smokes', 'Unknown')
Output:
- risk_score: probability of stroke (0-1)
- classification: 'High Risk' or 'Low Risk'
- probability: full probability distribution
Model Performance:
- ROC-AUC: [Insert final ROC-AUC score]
- Sensitivity: [Insert final sensitivity]
- Specificity: [Insert final specificity]
"
# Save documentation
writeLines(model_documentation, "model_documentation.txt")
# 5. Test deployment
cat("Testing deployment with example patient...\n")
## Testing deployment with example patient...
test_result <- example_prediction()
## Stroke Risk Assessment:
## Risk Score: 0.334
## Classification: Low Risk
print(test_result)
## $risk_score
## [1] 0.334
##
## $classification
## [1] "Low Risk"
##
## $probability
## No_Stroke Stroke
## 1 0.666 0.334