Comparative Analysis of ARIMA, SARIMA, HOLT-WINTERS & FB PROPHET

# Load required libraries
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.1     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(lubridate)

# Read the dataset
covid_data <- read_csv("df.csv")
## Rows: 1387 Columns: 6
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (2): location, date
## dbl (4): total_cases, new_cases, total_deaths, new_deaths
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# Ensure the 'date' column is in Date format
covid_data <- covid_data %>%
    mutate(date = dmy(date)) %>%
  arrange(date) 
  

head(covid_data)
## # A tibble: 6 × 6
##   location date       total_cases new_cases total_deaths new_deaths
##   <chr>    <date>           <dbl>     <dbl>        <dbl>      <dbl>
## 1 Kenya    2020-03-15           1         1            0          0
## 2 Kenya    2020-03-16           1         0            0          0
## 3 Kenya    2020-03-17           1         0            0          0
## 4 Kenya    2020-03-18           1         0            0          0
## 5 Kenya    2020-03-19           1         0            0          0
## 6 Kenya    2020-03-20           1         0            0          0
# Check for missing dates (optional, for completeness)
all_dates <- seq(min(covid_data$date), max(covid_data$date), by = "day")

# Fill any missing dates with 0 new cases
covid_daily <- tibble(date = all_dates) %>%
  left_join(covid_data, by = "date") %>%
  mutate(new_cases = replace_na(new_cases, 0))

# Convert to time series object
ts_daily <- ts(covid_daily$new_cases,
               start = c(year(min(covid_daily$date)), yday(min(covid_daily$date))),
               frequency = 365)

# Preview
# print(ts_daily) %>%  head()
# Load required libraries
library(forecast)
## Registered S3 method overwritten by 'quantmod':
##   method            from
##   as.zoo.data.frame zoo
library(Metrics)
## 
## Attaching package: 'Metrics'
## The following object is masked from 'package:forecast':
## 
##     accuracy
# === Train-Test Split ===
# Use 80% of data for training
n <- length(ts_daily)
train_len <- floor(0.8 * n)

train_ts <- ts_daily[1:train_len]
test_ts <- ts_daily[(train_len + 1):n]
# === 1. Non-seasonal ARIMA ===
model_arima <- auto.arima(train_ts, seasonal = FALSE)
forecast_arima <- forecast(model_arima, h = length(test_ts))

# Evaluate
rmse_arima <- rmse(test_ts, forecast_arima$mean)
mae_arima <- mae(test_ts, forecast_arima$mean)
# === 2. Seasonal ARIMA (SARIMA) ===
model_sarima <- auto.arima(train_ts, seasonal = TRUE)
forecast_sarima <- forecast(model_sarima, h = length(test_ts))

# Evaluate
rmse_sarima <- rmse(test_ts, forecast_sarima$mean)
mae_sarima <- mae(test_ts, forecast_sarima$mean)
# # === Print Results ===
# cat("ARIMA vs SARIMA Model Performance on Daily COVID-19 Cases:\n")
# cat(sprintf("Non-seasonal ARIMA:  RMSE = %.2f | MAE = %.2f\n", rmse_arima, mae_arima))
# cat(sprintf("Seasonal SARIMA:     RMSE = %.2f | MAE = %.2f\n", rmse_sarima, mae_sarima))
# === 3. Holt-Winters Exponential Smoothing ===
# Holt-Winters typically requires seasonality — here we use daily data with smoothing
model_hw <- ets(train_ts)  # ETS chooses the best among additive/multiplicative trends
forecast_hw <- forecast(model_hw, h = length(test_ts))

# Evaluate
rmse_hw <- rmse(test_ts, forecast_hw$mean)
mae_hw <- mae(test_ts, forecast_hw$mean)
# === 4. Facebook Prophet ===
library(prophet)
## Loading required package: Rcpp
## Loading required package: rlang
## 
## Attaching package: 'rlang'
## The following object is masked from 'package:Metrics':
## 
##     ll
## The following objects are masked from 'package:purrr':
## 
##     %@%, flatten, flatten_chr, flatten_dbl, flatten_int, flatten_lgl,
##     flatten_raw, invoke, splice
# Prepare data for Prophet: expects dataframe with 'ds' and 'y'
df_prophet <- covid_daily %>%
  rename(ds = date, y = new_cases)

# Prophet Train-Test Split
df_train <- df_prophet[1:train_len, ]
df_test <- df_prophet[(train_len + 1):nrow(df_prophet), ]

# Fit Prophet
model_prophet <- prophet(df_train, daily.seasonality = TRUE, yearly.seasonality = TRUE)

# Forecast
future <- make_future_dataframe(model_prophet, periods = nrow(df_test))
forecast_prophet <- predict(model_prophet, future)

# Get only forecast values for test period
preds_prophet <- tail(forecast_prophet$yhat, nrow(df_test))

# Evaluate
rmse_prophet <- rmse(df_test$y, preds_prophet)
mae_prophet <- mae(df_test$y, preds_prophet)
# === Print All Model Results ===
cat("\n📊 Final RMSE and MAE for All Models (Daily COVID-19 Cases - Kenya):\n")
## 
## 📊 Final RMSE and MAE for All Models (Daily COVID-19 Cases - Kenya):
cat(sprintf("1. Non-seasonal ARIMA:   RMSE = %.2f | MAE = %.2f\n", rmse_arima, mae_arima))
## 1. Non-seasonal ARIMA:   RMSE = 24.70 | MAE = 5.09
cat(sprintf("2. Seasonal SARIMA:      RMSE = %.2f | MAE = %.2f\n", rmse_sarima, mae_sarima))
## 2. Seasonal SARIMA:      RMSE = 24.70 | MAE = 5.09
cat(sprintf("3. Holt-Winters (ETS):   RMSE = %.2f | MAE = %.2f\n", rmse_hw, mae_hw))
## 3. Holt-Winters (ETS):   RMSE = 24.75 | MAE = 5.25
cat(sprintf("4. Facebook Prophet:     RMSE = %.2f | MAE = %.2f\n", rmse_prophet, mae_prophet))
## 4. Facebook Prophet:     RMSE = 818.40 | MAE = 731.77
# Load ggplot2 for visualization
library(ggplot2)

# Create a common test date vector
test_dates <- covid_daily$date[(train_len + 1):n]

# Create individual forecast dataframes
df_arima <- data.frame(date = test_dates, model = "ARIMA", forecast = as.numeric(forecast_arima$mean))
df_sarima <- data.frame(date = test_dates, model = "SARIMA", forecast = as.numeric(forecast_sarima$mean))
df_hw <- data.frame(date = test_dates, model = "Holt-Winters", forecast = as.numeric(forecast_hw$mean))
df_prophet <- data.frame(date = test_dates, model = "Prophet", forecast = preds_prophet)
# Actual values for test set
actual_df <- data.frame(date = test_dates, actual = test_ts)

# Combine forecasts
forecast_combined <- bind_rows(df_arima, df_sarima, df_hw, df_prophet)

# Join with actuals
plot_data <- left_join(forecast_combined, actual_df, by = "date")

# Plot
ggplot(plot_data, aes(x = date)) +
  geom_line(aes(y = actual), color = "black", size = 1.2, linetype = "dashed") +
  geom_line(aes(y = forecast, color = model), size = 1) +
  labs(
    title = "COVID-19 Daily New Cases in Kenya: Forecast vs Actual",
    subtitle = "Test Set Forecasts by ARIMA, SARIMA, Holt-Winters, Prophet",
    x = "Date", y = "New Daily Cases", color = "Model"
  ) +
  theme_minimal(base_size = 14)
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.