2025 CHUKA UNIVERSITY-AMMNET WORKSHOP ON MALARIA MODELING
setwd("~/2025_CDAM_WORKSHOP_1")
library(rio) ## For easy data importation
SMdata = import("SMdata.csv") # Mock dataset(SMdata)
library(gganimate) ## Adds dynamic animations to your visualizations
# color point in the plot by Region
p1 = SMdata |> ggplot( aes(x = Total, y = Positive, color = Region)) +
geom_point(size=4) +
labs(title="Malaria Positive Cases vs Total by Region",
subtitle="Time: {frame_time}", # Display time as subtitle
x= "Total Cases",
y= "Number of Positive Cases",
caption = " Source: CDAM Experts, 2025") +
theme_bw() +
theme(plot.title = element_text(hjust = 0.5)) # Align the title to the center
# Add animation with gganimate
animated_plot <- p1 +
transition_time (Year) + # Animate over 'time' variable
ease_aes('linear') # Smooth linear transitions
animated_plot
1 Introduction
The workshop is a hands-on training event designed to equip participants with skills in predicting malaria risk using AI and ML techniques. The training will utilize R statistical tool for transparent and explainable predictive modeling of malaria data. Participants will gain practical experience in building, interpreting, and applying predictive models to support malaria control efforts in Kenya. This outcome will guide and inform targeted intervention strategies and resource allocation by the government and malaria control partners in Kenya. This event is intended to create and strengthen community among AMMnet Kenya, researchers, policy makers and health professionals. The event fosters collaboration across disciplines, ensuring that models are both scientifically robust and actionable for real-world public health challenges.
1.1 Objective of the Workshop:
The objective of this two-day workshop is to develop capacity on Malaria Risk Prediction in Kenya using Artificial Intelligence and Machine Learning for improvement of early detection and intervention strategies.
1.2 Specific Objectives:
- Objective: Develop capacity on Malaria Risk Prediction using interpretable ML/AI models
- Significance: Improving early detection and intervention strategies
- Approach: Utilizing R for transparent and explainable predictive modeling
1.3 Target audience:
Malaria Modelers, Statisticians, Data Scientists, Public health researchers, Policy makers, Health workers
1.4 AMMnet’s mission of connecting people & exchanging knowledge
The workshop supports AMMnet’s mission by fostering collaborative environment for sharing knowledge among researchers, health professionals, and policy makers. By engaging in practical training in AI/ML and R, participants will exchange knowledge, create clear predictive models, and utilize these competencies to enhance malaria control methods, driving innovation and tangible solutions. New AMMnet members will gain hands-on experience with R, a widely used programming language in epidemiological modeling.
2 World Malaria 2024 Report
Malaria remains a significant public health threat globally. It is a life-threatening disease caused by parasite that are transmitted to people through the bite of infected female Anopheles mosquitoes. It is preventable and curable. Sub-Saharan Africa carries the heaviest malaria burden, accounting for an estimated 94% of global cases(263 millions) and 95% of malaria-related deaths(597,000) as of 2023. In Africa many are at risk and still lack access to the services they need to prevent, detect and treat the disease. Though significant strides have been made in the elimination of malaria as Egypt and Cabo Verde were certified malaria-free by the WHO in 2024 to make a total of 44 countries who have eliminated malaria (26 countries since 2000); and Georgia and Turkey have submitted their request for malaria-free certification.(WHO world Malaria 2024 Report).
2.1 Summary of WHO world Malaria 2024 Report
3 Data Sources (Synthetic data)
Complex interactions between:
- Patient-level factors
- Environmental factors
- Parasite information
3.1 Key Data Components
Patient Information
- Age (Betweewn 15-49)
- Gender (Male or Female)
- Symptoms (e.g., fever, chills, Headache, Nausea, Fatigue, Muscle_Aches)
- Malaria test results(Outcome:positive/negative)
Environmental Predictors
Temperature
Warm temperatures accelerate mosquito development and parasite replication.
In highland areas like Kericho and Nyamira, rising temperatures have been linked to increased malaria risk
Rainfall
Rain creates breeding sites for Anopheles mosquitoes.
In Kisumu (Lake Basin), malaria incidence rises with rainfall up to 160 mm, peaking after a 15-week lag.
In Ukunda (Coastal zone), rainfall up to 150 mm also correlates with increased cases
Though temperature has a negative correlation in coastal zone
Humidity
High humidity boosts mosquito survival and biting rates.
High humidity increases vector survival rates, especially for mosquitoes
Coastal and lake regions, with consistently high humidity, show persistent transmission.
Endemic Zones
Lake Basin & Coastal zones: High transmission due to favorable climate and water bodies.
Highlands: Historically low risk, but warming trends are changing that.
Seasonal zones: Outbreaks align with rainy seasons.
Low-risk zones: Dry or elevated areas with minimal vector activity
NB:Malaria in Kenya is a climate-sensitive disease, and understanding these environmental links helps tailor interventions.
NB: Christiansen-Juchtet et al., 2014, reported that temperature, Rainfall and humidity influence the mosquito’s life cycle for their growth and development. The vector larvae survive only when the environmental condition is conducive and the rate at which mosquitoes bite humans by sucking their blood increases when the environmental conditions are favourable for their survival.
Reference: https://ij-healthgeographics.biomedcentral.com/articles/10.1186/s12942-024-00381-8Parasite information
- Type of malaria parasite(Plasmodium: Vivax, Falciparum, Malariae, Ovale)
- Mosquito Density
Target Variable
- Malaria test result (binary: positive/negative)
4 What is Machine Learning?
Machine Learning is a subset of Artificial Intelligence (AI) that allows computers to learn from data and make decisions or predictions without being explicitly programmed for every task
4.1 Supervised Learning
Definition: The model learns from labeled data — i.e., data with both input features and known outputs (Malaria test results).
Goal:Predict outcomes for new data based on learned patterns.
Common Tasks:
Classification (predict categories): e.g., spam detection, disease diagnosis
Regression (predict continuous values): e.g., house price prediction, temperature forecasting
Example: Training a model to predict malaria outcome (Positive/Negative) based on patient-level factors, environmental factors, Parasite information etc
4.2 Unsupervised Learning
Definition: The model learns patterns and structures from unlabeled data — no known outputs provided.
Goal: Find hidden structure or groupings in data.
Common Tasks:
Clustering: group similar items (e.g., customer segmentation)
Dimensionality Reduction: simplify datasets (e.g., for visualization or preprocessing)
Example:Group malaria patient records into distinct types of symptom profiles without predefined labels.
4.3 Reinforcement Learning (RL)
Definition:An agent learns by interacting with an environment and receiving rewards or penalties for actions.
Goal: Learn a sequence of actions that maximizes cumulative rewards over time.
Key Concepts:
Agent: the learner/decision-maker
Environment: where the agent operates
Actions: choices the agent can make
Reward: feedback signal for learning
Policy: strategy mapping states to actions
Example:Training a robot to navigate teaching an AI to play chess or a video game.
NB: Machine Learning offers the potential to analyze vast amounts of healthcare data, identify complex patterns, and generate predictions that can be more accurate and efficient than traditional diagnostic methods. However, these models, particularly deep learning approaches, are often criticized for being “black-box” systems, where the rationale behind a prediction is not easily understood by clinicians or patients.
5 What is an AI model?
An AI model is a program that has been trained on a set of data to recognize patterns and make decisions or predictions without additional human intervention. These models use algorithms to process data and accomplish the tasks they are designed for (https://www.ibm.com/think/topics/ai-model)
5.1 Breakdown of how AI models work:
Modeling: The first step is creating the model, which uses complex algorithms to analyze data and make decisions based on it
Training: The model is then “trained” by feeding it with large amounts of data. The more data the model processes, the more accurately it can perform its task
Inference: The final step is inference, where the trained model is used to make predictions or act on new, previously unseen data.
5.2 Explainable AI (XAI)
Refers to a set of tools and techniques that make the behavior of machine learning models more interpretable and understandable to humans.
Goals To clarify how and why AI models make certain predictions ensuring that both experts and non-experts can trust and effectively use these insights in real-world applications.
Why do we Need Explainable AI (XAI)
In healthcare, the lack of transparency in machine learning models raises significant concerns. Healthcare professionals require models that not only predict accurately but also explain why a particular decision was made. This is particularly important in scenarios like malaria prediction, where medical decisions directly impact patients’ health outcomes. Without explanations, it becomes challenging for clinicians to trust the model’s recommendations or make informed decisions based on them.
Explainable AI (XAI) has emerged as a solution to this challenge, aiming to provide interpretability for complex machine learning algorithms. XAI techniques make the decision making process of models more transparent by providing understandable, human-readable explanations of how predictions are made. This is crucial in healthcare, where model interpretability can enhance clinicians’ trust in AI-driven predictions, improve patient outcomes through better communication, and ensure ethical accountability.
5.3 Importance of XAI in Malaria Prediction
XAI helps bridge the gap between complex AI technologies and actionable insights for policymakers, healthcare workers, and researchers, allowing them to use AI insights more confidently in planning and intervention.
In healthcare, the stakes are high: AI-driven systems are used to make critical decisions regarding diagnoses, treatment plans, and patient care. For healthcare professionals to trust these systems, they must be able to understand the reasoning behind the recommendations and predictions made by the AI. Key reasons for the importance of explainability in healthcare include:
Trust and Transparency
- Stakeholders (e.g., healthcare professionals and policymakers) are more likely to adopt models they understand.
- Clinicians and patients are more likely to trust AI-driven decisions if they can understand and verify the reasoning behind them.
Clinical Decision Support
- Explainable AI can enhance clinical decision-making by offering insights into why a certain diagnosis or treatment was suggested, allowing healthcare providers to incorporate AI recommendations into their decision-making process more effectively.
Actionable Insights
- Helps identify key factors driving malaria spread and outcomes, such as climatic conditions or patient demographics.
Regulatory Compliance
Interpretability can meet requirements for explainability in healthcare systems.
Healthcare institutions are subject to regulations such as GDPR, and NIST standards,which emphasize transparency and patient consent. AI must be able to justify its decisions, particularly when it comes to sensitive topics such as predicting the likelihood of developing a disease like malaria, to avoid reinforcing existing biases or discrimination.
GDPR: A comprehensive data privacy law in the European Union that gives individuals control over their personal data.
NIST: A voluntary framework developed by the U.S. National Institute of Standards and Technology to help organizations manage risks associated with AI systems
Patient Safety
- Transparent models enable healthcare professionals to identify potential errors, biases, or flaws in the AI’s predictions, reducing the likelihood of harmful recommendations.
5.4 Basic Workflow with DALEX
“DALEX” is powerful tool for interpreting machine learning models. It provides a unified framework for creating model explanations, making it easier to understand both global and local behavior of machine learning models.
- Preparing the Malaria Dataset
- Before training models, ensure your data is clean and well-prepared.
- Training Machine Learning Models
- Use the caret package to train models such as Logistic regression, SVM, Random Forest etc
- Model Interpretability Techniques
-use the DALEX package for machine learning interpretations
Feature Importance using iml(Interpretable Machine Learning)
Quantifies how much each feature contributes to the model predictions
Identify most critical predictors of malaria risk
Understand complex interactions between variables
SHAP (Shapley Additive Explanations)
SHAP has become one of the most widely used XAI methods due to its consistency and mathematical rigor. In malaria prediction models, SHAP can be used to calculate the exact contribution of each feature to the risk prediction, enabling a clear understanding of which factors are most influential
Explain individual predictions with SHAP values.
Explains individual predictions using game theory
Example: Shows how rainfall or temperature affects malaria risk predictions.
Local Interpretable Model-Agnostic Explanations (LIME)
LIME can be applied to explain the reasons behind a specific prediction, such as why a particular patient is predicted to be at risk for malaria.
Focuses on explaining individual predictions.
Flexible enough to work with almost any supervised learning model
Generates plots to illustrate feature contributions
Example: Explains why a specific patient was classified as high risk
Partial Dependence Plots (PDP)
PDP show the relationship between a feature and the predicted outcome while keeping other features constant. This can be particularly useful in diabetes prediction models, as it allows clinicians to visualize how changes in a single factor (e.g., glucose levels or BMI) affect the likelihood of diabetes while accounting for other variables.
Understand the relationship between features and predictions.
Visualizes the relationship between features and the target prediction
Understand non-linear relationships
5.5 Machine Learning Pipeline
6 Set a working directory
This is a default location where R looks for files and saves outputs
setwd("~/2025_CDAM_WORKSHOP_1")
7 Install and load necessary libraries
Loading libraries
library(caret) ## for training machine learning models
library(psych) ## for description of data
library(ggplot2) ## for data visualization
library(caretEnsemble) ## enables the creation of ensemble models
library(tidyverse) ## for data manipulation
library(mlbench) ## for benchmarking ML Models
library(flextable) ## to create and style tables
library(mltools) ## for hyperparameter tuning
library(tictoc) ## for determining the time taken for a model to run
library(ROSE) ## for random oversampling
library(smotefamily) ## for smote sampling
library(ROCR) ## for ROC curve
library(pROC) ## for visualizing, smoothing, and comparing ROC curves
library(e1071) ## for statistical modeling and machine learning tasks(SVM)
library(class) ## for classification using k-Nearest Neighbors and other methods
library(caTools) ## for splitting data into training and testing sets
library(MASS) ## provides plotting functions and datasets
library(ISLR) ## for practical applications of statistical learning methods
library(boot) ## useful for performing bootstrap resampling
library(cvTools) ## contains functions for cross-validation, bootstrapping, & other resampling methods
library(iml) ## provide tools to analyze and interpret machine learning models
library(lime) ## powerful tools for interpreting machine learning models
library(DALEX) ## powerful tool for interpreting machine learning models.
library(rio) ## for easy data import, export(saving) and conversion
library(esquisse) ## GUI tool that allows users to easily create ggplot2 plots interactively
8 Load and prepare data/Exporatory of the dataset
library(rio) ## for easy data import, export(saving) and conversion
data = import("Malaria Dataset.csv")
#head(data) # for the 1st few rows in the dataset
#tail(data) # for the last few rows in the dataset
9 Exploratory Data Analysis (EDA)
Before we start visualizing our data, we need to understand the characteristics of our data. The goal is to get an idea of the data structure and to understand the relationships between variables.
Here are some functions that can help us understand the structure of our data:
#dim(data) # for dimensions of dataset
#summary(data) # for summary of descriptive statistics
#describe(data) # for descriptive statistics
9.1 Check for data structure
str(data)
'data.frame': 1000 obs. of 18 variables:
$ Age : int 25 49 44 40 24 48 26 27 26 33 ...
$ Gender : chr "Female" "Male" "Male" "Female" ...
$ Rainfall : num 218.4 62.3 230.6 145.8 250.8 ...
$ Temperature : num 26.1 22.9 31.5 20.5 31 26.9 29.3 29.3 25.8 27.2 ...
$ Humidity : num 78.4 54.4 88.3 41.6 49.7 82.9 60.5 84 85 63.3 ...
$ Endemic_Zone : chr "Low risk" "Coastal" "Lake Basin" "Lake Basin" ...
$ Mosquito_Density: num 66.5 167.7 156.7 72.3 46.1 ...
$ Fever : chr "no" "no" "no" "yes" ...
$ Chills : chr "no" "no" "no" "yes" ...
$ Headache : chr "no" "no" "no" "yes" ...
$ Nausea : chr "yes" "no" "no" "yes" ...
$ Muscle_Aches : chr "yes" "yes" "yes" "yes" ...
$ Fatigue : chr "no" "no" "no" "yes" ...
$ falciparum : chr "no" "no" "yes" "yes" ...
$ vivax : chr "no" "no" "no" "no" ...
$ malariae : chr "no" "no" "yes" "yes" ...
$ ovale : chr "no" "no" "no" "yes" ...
$ Malaria_Result : chr "Negative" "Negative" "Negative" "Positive" ...
9.2 Convert all character variables to factors
data[] <- lapply(data, function(x) if(is.character(x)) as.factor(x) else x)
9.3 Diagnose the data set
library(gtsummary) #create publication-ready summary tables for regression models and descriptive statistics
library(flextable) #Creates highly customizable tables suitable for reporting and publication
library(dlookr) #Analyzes the structure and quality of the dataset
diagnose(data) |> flextable()
variables | types | missing_count | missing_percent | unique_count | unique_rate |
|---|---|---|---|---|---|
Age | integer | 0 | 0 | 35 | 0.035 |
Gender | factor | 0 | 0 | 2 | 0.002 |
Rainfall | numeric | 0 | 0 | 809 | 0.809 |
Temperature | numeric | 0 | 0 | 151 | 0.151 |
Humidity | numeric | 0 | 0 | 480 | 0.480 |
Endemic_Zone | factor | 0 | 0 | 5 | 0.005 |
Mosquito_Density | numeric | 0 | 0 | 773 | 0.773 |
Fever | factor | 0 | 0 | 2 | 0.002 |
Chills | factor | 0 | 0 | 2 | 0.002 |
Headache | factor | 0 | 0 | 2 | 0.002 |
Nausea | factor | 0 | 0 | 2 | 0.002 |
Muscle_Aches | factor | 0 | 0 | 2 | 0.002 |
Fatigue | factor | 0 | 0 | 2 | 0.002 |
falciparum | factor | 0 | 0 | 2 | 0.002 |
vivax | factor | 0 | 0 | 2 | 0.002 |
malariae | factor | 0 | 0 | 2 | 0.002 |
ovale | factor | 0 | 0 | 2 | 0.002 |
Malaria_Result | factor | 0 | 0 | 2 | 0.002 |
#Explore individual columns/variables
#unique(data$Region) # unique values for single column
#table(data$Malaria_Result) # frequency for a single column
#table(data$Region, data$Intervention_Type) # frequencies for multiple columns
9.4 Check for zero variance predictors:
nzv <- nearZeroVar(data[,-18], saveMetrics = TRUE)
print(nzv)
freqRatio percentUnique zeroVar nzv
Age 1.025641 3.5 FALSE FALSE
Gender 1.049180 0.2 FALSE FALSE
Rainfall 1.000000 80.9 FALSE FALSE
Temperature 1.000000 15.1 FALSE FALSE
Humidity 1.500000 48.0 FALSE FALSE
Endemic_Zone 1.196203 0.5 FALSE FALSE
Mosquito_Density 1.000000 77.3 FALSE FALSE
Fever 1.631579 0.2 FALSE FALSE
Chills 1.538071 0.2 FALSE FALSE
Headache 1.450980 0.2 FALSE FALSE
Nausea 1.481390 0.2 FALSE FALSE
Muscle_Aches 1.631579 0.2 FALSE FALSE
Fatigue 1.481390 0.2 FALSE FALSE
falciparum 2.533569 0.2 FALSE FALSE
vivax 2.597122 0.2 FALSE FALSE
malariae 2.257329 0.2 FALSE FALSE
ovale 2.424658 0.2 FALSE FALSE
##The results above show that there is no feature with zero variance
## Remove nzv
#data <- data[, !nzv$nzv]
#dim(data)
9.5 Visualizing the Target Variable (Malaria Test Results)
library(ggplot2) ## for data visualization
# Plot Target variable using ggplot2 function
# Sample dataset
dt <- data.frame(
Malaria_Result = c("Negative", "Positive"),
Respondent = c(820, 180)) # Replace with actual numbers
# Calculate percentages
dt1 <- dt %>%
mutate(Percentage = Respondent / sum(Respondent) * 100)
# Create the bar plot
dt1 |> ggplot(aes(x = Malaria_Result, y = Respondent, fill = Malaria_Result)) +
geom_bar(stat = "identity", show.legend = TRUE) +
geom_text(aes(label = paste0(Respondent, " (", round(Percentage, 1), "%)")),vjust = -0.5, size = 5) +
labs(title = "Imbalance Malaria Data",
x = "Malaria Test Result",
y = "Respondent",
fill = "Results") +
theme_minimal() +
theme(plot.title = element_text(hjust = 0.5)) # Align the title to the center
10 Data Partition for Machine Learning
library(caret)
set.seed(123)
# Create a partition: 75% for training, 25% for testing
index <- createDataPartition(data$Malaria_Result,p = 0.75, list = FALSE)
# Create training and testing sets
train <- data[index, ]
test <- data[-index, ]
# Get the dimensions of your train and test data
dim(train)
[1] 750 18
dim(test)
[1] 250 18
##frequency distribution of classes of the target variable in the train dataset
#table(train$Malaria_Result)
#table(test$Malaria_Result)
# Plot Target variable using ggplot2 function
# Sample dataset
dt1 <- data.frame(
Malaria_Result = c("Negative", "Positive"),
Respondent = c(615, 135)) # Replace with actual numbers
# Calculate percentages
dt1 <- dt1 |>
mutate(Percentage = Respondent / sum(Respondent) * 100)
# Create the bar plot
p1 = dt1 |> ggplot(aes(x = Malaria_Result, y = Respondent, fill = Malaria_Result)) +
geom_bar(stat = "identity", show.legend = TRUE) +
geom_text(aes(label = paste0(Respondent, " (", round(Percentage, 1), "%)")),vjust = -0.5, size = 5) +
labs(title = "Imbalance Malaria Data (Training set)",
x = "Malaria Test Result",
y = "Respondent",
fill = "Results") +
theme_minimal() +
theme(plot.title = element_text(hjust = 0.5)) # Align the title to the center
print(p1)
11 View the models in CARET
models= getModelInfo()
#names(models)
12 Resampling techinque (Over-sampling)
The minority class is duplicated (synthetically) until both classes have roughly the same number of samples.
library(ROSE) ## used for balancing imbalanced datasets (e.g., in binary classification)
set.seed(123) ## ensures reproducibility: Every time you run this code, you get the same random over-sampling result
over<- ovun.sample(Malaria_Result~., data = train, method = "over", N = 1230 )$data # over-samples the minority class to balance the dataset
# Calculate counts and percentages
over_summary <- over |>
group_by(Malaria_Result) |>
summarise(Count = n()) |>
mutate(Percent = round(Count / sum(Count) * 100, 1))
# Plot
p2 = ggplot(over, aes(x = Malaria_Result, fill = Malaria_Result)) +
geom_bar() +
geom_text(data = over_summary, aes(x = Malaria_Result, y = Count,
label = paste0(Count, "(", Percent, "%)")),
vjust = -0.5, # moves text higher vertically
hjust = 0.5, # centers text horizontally
size = 5) +
labs(title ="Balanced Malaria Data(Training set)",
y = "Respondent",
x = "Malaria Test Result") +
theme_classic() +
theme(plot.title = element_text(hjust = 0.5)) # Align the title to the center
print(p2)
library(patchwork)
library(easystats)
# Display the nested plot
plots(p1, p2, n_columns = 2, tags = paste("Fig.", 1:23))
13 Cross validation technique
Cross-validation (CV) is a resampling technique used to evaluate the performance of a machine learning model and test its ability to generalize to unseen data. The idea is to train and test the model on different subsets of the data to reduce bias and variance in performance estimation. It helps in detecting overfitting and selecting better hyperparameters.
Two key contributors to model performance are bias and Variance
Bias: We want prediction that are close to the true value of the outcome we arE trying to predict
Variance: We want predictions that do not vary too much around the true value
NB:A model that performs well on new data will balance these two
Hyperparameters are the settings you specify before training a machine learning model. They control the learning process and influence how well the model performs, but they are not learned from the data.
#Creating the train-Control scheme to avoid overfitting & underfitting
library(caret)
control <- trainControl(
method = "repeatedcv", # Use repeated cross-validation
number = 10, # 10-fold cross-validation
repeats = 5, # Repeat the 10-fold CV 5 times
classProbs = TRUE, # Compute class probabilities
summaryFunction = twoClassSummary) # Use metrics like ROC, Sensitivity, Specificity
method = “repeatedcv”: This tells caret::train() to use repeated k-fold cross-validation.
number = 10: Specifies 10 folds (i.e., 10-fold cross-validation).
repeats = 5: The entire 10-fold process will be repeated 5 times, each with a different random split.
This setup is useful for:
Reducing variance in the performance estimate.
Ensuring the model’s performance is stable across different splits of the data.
Getting a more reliable estimate of the model’s accuracy, precision, RMSE, or other metrics.
14 Train Machine learning Model
15 Logistic Regression
# Load libraries
library(caret) ## for training machine learning models
library(tictoc) ## for determining the time taken for a model to run
tic()
set.seed(123) ## Ensures that results are reproducible
LRModel <- train(factor(Malaria_Result) ~ .,
data = over,
method = "glm",
trControl = control)
toc()
1.19 sec elapsed
LRModel
Generalized Linear Model
1230 samples
17 predictor
2 classes: 'Negative', 'Positive'
No pre-processing
Resampling: Cross-Validated (10 fold, repeated 5 times)
Summary of sample sizes: 1106, 1106, 1108, 1107, 1107, 1108, ...
Resampling results:
ROC Sens Spec
0.9944202 0.9593971 0.9736066
# Prediction using Logistic Regression model
LRpred = predict(LRModel,newdata = test)
# Evaluation of the Logistic Regression model performance metrics
LR_CM <- confusionMatrix(LRpred,as.factor(test$Malaria_Result), positive = "Positive", mode='everything')
LR_CM
Confusion Matrix and Statistics
Reference
Prediction Negative Positive
Negative 194 2
Positive 11 43
Accuracy : 0.948
95% CI : (0.9127, 0.972)
No Information Rate : 0.82
P-Value [Acc > NIR] : 1.796e-09
Kappa : 0.8366
Mcnemar's Test P-Value : 0.0265
Sensitivity : 0.9556
Specificity : 0.9463
Pos Pred Value : 0.7963
Neg Pred Value : 0.9898
Precision : 0.7963
Recall : 0.9556
F1 : 0.8687
Prevalence : 0.1800
Detection Rate : 0.1720
Detection Prevalence : 0.2160
Balanced Accuracy : 0.9509
'Positive' Class : Positive
# Combine data into a data frame
Ground_truth<- test$Malaria_Result
Predicted <- LRpred
resultLR <- data.frame(Ground_truth, Predicted)
resultLR$Correct <- resultLR$Ground_truth == resultLR$Predicted
# Add a column for classification results (correct/incorrect)
resultLR<- data.frame(test, LRpred, resultLR$Correct)
#print(resultLR)
15.1 Plot of Confusion Matrix of Logistic Regression
#load required packages
library(reshape2)
library(scales)
# Create the confusion matrix
conf_matrix <- matrix(c(194, 11, 2, 43), nrow = 2, byrow = TRUE)
# Name the rows and columns
rownames(conf_matrix) <- c("Negative","Positive")
colnames(conf_matrix) <- c("Negative","Positive")
# Melt the matrix for ggplot
conf_df <- melt(conf_matrix)
# Plot
ggplot(conf_df, aes(x = Var2, y = Var1, fill = value)) +
geom_tile(color = "white") +
geom_text(aes(label = value), color = "red", size = 6) +
scale_fill_gradient(low = "white", high = "purple") +
labs(title = "Logistic Regreession Confusion Matrix", x = "Predicted label", y = "True label") +
theme_minimal() +
theme(
plot.title = element_text(hjust = 0.5, face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.title.y = element_text(face = "bold"),
axis.text = element_text(face = "bold"))
## Model performance metric for classification
Confusion matrix : is a performance evaluation tool used for classification models. It shows how well the model’s predicted classes match the actual classes.
Accuracy: The proportion of all predictions that are correct used to determine the performance of the Algorithms
\[ \text{Accuracy} = \frac{TP + TN}{TP + FP + FN + TN} =\frac{43 + 194}{43 + 11 + 2 + 194} = 0.948 \]
Interpretation:
- How often the classifier is correct overall.
-⚠️ Can be misleading if classes are imbalanced.
True Positive Rate (TPR): Also known as sensitivity or recall
- It is the ratio of correctly predicted positive observations to the actual positives.
\[ \text{Sensitivity} = \frac{TP}{TP + FN} = 0.956 \]
where TP is True Positives and FN is False Negatives.
Interpretation:
Of all the true positive cases, how many did the model catch?
📌 Important in medical diagnosis, fraud detection.
False Positive Rate (FPR). Also known as specificity
- It is the ratio of incorrectly predicted positive observations to the actual negatives.
\[ \text{Specificity} = \frac{TN}{TN + FP} = 0.946 \]
where FP is False Positives and TN is True Negatives.
Interpretation:
Of all the true negative cases, how many did the model correctly identify?
📌 Important to minimize false alarms.
Precision: The proportion of predicted positives that are actually positive
used to test the correctness of the model when it gives a positive outcome
\[ \text{Precision} = \frac{TP}{TP + FP} = 0.796 \]
Interpretation:
- How trustworthy are the positive predictions?
-📌 Important when false positives are costly (e.g., spam filters, cancer screening, malaria detection).
F1 score: The harmonic mean of precision and recall. It balances the two when both are important.
Used to assess the number of variables that are missing in the predictions which are positive target
\[ F_1 Score = \frac{2 \cdot \text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}} = 0.869 \]
Interpretation:
F1 reaches its best value at 1 (perfect precision and recall) and worst at 0.
📌 Useful in imbalanced classification tasks.
15.2 Importance of features in a Logistic regression Model
# Show relative importance of features
# vip::vip(LRModel)
# Alternatively using ggplot function
var_imp <-varImp(LRModel)
ggplot(var_imp, aes(x = reorder(Variable, Importance), y = importance)) +
geom_bar(stat = "identity", fill = "tomato") +
coord_flip() +
xlab("Variable") +
ylab("Importance") +
ggtitle("Feature Importance Plot for LR Model") +
theme(plot.title = element_text(hjust = 0.5)) # Align the title to the center
15.3 Prepare data for explain() function
library(DALEX) ## powerful tool for interpreting machine learning models.
# Converts the target variable from categorical ("Positive"/"Negative") to numeric (1=for Positive, 0=otherwise)
# Required for many ML models and for DALEX to interpret the output as binary classification.
train$Malaria_Result <- ifelse(train$Malaria_Result == "Positive", 1, 0)
# Create the explainer Object
explainer_1 <- explain(model = LRModel,
data = train[, -which(names(train) =="Malaria_Result")], # Exclude the target column
y = train$Malaria_Result, # Target values as vector
label = "Local Explanation with DALEX for Logistics Regression")
Preparation of a new explainer is initiated
-> model label : Local Explanation with DALEX for Logistics Regression
-> data : 750 rows 17 cols
-> target variable : 750 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 7.0.1 , task classification ( default )
-> predicted values : numerical, min = 1.761836e-08 , mean = 0.2061251 , max = 0.9999998
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = -0.9989722 , mean = -0.02612506 , max = 0.8216142
A new explainer has been created!
# This creates an explainer1 object used by the DALEX and ingredients packages for interpretability methods
# Select an instance to explain from test set or new unseen data
set.seed(123)
new_observation_1 <- test[3, -which(names(test)=="Malaria_Result")] # Select a test instance
#new_observation_1
# Break Down explanation for the instance
local_explanation_1 <- predict_parts(explainer_1, new_observation_1)
# Plot local explanation
plot(local_explanation_1)
15.4 Overview
The graph presents a local explanation of a RF model using the DALEX package. It visualizes how different predictor variables contribute to the prediction for a specific instance. The prediction is represented by the bar on the right, and the contributions of each variable are shown as horizontal bars.
15.5 Breakdown of Contributions
• Intercept: This baseline value represents the model’s prediction when all predictor variables are zero or absent. In this case, the intercept is xxxxxx
• Predictor Variables: Each predictor variable’s contribution is shown as a bar. The color indicates the direction of the contribution:
• Green: Positive contribution, meaning the variable increases the prediction.
• Red: Negative contribution, meaning the variable decreases the prediction.
• The length of the bar represents the magnitude of the contribution.
15.6 Overall Prediction
Summing up all the contributions (intercept + predictor variables), we arrive at the final prediction of xxxx This value represents the probability of a positive outcome, as RF model typically output probabilities.
This graph provides a valuable tool for understanding how a RF model arrives at a specific prediction. It highlights the relative importance of different predictor variables and their impact on the final outcome.
However, it is essential to consider the limitations and interpret the results in conjunction with other model evaluation metrics
15.7 SHAP(SHapley Additive exPlanations)
The Shapley value helps explain how much each feature contributes to the prediction made by a machine learning model. It provides a way to fairly distribute the “credit” for the model’s output across all input features. By visualizing the SHAP plot, you can understand not only which features are important, but also how specific feature values that are driving predictions for individual cases.
15.8 SHAP explanation
shap_values_1 <- predict_parts(explainer_1, new_observation_1, type = "shap")
plot(shap_values_1)
15.9 Overview
Leftward(negative): Indicates the feature is pushing the model prediction towards a negative class
Larger absolute SHAP values mean a feature has a stronger influence on the prediction.
Smaller SHAP values (close to zero) indicate that a feature has minimal influence on the model’s output for that instance
15.10 Partial Dependence Plots (PDP)
PDP show the relationship between a feature and the predicted outcome while keeping other features constant. This can be particularly useful in diabetes prediction models, as it allows clinicians to visualize how changes in a single factor (e.g., glucose levels or BMI) affect the likelihood of diabetes while accounting for other variables.
Understand the relationship between features and predictions.
Visualizes the relationship between features and the target prediction
Understand non-linear relationships
## Partial Dependence Plot for Temperature
pdp <- model_profile(explainer_1, variables = "Temperature", type = "partial")
plot(pdp)
## Partial Dependence Plot for Humidity
pdp <- model_profile(explainer_1, variables = "Humidity", type = "partial")
plot(pdp)
16 Random Forest
This is an ensemble learning method that combines multiple decision trees to improve prediction accuracy and reduce variance. ## mtry This parameter controls the number of features randomly chosen as candidates for splitting a node in each tree. ## Training the RF model
# Load libraries
library(caret)
set.seed(123)
tuneGrid_rf <- expand.grid(mtry = c(2, 4, 6, 8, 12))
tic()
RFModel <- train(factor(Malaria_Result)~.,
data=over,
method="rf",
trControl=control,
tuneGrid=tuneGrid_rf)
toc()
113.66 sec elapsed
RFModel
Random Forest
1230 samples
17 predictor
2 classes: 'Negative', 'Positive'
No pre-processing
Resampling: Cross-Validated (10 fold, repeated 5 times)
Summary of sample sizes: 1106, 1106, 1108, 1107, 1107, 1108, ...
Resampling results across tuning parameters:
mtry ROC Sens Spec
2 0.9997999 0.9873189 0.9960814
4 0.9997238 0.9860286 0.9957589
6 0.9997054 0.9821364 0.9957641
8 0.9995599 0.9772607 0.9957589
12 0.9993606 0.9701111 0.9957589
ROC was used to select the optimal model using the largest value.
The final value used for the model was mtry = 2.
#plot(RFModel)
16.1 Prediction & Evaluation of the RF model performance metrics
# Prediction using RF model
RFpred=predict(RFModel,newdata = test)
# Evaluation of the RF model performance metrics
RF_CM<- confusionMatrix(RFpred,as.factor(test$Malaria_Result), positive = "Positive", mode='everything')
RF_CM
Confusion Matrix and Statistics
Reference
Prediction Negative Positive
Negative 201 5
Positive 4 40
Accuracy : 0.964
95% CI : (0.9328, 0.9834)
No Information Rate : 0.82
P-Value [Acc > NIR] : 3.662e-12
Kappa : 0.877
Mcnemar's Test P-Value : 1
Sensitivity : 0.8889
Specificity : 0.9805
Pos Pred Value : 0.9091
Neg Pred Value : 0.9757
Precision : 0.9091
Recall : 0.8889
F1 : 0.8989
Prevalence : 0.1800
Detection Rate : 0.1600
Detection Prevalence : 0.1760
Balanced Accuracy : 0.9347
'Positive' Class : Positive
16.2 Plot of Confusion Matrix of Random Forest
#load required packages
library(reshape2)
library(scales)
# Create the confusion matrix
conf_matrix <- matrix(c(201, 5, 4, 40), nrow = 2, byrow = TRUE)
# Name the rows and columns
rownames(conf_matrix) <- c("Negative","Positive")
colnames(conf_matrix) <- c("Negative","Positive")
# Melt the matrix for ggplot
conf_df <- melt(conf_matrix)
# Plot
ggplot(conf_df, aes(x = Var2, y = Var1, fill = value)) +
geom_tile(color = "white") +
geom_text(aes(label = value), color = "black", size = 6) +
scale_fill_gradient(low = "white", high = "purple") +
labs(title = "Random Forest Confusion Matrix", x = "Predicted label", y = "True label") +
theme_minimal() +
theme(
plot.title = element_text(hjust = 0.5, face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.title.y = element_text(face = "bold"),
axis.text = element_text(face = "bold"))
16.3 Importance of features in a Random Model
# Show relative importance of features
# vip::vip(RFModel)
# Alternatively using ggplot function
var_imp <-varImp(RFModel)
ggplot(var_imp, aes(x = reorder(Variable, Importance), y = importance)) +
geom_bar(stat = "identity", fill = "tomato") +
coord_flip() +
xlab("Variable") +
ylab("Importance") +
ggtitle("Feature Importance Plot for RF Model") +
theme(plot.title = element_text(hjust = 0.5)) # Align the title to the center
16.4 Prepare data for explain() function
library(DALEX)
# Converts the target variable from categorical ("Positive"/"Negative") to numeric (1 for Positive, 0 otherwise)
# Required for many ML models and for DALEX to interpret the output as binary classification.
train$Malaria_Result <- ifelse(train$Malaria_Result == "Positive", 1, 0)
# Create the explainer Object
explainer_2 <- explain(model = RFModel,
data = train[, -which(names(train) =="Malaria_Result")], # Exclude the target column
y = train$Malaria_Result, # Target values as vector
label = "Local Explanation with DALEX for Random Forest")
Preparation of a new explainer is initiated
-> model label : Local Explanation with DALEX for Random Forest
-> data : 750 rows 17 cols
-> target variable : 750 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 7.0.1 , task classification ( default )
-> predicted values : numerical, min = 0 , mean = 0.2224133 , max = 1
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = -1 , mean = -0.2224133 , max = 0
A new explainer has been created!
# This creates an explainer object used by the DALEX and ingredients packages for interpretability methods
# Select an instance to explain
new_observation_2 <- test[3, -which(names(test)=="Malaria_Result")] # Select a test instance
new_observation_2
Age Gender Rainfall Temperature Humidity Endemic_Zone Mosquito_Density Fever
4 40 Female 145.8 20.5 41.6 Lake Basin 72.3 yes
Chills Headache Nausea Muscle_Aches Fatigue falciparum vivax malariae ovale
4 yes yes yes yes yes yes no yes yes
# Break Down explanation for the instance
local_explanation_2 <- predict_parts(explainer_2, new_observation_2)
# Plot local explanation
plot(local_explanation_2)
## Overview
The graph presents a local explanation of a RF model using the DALEX package. It visualizes how different predictor variables contribute to the prediction for a specific instance. The prediction is represented by the bar on the right, and the contributions of each variable are shown as horizontal bars.
16.5 Breakdown of Contributions
• Intercept: This baseline value represents the model’s prediction when all predictor variables are zero or absent. In this case, the intercept is 0.221
• Predictor Variables: Each predictor variable’s contribution is shown as a bar. The color indicates the direction of the contribution:
• Green: Positive contribution, meaning the variable increases the prediction.
• Red: Negative contribution, meaning the variable decreases the prediction.
• The length of the bar represents the magnitude of the contribution.
16.6 Overall Prediction
Summing up all the contributions (intercept + predictor variables), we arrive at the final prediction of 0.95. This value represents the probability of a positive outcome, as RF model typically output probabilities.
This graph provides a valuable tool for understanding how a RF model arrives at a specific prediction. It highlights the relative importance of different predictor variables and their impact on the final outcome.
However, it is essential to consider the limitations and interpret the results in conjunction with other model evaluation metrics
16.7 SHAP(SHapley Additive exPlanations)
The Shapley value helps explain how much each feature contributes to the prediction made by a machine learning model. It provides a way to fairly distribute the “credit” for the model’s output across all input features. By visualizing the SHAP plot, you can understand not only which features are important, but also how specific feature values that are driving predictions for individual cases.
16.8 SHAP explanation
shap_values_2 <- predict_parts(explainer_2, new_observation_2, type = "shap")
plot(shap_values_2)
16.9 Overview
Leftward(negative): Indicates the feature is pushing the model prediction towards a negative class
Larger absolute SHAP values mean a feature has a stronger influence on the prediction.
Smaller SHAP values (close to zero) indicate that a feature has minimal influence on the model’s output for that instance
16.10 Partial Dependence Plots (PDP)
PDP show the relationship between a feature and the predicted outcome while keeping other features constant. This can be particularly useful in diabetes prediction models, as it allows clinicians to visualize how changes in a single factor (e.g., glucose levels or BMI) affect the likelihood of diabetes while accounting for other variables.
Understand the relationship between features and predictions.
Visualizes the relationship between features and the target prediction
Understand non-linear relationships
## Partial Dependence Plot for Humidity
pdp <- model_profile(explainer_2, variables = "Humidity", type = "partial")
plot(pdp)
16.11 Support Vector Machine (SVM)
#load library
library(caret)
#Tune the grid
svm_tunegrid <- expand.grid(sigma = c(0.01, 0.1, 0.2), C = c(0.1, 1, 10))
tic()
set.seed(123) ## Ensures that results are reproducible
# Training a model with standardization: This step is useful for models that are sensitive to the scale of the data such as:KNN, SVM, DT etc
SVMModel <- train(factor(Malaria_Result) ~ .,
data = over,
method = "svmRadial",
trControl = control,
tuneGrid =svm_tunegrid,
preProcess= c("center", "scale")) ##function to normalize the predictors
toc()
92.72 sec elapsed
16.12 View the Model
SVMModel
Support Vector Machines with Radial Basis Function Kernel
1230 samples
17 predictor
2 classes: 'Negative', 'Positive'
Pre-processing: centered (20), scaled (20)
Resampling: Cross-Validated (10 fold, repeated 5 times)
Summary of sample sizes: 1106, 1106, 1108, 1107, 1107, 1108, ...
Resampling results across tuning parameters:
sigma C ROC Sens Spec
0.01 0.1 0.9953629 0.9684876 0.9693601
0.01 1.0 0.9958446 0.9687996 0.9687150
0.01 10.0 0.9972916 0.9778689 0.9895717
0.10 0.1 0.9982851 0.9792015 0.9846854
0.10 1.0 0.9996592 0.9925225 0.9937969
0.10 10.0 0.9995957 0.9915494 0.9931465
0.20 0.1 0.9996107 0.9990270 0.9729773
0.20 1.0 0.9995456 0.9990217 0.9921735
0.20 10.0 0.9995456 0.9993443 0.9921735
ROC was used to select the optimal model using the largest value.
The final values used for the model were sigma = 0.1 and C = 1.
16.13 View the Best Tune
SVMModel$bestTune
sigma C
5 0.1 1
16.14 View the Results
SVMModel$results
sigma C ROC Sens Spec ROCSD SensSD SpecSD
1 0.01 0.1 0.9953629 0.9684876 0.9693601 0.003822631 0.022073584 0.02232086
2 0.01 1.0 0.9958446 0.9687996 0.9687150 0.004190148 0.020157785 0.02335719
3 0.01 10.0 0.9972916 0.9778689 0.9895717 0.003600154 0.017037069 0.01390537
4 0.10 0.1 0.9982851 0.9792015 0.9846854 0.001962261 0.017669779 0.01936928
5 0.10 1.0 0.9996592 0.9925225 0.9937969 0.001098089 0.009959455 0.01138115
6 0.10 10.0 0.9995957 0.9915494 0.9931465 0.001174721 0.009973902 0.01194267
7 0.20 0.1 0.9996107 0.9990270 0.9729773 0.001392756 0.003890588 0.02753736
8 0.20 1.0 0.9995456 0.9990217 0.9921735 0.001960388 0.003911731 0.01287408
9 0.20 10.0 0.9995456 0.9993443 0.9921735 0.001960388 0.003245060 0.01287408
16.15 Plot the Best Model
plot(SVMModel)
16.16 Prediction & Evaluation of the SVM model performance metrics
SVMpred = predict(SVMModel, newdata = test)
SVM_CM <- confusionMatrix(SVMpred, as.factor(test$Malaria_Result), positive = "Positive", mode='everything')
SVM_CM
Confusion Matrix and Statistics
Reference
Prediction Negative Positive
Negative 202 8
Positive 3 37
Accuracy : 0.956
95% CI : (0.9226, 0.9778)
No Information Rate : 0.82
P-Value [Acc > NIR] : 9.723e-11
Kappa : 0.8442
Mcnemar's Test P-Value : 0.2278
Sensitivity : 0.8222
Specificity : 0.9854
Pos Pred Value : 0.9250
Neg Pred Value : 0.9619
Precision : 0.9250
Recall : 0.8222
F1 : 0.8706
Prevalence : 0.1800
Detection Rate : 0.1480
Detection Prevalence : 0.1600
Balanced Accuracy : 0.9038
'Positive' Class : Positive
16.17 Plot of Confusion Matrix of SVM
#load required packages
library(reshape2)
library(scales)
# Create the confusion matrix
conf_matrix <- matrix(c(200, 5, 4, 41), nrow = 2, byrow = TRUE)
# Name the rows and columns
rownames(conf_matrix) <- c("Negative","Positive")
colnames(conf_matrix) <- c("Negative","Positive")
# Melt the matrix for ggplot
conf_df <- melt(conf_matrix)
# Plot
ggplot(conf_df, aes(x = Var2, y = Var1, fill = value)) +
geom_tile(color = "white") +
geom_text(aes(label = value), color = "black", size = 6) +
scale_fill_gradient(low = "white", high = "blue") +
labs(title = "SVM Confusion Matrix", x = "Predicted label", y = "True label") +
theme_minimal() +
theme(
plot.title = element_text(hjust = 0.5, face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.title.y = element_text(face = "bold"),
axis.text = element_text(face = "bold")
)
16.18 Prepare data for explain() function
# Converts the target variable from categorical ("Positive"/"Negative") to numeric (1 for Positive, 0 otherwise)
train$Malaria_Result <- ifelse(train$Malaria_Result == "Positive", 1, 0)
# Create the explainer Object
explainer_3 <- explain(model = SVMModel,
data = train[, -which(names(train) =="Malaria_Result")], # Exclude the target column
y = train$Malaria_Result, # Target values as vector
label = "Local Explanation with DALEX for SVM")
Preparation of a new explainer is initiated
-> model label : Local Explanation with DALEX for SVM
-> data : 750 rows 17 cols
-> target variable : 750 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 7.0.1 , task classification ( default )
-> predicted values : numerical, min = 5.030309e-07 , mean = 0.1796074 , max = 0.9999985
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = -0.9999985 , mean = -0.1796074 , max = -5.030309e-07
A new explainer has been created!
This creates an explainer object used by the DALEX and ingredients packages for interpretability methods
new_observation_3 <- test[3, -which(names(test)=="Malaria_Result")] # Select a test instance
new_observation_3
Age Gender Rainfall Temperature Humidity Endemic_Zone Mosquito_Density Fever
4 40 Female 145.8 20.5 41.6 Lake Basin 72.3 yes
Chills Headache Nausea Muscle_Aches Fatigue falciparum vivax malariae ovale
4 yes yes yes yes yes yes no yes yes
# Break Down explanation for the instance
local_explanation_3 <- predict_parts(explainer_3, new_observation_3)
# Plot local explanation
plot(local_explanation_2)
## Overview
The graph presents a local explanation of a RF model using the DALEX package. It visualizes how different predictor variables contribute to the prediction for a specific instance. The prediction is represented by the bar on the right, and the contributions of each variable are shown as horizontal bars.
16.19 Breakdown of Contributions
• Intercept: This baseline value represents the model’s prediction when all predictor variables are zero or absent. In this case, the intercept is 0.221
• Predictor Variables: Each predictor variable’s contribution is shown as a bar. The color indicates the direction of the contribution:
• Green: Positive contribution, meaning the variable increases the prediction.
• Red: Negative contribution, meaning the variable decreases the prediction.
• The length of the bar represents the magnitude of the contribution.
16.20 Overall Prediction
Summing up all the contributions (intercept + predictor variables), we arrive at the final prediction of 0.95. This value represents the probability of a positive outcome, as RF model typically output probabilities.
This graph provides a valuable tool for understanding how a RF model arrives at a specific prediction. It highlights the relative importance of different predictor variables and their impact on the final outcome.
However, it is essential to consider the limitations and interpret the results in conjunction with other model evaluation metrics
16.21 SHAP(SHapley Additive exPlanations)
The Shapley value helps explain how much each feature contributes to the prediction made by a machine learning model. It provides a way to fairly distribute the “credit” for the model’s output across all input features. By visualizing the SHAP plot, you can understand not only which features are important, but also how specific feature values that are driving predictions for individual cases.
16.22 SHAP explanation
shap_values_3 <- predict_parts(explainer_3, new_observation_3, type = "shap")
plot(shap_values_3)
16.23 Overview
Leftward(negative): Indicates the feature is pushing the model prediction towards a negative class
Larger absolute SHAP values mean a feature has a stronger influence on the prediction.
Smaller SHAP values (close to zero) indicate that a feature has minimal influence on the model’s output for that instance
17 k-Nearest Neighbors (KNN)
tic()
set.seed(123) ## Ensures that results are reproducible
# Training a model with standardization: This step is useful for models that are sensitive to the scale of the data such as:KNN, SVM, DT etc
KNNModel <- train(factor(Malaria_Result) ~ .,
data = over,
method = "knn",
trControl = control,
preProcess= c("center", "scale")) ##function to normalize the predictors
toc()
2.39 sec elapsed
KNNModel
k-Nearest Neighbors
1230 samples
17 predictor
2 classes: 'Negative', 'Positive'
Pre-processing: centered (20), scaled (20)
Resampling: Cross-Validated (10 fold, repeated 5 times)
Summary of sample sizes: 1106, 1106, 1108, 1107, 1107, 1108, ...
Resampling results across tuning parameters:
k ROC Sens Spec
5 0.9869149 0.9157959 0.9840508
7 0.9865552 0.9180328 0.9876362
9 0.9864089 0.9190640 0.9701111
ROC was used to select the optimal model using the largest value.
The final value used for the model was k = 5.
17.1 View the Best Tune
KNNModel$bestTune
k
1 5
17.2 Plot the Model
plot(KNNModel)
17.3 Prediction & Evaluation of the KNN model performance metrics
KNNpred = predict(KNNModel, newdata = test)
KNN_CM <- confusionMatrix(KNNpred, as.factor(test$Malaria_Result), positive = "Positive", mode='everything')
KNN_CM
Confusion Matrix and Statistics
Reference
Prediction Negative Positive
Negative 183 1
Positive 22 44
Accuracy : 0.908
95% CI : (0.8652, 0.9408)
No Information Rate : 0.82
P-Value [Acc > NIR] : 7.049e-05
Kappa : 0.7364
Mcnemar's Test P-Value : 3.042e-05
Sensitivity : 0.9778
Specificity : 0.8927
Pos Pred Value : 0.6667
Neg Pred Value : 0.9946
Precision : 0.6667
Recall : 0.9778
F1 : 0.7928
Prevalence : 0.1800
Detection Rate : 0.1760
Detection Prevalence : 0.2640
Balanced Accuracy : 0.9352
'Positive' Class : Positive
17.4 Plot of Confusion Matrix of KNN
#load required packages
library(reshape2)
library(scales)
# Create the confusion matrix
conf_matrix <- matrix(c(183, 1, 22, 44), nrow = 2, byrow = TRUE)
# Name the rows and columns
rownames(conf_matrix) <- c("Negative","Positive")
colnames(conf_matrix) <- c("Negative","Positive")
# Melt the matrix for ggplot
conf_df <- melt(conf_matrix)
# Plot
ggplot(conf_df, aes(x = Var2, y = Var1, fill = value)) +
geom_tile(color = "white") +
geom_text(aes(label = value), color = "black", size = 6) +
scale_fill_gradient(low = "white", high = "blue") +
labs(title = "KNN Confusion Matrix", x = "Predicted label", y = "True label") +
theme_minimal() +
theme(
plot.title = element_text(hjust = 0.5, face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.title.y = element_text(face = "bold"),
axis.text = element_text(face = "bold")
)
17.5 Prepare data for explain() function
# Converts the target variable from categorical ("Positive"/"Negative") to numeric (1 for Positive, 0 otherwise)
train$Malaria_Result <- ifelse(train$Malaria_Result == "Positive", 1, 0)
# Create the explainer Object
explainer_4 <- explain(model = KNNModel,
data = train[, -which(names(train) =="Malaria_Result")], # Exclude the target column
y = train$Malaria_Result, # Target values as vector
label = "Local Explanation with DALEX for KNN")
Preparation of a new explainer is initiated
-> model label : Local Explanation with DALEX for KNN
-> data : 750 rows 17 cols
-> target variable : 750 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 7.0.1 , task classification ( default )
-> predicted values : numerical, min = 0 , mean = 0.2194729 , max = 1
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = -1 , mean = -0.2194729 , max = 0
A new explainer has been created!
This creates an explainer object used by the DALEX and ingredients packages for interpretability methods
new_observation_4 <- test[3, -which(names(test)=="Malaria_Result")] # Select a test instance
new_observation_4
Age Gender Rainfall Temperature Humidity Endemic_Zone Mosquito_Density Fever
4 40 Female 145.8 20.5 41.6 Lake Basin 72.3 yes
Chills Headache Nausea Muscle_Aches Fatigue falciparum vivax malariae ovale
4 yes yes yes yes yes yes no yes yes
# Break Down explanation for the instance
local_explanation_4 <- predict_parts(explainer_4, new_observation_4)
# Plot local explanation
plot(local_explanation_4)
17.6 Overview
The graph presents a local explanation of a KNN model using the DALEX package. It visualizes how different predictor variables contribute to the prediction for a specific instance. The prediction is represented by the bar on the right, and the contributions of each variable are shown as horizontal bars.
17.7 SHAP explanation
shap_values_4 <- predict_parts(explainer_4, new_observation_4, type = "shap")
plot(shap_values_4)
17.8 Decision Tree (DT)
set.seed(123) ## Ensures that results are reproducible
tic()
DTModel <- train(factor(Malaria_Result) ~ .,
data = over,
method = "rpart",
trControl = control,
preProcess= c("center", "scale"))
toc()
1.64 sec elapsed
17.9 View the Model
DTModel
CART
1230 samples
17 predictor
2 classes: 'Negative', 'Positive'
Pre-processing: centered (20), scaled (20)
Resampling: Cross-Validated (10 fold, repeated 5 times)
Summary of sample sizes: 1106, 1106, 1108, 1107, 1107, 1108, ...
Resampling results across tuning parameters:
cp ROC Sens Spec
0.02113821 0.8716873 0.8670333 0.8565944
0.07073171 0.8278307 0.7677261 0.8757747
0.60813008 0.6312374 0.7784981 0.4839767
ROC was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.02113821.
17.10 View the Best Tune
DTModel$bestTune
cp
1 0.02113821
17.11 Plot the Model
plot(DTModel)
17.12 Fancy Rpartplot
library(rpart.plot)
rpart.plot(DTModel$finalModel,
main = "Decision Tree for Malaria Prediction",
extra = 104,
type = 3,
fallen.leaves = TRUE,
box.palette = "RdBu",
shadow.col = "gray",
nn = TRUE)
17.13 Prediction & Evaluation of the DT model performance metrics
DTpred = predict(DTModel, newdata = test)
DT_CM <- confusionMatrix(DTpred, as.factor(test$Malaria_Result), positive = "Positive", mode='everything')
DT_CM
Confusion Matrix and Statistics
Reference
Prediction Negative Positive
Negative 175 6
Positive 30 39
Accuracy : 0.856
95% CI : (0.8063, 0.8971)
No Information Rate : 0.82
P-Value [Acc > NIR] : 0.0780309
Kappa : 0.5962
Mcnemar's Test P-Value : 0.0001264
Sensitivity : 0.8667
Specificity : 0.8537
Pos Pred Value : 0.5652
Neg Pred Value : 0.9669
Precision : 0.5652
Recall : 0.8667
F1 : 0.6842
Prevalence : 0.1800
Detection Rate : 0.1560
Detection Prevalence : 0.2760
Balanced Accuracy : 0.8602
'Positive' Class : Positive
17.14 Plot of Confusion Matrix of DT
#load required packages
library(reshape2)
library(scales)
# Create the confusion matrix
conf_matrix <- matrix(c(175, 6, 30, 39), nrow = 2, byrow = TRUE)
# Name the rows and columns
rownames(conf_matrix) <- c("Negative","Positive")
colnames(conf_matrix) <- c("Negative","Positive")
# Melt the matrix for ggplot
conf_df <- melt(conf_matrix)
# Plot
ggplot(conf_df, aes(x = Var2, y = Var1, fill = value)) +
geom_tile(color = "white") +
geom_text(aes(label = value), color = "black", size = 6) +
scale_fill_gradient(low = "white", high = "blue") +
labs(title = "Decision Tree Confusion Matrix", x = "Predicted label", y = "True label") +
theme_minimal() +
theme(
plot.title = element_text(hjust = 0.5, face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.title.y = element_text(face = "bold"),
axis.text = element_text(face = "bold")
)
17.15 Prepare data for explain() function
# Converts the target variable from categorical ("Positive"/"Negative") to numeric (1 for Positive, 0 otherwise)
train$Malaria_Result <- ifelse(train$Malaria_Result == "Positive", 1, 0)
# Create the explainer Object
explainer_5 <- explain(model = DTModel,
data = train[, -which(names(train) =="Malaria_Result")], # Exclude the target column
y = train$Malaria_Result, # Target values as vector
label = "Local Explanation with DALEX for Decision Tree")
Preparation of a new explainer is initiated
-> model label : Local Explanation with DALEX for Decision Tree
-> data : 750 rows 17 cols
-> target variable : 750 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 7.0.1 , task classification ( default )
-> predicted values : numerical, min = 0.12749 , mean = 0.3173826 , max = 0.9013807
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = -0.9013807 , mean = -0.3173826 , max = -0.12749
A new explainer has been created!
# This creates an explainer object used by the DALEX and ingredients packages for interpretability methods
new_observation_5 <- test[3, -which(names(test)=="Malaria_Result")] # Select a test instance
new_observation_5
Age Gender Rainfall Temperature Humidity Endemic_Zone Mosquito_Density Fever
4 40 Female 145.8 20.5 41.6 Lake Basin 72.3 yes
Chills Headache Nausea Muscle_Aches Fatigue falciparum vivax malariae ovale
4 yes yes yes yes yes yes no yes yes
# Break Down explanation for the instance
local_explanation_5 <- predict_parts(explainer_5, new_observation_5)
# Plot local explanation
plot(local_explanation_5)
17.16 SHAP explanation
shap_values_5 <- predict_parts(explainer_5, new_observation_5, type = "shap")
plot(shap_values_5)
17.17 Naive Bayes NB
set.seed(123) ## Ensures that results are reproducible
tic()
NBModel <- train(factor(Malaria_Result) ~ .,
data = over,
method = "naive_bayes",
trControl = control,
preProcess= c("center", "scale"))
toc()
2.72 sec elapsed
17.18 View the Model
NBModel
Naive Bayes
1230 samples
17 predictor
2 classes: 'Negative', 'Positive'
Pre-processing: centered (20), scaled (20)
Resampling: Cross-Validated (10 fold, repeated 5 times)
Summary of sample sizes: 1106, 1106, 1108, 1107, 1107, 1108, ...
Resampling results across tuning parameters:
usekernel ROC Sens Spec
FALSE 0.9939957 0.9658805 0.9391274
TRUE 0.9957419 0.9434479 0.9807827
Tuning parameter 'laplace' was held constant at a value of 0
Tuning
parameter 'adjust' was held constant at a value of 1
ROC was used to select the optimal model using the largest value.
The final values used for the model were laplace = 0, usekernel = TRUE
and adjust = 1.
17.19 View the Best Tune
NBModel$bestTune
laplace usekernel adjust
2 0 TRUE 1
17.20 View the Results
NBModel$results
usekernel laplace adjust ROC Sens Spec ROCSD SensSD
1 FALSE 0 1 0.9939957 0.9658805 0.9391274 0.004204189 0.02127463
2 TRUE 0 1 0.9957419 0.9434479 0.9807827 0.003101067 0.02500433
SpecSD
1 0.02930288
2 0.01672640
17.21 Plot the Model
plot(NBModel)
17.22 Prediction & Evaluation of the NB model performance metrics
NBpred = predict(NBModel, newdata = test)
NB_CM <- confusionMatrix(NBpred, as.factor(test$Malaria_Result), positive = "Positive", mode='everything')
NB_CM
Confusion Matrix and Statistics
Reference
Prediction Negative Positive
Negative 188 1
Positive 17 44
Accuracy : 0.928
95% CI : (0.8886, 0.9568)
No Information Rate : 0.82
P-Value [Acc > NIR] : 7.338e-07
Kappa : 0.7858
Mcnemar's Test P-Value : 0.000407
Sensitivity : 0.9778
Specificity : 0.9171
Pos Pred Value : 0.7213
Neg Pred Value : 0.9947
Precision : 0.7213
Recall : 0.9778
F1 : 0.8302
Prevalence : 0.1800
Detection Rate : 0.1760
Detection Prevalence : 0.2440
Balanced Accuracy : 0.9474
'Positive' Class : Positive
17.23 Plot of Confusion Matrix of NB
#load required packages
library(reshape2)
library(scales)
# Create the confusion matrix
conf_matrix <- matrix(c(188, 1, 17, 44), nrow = 2, byrow = TRUE)
# Name the rows and columns
rownames(conf_matrix) <- c("Negative","Positive")
colnames(conf_matrix) <- c("Negative","Positive")
# Melt the matrix for ggplot
conf_df <- melt(conf_matrix)
# Plot
ggplot(conf_df, aes(x = Var2, y = Var1, fill = value)) +
geom_tile(color = "white") +
geom_text(aes(label = value), color = "black", size = 6) +
scale_fill_gradient(low = "white", high = "blue") +
labs(title = "Naive Bayes Confusion Matrix", x = "Predicted label", y = "True label") +
theme_minimal() +
theme(
plot.title = element_text(hjust = 0.5, face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.title.y = element_text(face = "bold"),
axis.text = element_text(face = "bold")
)
17.24 Prepare data for explain() function
# Converts the target variable from categorical ("Positive"/"Negative") to numeric (1 for Positive, 0 otherwise)
train$Malaria_Result <- ifelse(train$Malaria_Result == "Positive", 1, 0)
# Create the explainer Object
explainer_6 <- explain(model = NBModel,
data = train[, -which(names(train) =="Malaria_Result")], # Exclude the target column
y = train$Malaria_Result, # Target values as vector
label = "Local Explanation with DALEX for Naive Bayes")
Preparation of a new explainer is initiated
-> model label : Local Explanation with DALEX for Naive Bayes
-> data : 750 rows 17 cols
-> target variable : 750 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 7.0.1 , task classification ( default )
-> predicted values : numerical, min = 2.431903e-06 , mean = 0.2270044 , max = 0.9999973
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = -0.9999973 , mean = -0.2270044 , max = -2.431903e-06
A new explainer has been created!
This creates an explainer object used by the DALEX and ingredients packages for interpretability methods
new_observation_6 <- test[2, -which(names(test)=="Malaria_Result")] # Select a test instance
new_observation_6
Age Gender Rainfall Temperature Humidity Endemic_Zone Mosquito_Density Fever
3 44 Male 230.6 31.5 88.3 Lake Basin 156.7 no
Chills Headache Nausea Muscle_Aches Fatigue falciparum vivax malariae ovale
3 no no no yes no yes no yes no
# Break Down explanation for the instance
local_explanation_6 <- predict_parts(explainer_6, new_observation_6)
# Plot local explanation
plot(local_explanation_6)
17.24.1 SHAP explanation
shap_values_6 <- predict_parts(explainer_6, new_observation_6, type = "shap")
plot(shap_values_6)