Malaria predictive modelling

Installing Packages & Loading libraries

setwd("/Users/m1/Downloads/Research/Malaria-Incidence-Prediction/")

#installing packages

library(readxl)
library(tidyverse)
library(ggplot2)
library(dplyr)
library(purrr)
library(stringi)
library(imputeTS)
library(data.table)
library(writexl)
library(sf)
library(gstat)
library(corrr)
library(lubridate)
library(car)

#loading malaria datasets and merging the files #merging files

files <- list.files(pattern="*.xls")
merged_df <- files %>% map_dfr(read_excel)
merged_df <- merged_df[ , !(names(merged_df) %in% c("orgunitlevel1", "orgunitlevel2", "orgunitlevel4", "NMCP IPD Confirmed Malaria Cases", "NMCP IPD Total Malaria Deaths"))]
dates_merged_df <- separate(merged_df, col=periodname, into=c("Month", "Year"), sep=" ")
dates_merged_df <- separate(dates_merged_df, col =orgunitlevel3, into = c("District", "facility"), sep = "-" )
#renaming the NMCP OPD Confirmed Malaria Cases to remove spaces

v1_merged_df <- dates_merged_df %>%
  rename("Incidence" = "NMCP OPD Confirmed Malaria Cases", 
         "cluster" = "organisationunitname")
#getting only specific years from 2018 to 2024
summarized_df <- v1_merged_df %>%
  filter(Year >= 2018, Year <= 2024)
#Cleaning missing Incidence column using kalman filtering 
summarized_df <- summarized_df %>% 
  group_by(District) %>%
  mutate(Incidence = round(na_kalman(Incidence)))
#importing climate data into workspace
lines <- readLines("Windspeed-2010-2025-Monthly.csv")
begin_data <- grep("-END HEADER-", lines)
wind_df <- read_csv("Windspeed-2010-2025-Monthly.csv", skip = begin_data)
temp_df <-read_csv("Temperature-2010-2025-Monthly.csv", skip = begin_data)
rain_df <-read_csv("Rainfall-Monthly_2010_2025.csv", skip = begin_data)
hum_df <-read_csv("RHum-2010-2025-Monthly.csv")
#saving malaria data as csv file and reading it into the workspace
write.csv(summarized_df, "malaria.csv", row.names = FALSE)
malaria_df <- read_csv("malaria.csv")
glimpse(malaria_df)
Rows: 63,654
Columns: 6
$ District  <chr> "Mangochi", "Mangochi", "Mangochi", "Mangochi", "Mangochi", …
$ facility  <chr> "DHO", "DHO", "DHO", "DHO", "DHO", "DHO", "DHO", "DHO", "DHO…
$ cluster   <chr> "Aa-salam Clinic", "Aa-salam Clinic", "Aa-salam Clinic", "Aa…
$ Month     <chr> "February", "March", "April", "May", "June", "July", "August…
$ Year      <dbl> 2019, 2019, 2019, 2019, 2019, 2019, 2019, 2019, 2019, 2019, …
$ Incidence <dbl> 451, 794, 308, 276, 291, 170, 133, 155, 128, 59, 62, 160, 25…
#changing all columns to lower letters in malaria data
names(malaria_df) <- tolower(names(malaria_df))
#changing all names to lowercase for climate data
names(wind_df) <- tolower(names(wind_df))
names(rain_df) <- tolower(names(rain_df))
names(hum_df) <- tolower(names(hum_df))
names(temp_df) <- tolower(names(temp_df))
# Summarize total incidence per month
monthly_incidence <- malaria_df %>%
  group_by(month) %>%
  summarise(Total_Incidence = sum(incidence, na.rm = TRUE)) %>%
  # Make sure months are ordered correctly
  mutate(Month = factor(month, levels = month.abb))

monthly_incidence$month <- factor(
  monthly_incidence$month,
  levels = c("January", "February","March","April","May","June","July","August","September","October","November","December")
)
# Plot histogram (bar plot for total incidence per month)
ggplot(monthly_incidence, aes(x = month, y = Total_Incidence)) +
  geom_col(fill = "steelblue") +
  theme_minimal() +
  labs(title = "Total Incidence by Month",
       x = "Month",
       y = "Total Incidence")+
  theme(
    axis.text.x = element_text(size = 8, angle = 45, hjust = 1)  # smaller and rotated
  )

# Ensure Month is ordered
malaria_df$month <- factor(malaria_df$month,
                           levels = c("January","February","March","April","May","June",
                                      "July","August","September","October","November","December"))

# Summarize incidence per district per month
monthly_district <- malaria_df %>%
  group_by(district, month) %>%
  summarise(Total_Incidence = sum(incidence, na.rm = TRUE), .groups = "drop")

# Heatmap-style plot with adjusted axis titles and plot height
ggplot(monthly_district, aes(x = month, y = district, fill = Total_Incidence)) +
  geom_tile(color = "white") +  # white borders for separation
  scale_fill_gradient(low = "lightyellow", high = "steelblue") +  # color gradient
  labs(title = "Monthly Incidence per District",
       x = "Month",
       y = "District",
       fill = "Incidence") +
  theme_minimal(base_size = 12) +
  theme(
    axis.title.x = element_text(size = 10, face = "bold"),    # x-axis title smaller
    axis.title.y = element_text(size = 10, face = "bold"),    # y-axis title smaller
    axis.text.x = element_text(angle = 45, hjust = 1, size = 8),
    axis.text.y = element_text(size = 8),
    legend.title = element_text(size = 10),                   # legend title smaller
    legend.text = element_text(size = 8),                     # legend labels smaller
    plot.title = element_text(face = "bold", size = 14, hjust = 0.5)
  )

#filtering period of analysis for climate varibles from 2018-2025
wind_df <- wind_df %>%
  filter(year >= 2018, year <= 2024)

rain_df <- rain_df %>%
  filter(year >= 2018, year <= 2024)

hum_df <- hum_df %>%
  filter(year >= 2018, year <= 2024)

temp_df <- temp_df %>%
  filter(year >= 2018, year <= 2024)
# Ensure month is ordered
malaria_df$month <- factor(malaria_df$month,
                            levels = c("January","February","March","April","May","June",
                                       "July","August","September","October","November","December"))

# Summarize incidence per district per month per year
monthly_year_district <- malaria_df %>%
  group_by(district, year, month) %>%
  summarise(Total_Incidence = sum(incidence, na.rm = TRUE), .groups = "drop")

# Heatmap faceted by district
ggplot(monthly_year_district, aes(x = month, y = factor(year), fill = Total_Incidence)) +
  geom_tile(color = "white") +
  scale_fill_gradient(low = "lightyellow", high = "steelblue") +
  facet_wrap(~district, scales = "free_y") +
  labs(title = "Monthly Incidence per District (2018–2025)",
       x = "Month",
       y = "Year",
       fill = "Incidence") +
  theme_minimal(base_size = 12) +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1, size = 5),
    axis.text.y = element_text(size = 8),
    axis.title = element_text(size = 10, face = "bold"),
    legend.title = element_text(size = 10),
    legend.text = element_text(size = 8),
    strip.text = element_text(face = "bold", size = 8),
    plot.title = element_text(face = "bold", size = 14, hjust = 0.5)
  )

#removing the column annual from the climate datasets
temp_df <- temp_df[ , !(names(temp_df) %in% c("ann"))]
rain_df <- rain_df[ , !(names(rain_df) %in% c("ann"))]
hum_df <- hum_df[ , !(names(hum_df) %in% c("ann"))]
wind_df <- wind_df[ , !(names(wind_df) %in% c("ann"))]

hum_df <- hum_df %>%
  rename("parameter" = "parameter-hum")
#changing month to date time to plot in order
malaria_df$date <- as.Date(paste("01", malaria_df$month, malaria_df$year),
                           format = "%d %B %Y")
#merging auxiliary facilities/clusters to parent facilities
malaria_df <- malaria_df %>%
  mutate(district = case_when(
    district == "Queen Elizabeth Central Hospital" ~ "Blantyre",
    district == "Zomba Mental Hospital" ~ "Zomba",
    district == "Zomba Central Hospital" ~ "Zomba",
    district == "Nkhata" ~ "Nkhatabay",
    district == "Mzuzu Central Hospital" ~ "Mzuzu",
    TRUE ~ district #keeping all other names unchanged
  ))
unique(malaria_df$district)
 [1] "Mangochi"   "Lilongwe"   "Nkhotakota" "Chikwawa"   "Thyolo"    
 [6] "Kasungu"    "Karonga"    "Balaka"     "Blantyre"   "Dedza"     
[11] "Chiradzulu" "Ntcheu"     "Zomba"      "Rumphi"     "Mulanje"   
[16] "Dowa"       "Nkhatabay"  "Mzimba"     "Salima"     "Machinga"  
[21] "Chitipa"    "Nsanje"     "Neno"       "Mchinji"    "Ntchisi"   
[26] "Phalombe"   "Likoma"     "Mwanza"     "Mzuzu"     
#converting climate datasets to data.table
setDT(rain_df)
setDT(wind_df)
setDT(temp_df)
setDT(hum_df)

#reshaping to long format using melt
months <- c("jan","feb","mar","apr","may","jun",
            "jul","aug","sep","oct","nov","dec")


rain_long <- melt(rain_df,
                  id.vars = c("year","lat","lon"),
                  measure.vars = months,
                  variable.name = "month",
                  value.name = "rainfall")

hum_long <- melt(hum_df,
                 id.vars = c("year","lat","lon"),
                 measure.vars = months,
                 variable.name = "month",
                 value.name = "humidity")

temp_long <- melt(temp_df,
                  id.vars = c("year","lat","lon"),
                  measure.vars = months,
                  variable.name = "month",
                  value.name = "temperature")

wind_long <- melt(wind_df,
                  id.vars = c("year","lat","lon"),
                  measure.vars = months,
                  variable.name = "month",
                  value.name = "windspeed")
#converting month names to number and creating a date column
climate_list <- list(rain_long, hum_long, temp_long, wind_long)

for(i in seq_along(climate_list)){
  climate_list[[i]][, month_num := match(tolower(month), tolower(month.abb))]
  climate_list[[i]][, date := as.Date(paste(year, month_num, "01", sep = "-"))]
}

# Assign back
rain_long <- climate_list[[1]]
hum_long  <- climate_list[[2]]
temp_long <- climate_list[[3]]
wind_long <- climate_list[[4]]

#keeping only necessary columns
rain_long <- rain_long[,  .(year, lat, lon, month, date, rainfall)]
hum_long  <- hum_long[,   .(year, lat, lon, month, date, humidity)]
temp_long <- temp_long[,  .(year, lat, lon, month, date, temperature)]
wind_long <- wind_long[,  .(year, lat, lon, month, date, windspeed)]

#merging climate datasets
climate_df <- merge(rain_long, hum_long,
                    by = c("year","lat","lon","month","date"), all = TRUE)
climate_df <- merge(climate_df, temp_long,
                    by = c("year","lat","lon","month","date"), all = TRUE)
climate_df <- merge(climate_df, wind_long,
                    by = c("year","lat","lon","month","date"), all = TRUE)
head(climate_df)
Key: <year, lat, lon, month, date>
    year   lat    lon  month       date rainfall humidity temperature windspeed
   <num> <num>  <num> <fctr>     <Date>    <num>    <num>       <num>     <num>
1:  2018   -17 33.125    jan 2018-01-01     0.62    58.22       27.42      3.02
2:  2018   -17 33.125    feb 2018-02-01    11.33    83.60       24.88      2.08
3:  2018   -17 33.125    mar 2018-03-01     3.78    78.75       24.98      2.39
4:  2018   -17 33.125    apr 2018-04-01     0.81    70.54       23.97      2.83
5:  2018   -17 33.125    may 2018-05-01     0.24    62.94       23.01      2.64
6:  2018   -17 33.125    jun 2018-06-01     0.01    57.74       20.15      2.77
#Geolocating points in the climate data and linking it to distrits for merging with the malaria data

# Load Malawi districts shapefile

districts <- st_read("mwi_adm_nso_hotosm_20230405_shp/mwi_admbnda_adm2_nso_hotosm_20230405.shp")
Reading layer `mwi_admbnda_adm2_nso_hotosm_20230405' from data source 
  `/Users/m1/Downloads/Research/Malaria-Incidence-Prediction/mwi_adm_nso_hotosm_20230405_shp/mwi_admbnda_adm2_nso_hotosm_20230405.shp' 
  using driver `ESRI Shapefile'
Simple feature collection with 32 features and 11 fields
Geometry type: MULTIPOLYGON
Dimension:     XY
Bounding box:  xmin: 32.67164 ymin: -17.12975 xmax: 35.91848 ymax: -9.367346
Geodetic CRS:  WGS 84
# Transform to same CRS as your climate points
districts <- st_transform(districts, crs = 4326)  # WGS84


# 2. Convert climate dataframe to sf

# Replace 'climate_df_raw' with your actual dataframe
climate_sf <- st_as_sf(climate_df, coords = c("lon", "lat"), crs = 4326)


# 3. Spatial join with districts (buffer + nearest fallback)

# First, buffer districts slightly to catch edge points
malawi_buffered <- st_buffer(districts, dist = 0.01)

# Initial join using buffer
points_with_district <- st_join(climate_sf, malawi_buffered["ADM2_EN"], join = st_within)

# Handle points still NA (nearest district)
na_points <- points_with_district[is.na(points_with_district$ADM2_EN), ]
nearest_index <- st_nearest_feature(na_points, districts)
na_points$ADM2_EN <- districts$ADM2_EN[nearest_index]
points_with_district$ADM2_EN[is.na(points_with_district$ADM2_EN)] <- na_points$ADM2_EN


# 4. Clean dataframe
colnames(points_with_district)
[1] "year"        "month"       "date"        "rainfall"    "humidity"   
[6] "temperature" "windspeed"   "ADM2_EN"     "geometry"   
head(points_with_district)
Simple feature collection with 6 features and 8 fields
Geometry type: POINT
Dimension:     XY
Bounding box:  xmin: 33.125 ymin: -17 xmax: 33.125 ymax: -17
Geodetic CRS:  WGS 84
  year month       date rainfall humidity temperature windspeed  ADM2_EN
1 2018   jan 2018-01-01     0.62    58.22       27.42      3.02 Chikwawa
2 2018   feb 2018-02-01    11.33    83.60       24.88      2.08 Chikwawa
3 2018   mar 2018-03-01     3.78    78.75       24.98      2.39 Chikwawa
4 2018   apr 2018-04-01     0.81    70.54       23.97      2.83 Chikwawa
5 2018   may 2018-05-01     0.24    62.94       23.01      2.64 Chikwawa
6 2018   jun 2018-06-01     0.01    57.74       20.15      2.77 Chikwawa
            geometry
1 POINT (33.125 -17)
2 POINT (33.125 -17)
3 POINT (33.125 -17)
4 POINT (33.125 -17)
5 POINT (33.125 -17)
6 POINT (33.125 -17)
climate_df <- points_with_district %>%
  st_drop_geometry() %>%
  select(year, month, date, rainfall, humidity, temperature, windspeed, ADM2_EN) %>%
  rename(district = ADM2_EN)

# Check results
head(climate_df)
  year month       date rainfall humidity temperature windspeed district
1 2018   jan 2018-01-01     0.62    58.22       27.42      3.02 Chikwawa
2 2018   feb 2018-02-01    11.33    83.60       24.88      2.08 Chikwawa
3 2018   mar 2018-03-01     3.78    78.75       24.98      2.39 Chikwawa
4 2018   apr 2018-04-01     0.81    70.54       23.97      2.83 Chikwawa
5 2018   may 2018-05-01     0.24    62.94       23.01      2.64 Chikwawa
6 2018   jun 2018-06-01     0.01    57.74       20.15      2.77 Chikwawa
unique(climate_df$district)
 [1] "Chikwawa"      "Nsanje"        "Thyolo"        "Mulanje"      
 [5] "Dedza"         "Mwanza"        "Blantyre"      "Zomba"        
 [9] "Lilongwe"      "Ntcheu"        "Balaka"        "Machinga"     
[13] "Mangochi"      "Mchinji"       "Lilongwe City" "Dowa"         
[17] "Salima"        "Kasungu"       "Nkhotakota"    "Mzimba"       
[21] "Likoma"        "Nkhatabay"     "Rumphi"        "Chitipa"      
[25] "Karonga"      
# 5. Plot Malawi map with district labels

ggplot() +
  geom_sf(data = districts, fill = NA, color = "black") +
  geom_sf_text(data = districts, aes(label = ADM2_EN), size = 2, color = "blue") +
  theme_minimal() +
  labs(title = "Malawi Districts", caption = "Source: Malawi shapefile")

#merging climate with malaria data
setDT(malaria_df)
head(malaria_df)
   district facility         cluster    month  year incidence       date
     <char>   <char>          <char>   <fctr> <num>     <num>     <Date>
1: Mangochi      DHO Aa-salam Clinic February  2019       451 2019-02-01
2: Mangochi      DHO Aa-salam Clinic    March  2019       794 2019-03-01
3: Mangochi      DHO Aa-salam Clinic    April  2019       308 2019-04-01
4: Mangochi      DHO Aa-salam Clinic      May  2019       276 2019-05-01
5: Mangochi      DHO Aa-salam Clinic     June  2019       291 2019-06-01
6: Mangochi      DHO Aa-salam Clinic     July  2019       170 2019-07-01
head(climate_df)
  year month       date rainfall humidity temperature windspeed district
1 2018   jan 2018-01-01     0.62    58.22       27.42      3.02 Chikwawa
2 2018   feb 2018-02-01    11.33    83.60       24.88      2.08 Chikwawa
3 2018   mar 2018-03-01     3.78    78.75       24.98      2.39 Chikwawa
4 2018   apr 2018-04-01     0.81    70.54       23.97      2.83 Chikwawa
5 2018   may 2018-05-01     0.24    62.94       23.01      2.64 Chikwawa
6 2018   jun 2018-06-01     0.01    57.74       20.15      2.77 Chikwawa
#merging the climate and malaria data
# Make sure the date columns are Date type
malaria_df$date <- as.Date(malaria_df$date)
climate_df$date <- as.Date(climate_df$date)
climate_df <- climate_df %>%
  mutate(month = month.name[match(tolower(month), tolower(month.abb))])

head(climate_df)
  year    month       date rainfall humidity temperature windspeed district
1 2018  January 2018-01-01     0.62    58.22       27.42      3.02 Chikwawa
2 2018 February 2018-02-01    11.33    83.60       24.88      2.08 Chikwawa
3 2018    March 2018-03-01     3.78    78.75       24.98      2.39 Chikwawa
4 2018    April 2018-04-01     0.81    70.54       23.97      2.83 Chikwawa
5 2018      May 2018-05-01     0.24    62.94       23.01      2.64 Chikwawa
6 2018     June 2018-06-01     0.01    57.74       20.15      2.77 Chikwawa
# Join datasets by district and date
merged_df <- malaria_df %>%
  left_join(climate_df, by = c("district", "date", "year","month"))


# Checking result of merge
head(merged_df)
   district facility         cluster    month  year incidence       date
     <char>   <char>          <char>   <char> <num>     <num>     <Date>
1: Mangochi      DHO Aa-salam Clinic February  2019       451 2019-02-01
2: Mangochi      DHO Aa-salam Clinic February  2019       451 2019-02-01
3: Mangochi      DHO Aa-salam Clinic February  2019       451 2019-02-01
4: Mangochi      DHO Aa-salam Clinic February  2019       451 2019-02-01
5: Mangochi      DHO Aa-salam Clinic February  2019       451 2019-02-01
6: Mangochi      DHO Aa-salam Clinic February  2019       451 2019-02-01
   rainfall humidity temperature windspeed
      <num>    <num>       <num>     <num>
1:     7.04    82.88       24.38      1.75
2:     5.01    86.72       22.87      1.74
3:     8.42    76.47       25.42      2.67
4:     6.90    87.05       22.56      1.68
5:     9.40    79.18       24.44      2.43
6:     8.85    87.16       22.31      1.59
# Summarize incidence per district per year
yearly_district <- malaria_df %>%
  group_by(district, year) %>%
  summarise(Total_Incidence = sum(incidence, na.rm = TRUE), .groups = "drop")

# Make year a factor for proper ordering
yearly_district$year <- factor(yearly_district$year, levels = 2018:2025)

# Heatmap
ggplot(yearly_district, aes(x = year, y = district, fill = Total_Incidence)) +
  geom_tile(color = "white") +
  scale_fill_gradient(low = "lightyellow", high = "steelblue") +
  labs(title = "Annual Incidence per District (2018–2025)",
       x = "Year",
       y = "District",
       fill = "Incidence") +
  theme_minimal(base_size = 12) +
  theme(
    axis.title.x = element_text(size = 10, face = "bold"),
    axis.title.y = element_text(size = 10, face = "bold"),
    axis.text.x = element_text(size = 10),
    axis.text.y = element_text(size = 8),
    legend.title = element_text(size = 10),
    legend.text = element_text(size = 8),
    plot.title = element_text(face = "bold", size = 14, hjust = 0.5)
  )

#drawing a chloropleth map of malawi showing incidence oer district
malaria_agg <-  malaria_df %>%
  group_by(district, year)%>%
  summarise(total_incidence = sum(incidence, na.rm = TRUE))%>%
  ungroup()

malaria_map_data <- districts %>%
  left_join(malaria_agg, by = c("ADM2_EN" = "district"))%>%
  filter(!is.na(total_incidence)) #removing districts with NA incidence

ggplot(malaria_map_data) +
  geom_sf(aes(fill = total_incidence), color = "black", size = 0.2) +
  scale_fill_gradient(low = "#FFF7BC", high = "#D95F0E", name = "Incidence") +
  theme_minimal() +
  labs(title = "Malaria Incidence per District in Malawi (2018-2024)") +
  facet_wrap(~ year) +
  theme(
    axis.text = element_blank(),
    axis.ticks = element_blank(),
    panel.grid = element_blank()
  )

#making correlation for incidence with climate variables
# Define climate variables
climate_vars <- c("rainfall", "humidity", "temperature", "windspeed")

# Function to compute correlation and p-value safely
cor_test_safe <- function(x, y) {
  if(sum(!is.na(x) & !is.na(y)) > 1) {
    test <- cor.test(x, y, use = "complete.obs")
    tibble(correlation = test$estimate, p_value = test$p.value)
  } else {
    tibble(correlation = NA_real_, p_value = NA_real_)
  }
}

# Compute correlation and p-value per district and variable
correlation_per_district <- merged_df %>%
  group_by(district) %>%
  summarise(
    across(
      all_of(climate_vars),
      ~ list(cor_test_safe(incidence, .x)), 
      .names = "{.col}"
    ),
    .groups = "drop"
  ) %>%
  mutate(across(all_of(climate_vars), ~ map_dfr(.x, ~.x))) %>%
  unnest(cols = everything(), names_sep = "_")


# Reshape to long format
cor_long <- correlation_per_district %>%
  pivot_longer(
    cols = matches(paste0("^(", paste(climate_vars, collapse = "|"), ")_")),
    names_to = c("climate_variable", ".value"),
    names_sep = "_"
  ) %>%
  filter(!is.na(correlation))


# Function to convert p-values to significance stars
p_to_stars <- function(p) {
  case_when(
    p < 0.001 ~ "***",
    p < 0.01  ~ "**",
    p < 0.05  ~ "*",
    p < 0.1   ~ ".",
    TRUE      ~ ""
  )
}

colnames(cor_long)
[1] "district"         "climate_variable" "correlation"      "p"               
cor_long <- cor_long %>%
  mutate(stars = p_to_stars(p))

# Plot heatmap with correlation and stars
ggplot(cor_long, aes(x = climate_variable, y = district, fill = correlation)) +
  geom_tile(color = "white") +
  geom_text(aes(label = stars), color = "black", size = 5) +
  scale_fill_gradient2(low = "blue", mid = "white", high = "red", midpoint = 0, name = "Correlation") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
  labs(
    title = "Correlation of Malaria Incidence vs Climate Variables per District",
    x = "Climate Variable",
    y = "District",
    subtitle = "Significance indicated by stars: *** p<0.001, ** p<0.01, * p<0.05, . p<0.1"
  )

#creating month lags for climate variables vs incidence
#create monthly lags 1 - 5

# 1. Define helper functions

# Safe correlation test (returns tibble even if correlation fails)
cor_test_safe <- function(x, y) {
  tryCatch({
    res <- cor.test(x, y, method = "pearson")
    tibble(correlation = unname(res$estimate),
           p_value = res$p.value)
  }, error = function(e) tibble(correlation = NA, p_value = NA))
}

# Convert p-values to significance stars
p_to_stars <- function(p) {
  case_when(
    is.na(p) ~ "",
    p < 0.001 ~ "***",
    p < 0.01  ~ "**",
    p < 0.05  ~ "*",
    p < 0.1   ~ ".",
    TRUE ~ ""
  )
}


# 2. Define climate variables

climate_vars <- c("rainfall", "humidity", "temperature", "windspeed")


# 3. Create lagged climate variables (1–5 months)

lagged_df <- merged_df %>%
  group_by(district) %>%
  arrange(date, .by_group = TRUE) %>%
  mutate(across(
    all_of(climate_vars),
    list(
      lag1 = ~lag(.x, 1),
      lag2 = ~lag(.x, 2),
      lag3 = ~lag(.x, 3),
      lag4 = ~lag(.x, 4),
      lag5 = ~lag(.x, 5),
      lag6 = ~lag(.x, 6)
    ),
    .names = "{.col}_{.fn}"
  )) %>%
  ungroup()


# 4. Compute correlations for each district and lag

climate_lag_vars <- grep("rainfall|humidity|temperature|windspeed", names(lagged_df), value = TRUE)

correlation_lags <- lagged_df %>%
  group_by(district) %>%
  summarise(
    across(
      all_of(climate_lag_vars),
      ~ list(cor_test_safe(incidence, .x)), 
      .names = "{.col}"
    ),
    .groups = "drop"
  ) %>%
  mutate(across(all_of(climate_lag_vars), ~ map_dfr(.x, ~.x))) %>%
  unnest(cols = everything(), names_sep = "_")


# 5. Reshape to long format for plotting/analysis

cor_long_lags <- correlation_lags %>%
  pivot_longer(
    cols = matches("_(lag[1-5])_"),
    names_to = c("climate_variable", "lag", ".value"),
    names_pattern = "(.*)_(lag[1-6])_(.*)"
  ) %>%
  filter(!is.na(correlation)) %>%
  mutate(stars = p_to_stars(p_value))

# Add this short line below:
cor_long_lags <- cor_long_lags %>%
  mutate(stars = ifelse(stars == "", "none", stars))
#correlation heatmaps with district facets
cor_long_lags_heat <- cor_long_lags %>%
  mutate(stars = ifelse(stars == "none", "", stars))

ggplot(cor_long_lags_heat,
       aes(x = lag, y = climate_variable, fill = correlation)) +
  geom_tile(color = "white") +
  geom_text(aes(label = stars), color = "black", size = 4, fontface = "bold") +
  scale_fill_gradient2(
    low = "blue", mid = "white", high = "red",
    midpoint = 0,
    limits = c(-1, 1),
    name = "Correlation"
  ) +
  facet_wrap(~ district, ncol = 4) +
  theme_minimal(base_size = 12) +
  labs(
    title = "Lagged Climate–Malaria Correlation Matrices by District (1–5 Month Lags)",
    x = "Lag (months)",
    y = "Climate Variable"
  ) +
  theme(
    panel.grid = element_blank(),
    axis.text.x = element_text(size = 10),
    axis.text.y = element_text(size = 10),
    strip.text = element_text(face = "bold")
  )

#plotting lagged values per year
lagged_df <- lagged_df %>%
  mutate(year = year(date))

#getting correlatioms per year
correlation_lags_year <- lagged_df %>%
  group_by(district, year) %>%
  summarise(
    across(
      all_of(climate_lag_vars),
      ~ list(cor_test_safe(incidence, .x)),
      .names = "{.col}"
    ),
    .groups = "drop"
  ) %>%
  mutate(across(all_of(climate_lag_vars), ~ map_dfr(.x, ~.x))) %>%
  unnest(cols = everything(), names_sep = "_")

#reshaping to long format
cor_long_lags_year <- correlation_lags_year %>%
  pivot_longer(
    cols = matches("_(lag[1-5])_"),
    names_to = c("climate_variable", "lag", ".value"),
    names_pattern = "(.*)_(lag[1-5])_(.*)"
  ) %>%
  filter(!is.na(correlation)) %>%
  mutate(
    stars = p_to_stars(p_value),
    stars = ifelse(stars == "none", "", stars),
    lag = factor(lag, levels = paste0("lag", 1:5))
  )
#plotting facets per year
ggplot(cor_long_lags_year,
       aes(x = lag, y = climate_variable, fill = correlation)) +
  geom_tile(color = "white") +
  geom_text(aes(label = stars), color = "black", size = 3, fontface = "bold") +
  scale_fill_gradient2(
    low = "blue", mid = "white", high = "red",
    midpoint = 0,
    limits = c(-1, 1),
    name = "Correlation"
  ) +
  facet_grid(year ~ district) +  # rows = year, columns = district
  theme_minimal(base_size = 12) +
  labs(
    title = "Lagged Climate–Malaria Correlation Matrices by District and Year",
    x = "Lag (months)",
    y = "Climate Variable"
  ) +
  theme(
    panel.grid = element_blank(),
    axis.text.x = element_text(size = 8),
    axis.text.y = element_text(size = 10),
    strip.text = element_text(face = "bold")
  )