Before knitting (do this once):

  1. Install every package the notebook uses — run this line in the RStudio Console once. The install.packages() chunks inside the document are set to eval=FALSE, so they will not reinstall on each knit:

    install.packages(c(
      "tidyverse", "patchwork", "moments", "reshape2", "scales",
      "ranger", "rsample", "yardstick", "tune", "dials",
      "parsnip", "workflows", "recipes", "doParallel", "caret"
    ))
  2. Set the data path — in the Section 1 data-loading chunk, change the read.csv() path to point to the SPARCS 2021 CSV on your machine.

  3. Be patient — the dataset has ~2.1M rows; the Random Forest models (Section 4) and the linear model on the full feature set (Section 5) can each take several minutes to knit.

1.Data Colletion

  • The dataset used in this project is the Hospital Inpatient Discharges (SPARCS De-Identified): 2021, published by the New York State Department of Health.
  • Basic inspection was performed using head(), dim(), and glimpse() to examine the structure and composition of the dataset.
install.packages("patchwork")
install.packages("moments")
install.packages("reshape2")
#load package
library(tidyverse)
library(patchwork)
library(moments)
library(reshape2)
install.packages(c(
  "tidyverse", "ranger", "rsample", "yardstick", "tune",
  "dials", "parsnip", "workflows", "recipes", "doParallel", "themis"
))
install.packages(c("hms", "readr", "tidyverse"))   # binaries, no compiling
library(readr)
# # 1. Does the folder exist?
# dir.exists("C:/Users/kashi/OneDrive/Desktop/Project/WQD7003 DATA ANALYTICS")

# # 2. List what's actually in it (shows the EXACT filename)
# list.files("C:/Users/kashi/OneDrive/Desktop/Project/WQD7003 DATA ANALYTICS")

# # 3. Or just hunt for it anywhere under Project (returns the real full path)
# list.files("C:/Users/kashi/OneDrive/Desktop/Project",
#            pattern = "Hospital", recursive = TRUE, full.names = TRUE)
R.version.string     # must now say 4.5.3
## [1] "R version 4.5.3 (2026-03-11 ucrt)"
# # Authenticate Google Drive access
# library(googledrive)
# drive_auth()
# # Find the file
# file_info <- drive_get("Hospital_Inpatient_Discharges_2021.csv")

# # Download it locally to Colab
# drive_download(file_info, path = "Hospital_Inpatient_Discharges.csv", overwrite = TRUE)
# raw_df <- read.csv("C:/Users/kashi/OneDrive/Desktop/Project/WQD7003 DATA ANALYTICS/Hospital_Inpatient_Discharges_(SPARCS_De-Identified)__2021_20260420.csv",
#                    stringsAsFactors = FALSE, check.names = FALSE)
# Read the file
raw_df <- read.csv("C:/Users/kashi/OneDrive/Desktop/Project/WQD7003 DATA ANALYTICS/Hospital_Inpatient_Discharges_(SPARCS_De-Identified)__2021_20260420.csv")
head(raw_df)
dim(raw_df)
## [1] 2135260      33
glimpse(raw_df)
## Rows: 2,135,260
## Columns: 33
## $ Hospital.Service.Area               <chr> "New York City", "New York City", …
## $ Hospital.County                     <chr> "Bronx", "Bronx", "Bronx", "Bronx"…
## $ Operating.Certificate.Number        <int> 7000006, 7000006, 7000006, 7000006…
## $ Permanent.Facility.Id               <int> 1169, 1169, 1168, 3058, 1169, 3058…
## $ Facility.Name                       <chr> "Montefiore Medical Center - Henry…
## $ Age.Group                           <chr> "70 or Older", "50 to 69", "18 to …
## $ Zip.Code...3.digits                 <chr> "104", "104", "104", "104", "104",…
## $ Gender                              <chr> "M", "F", "F", "M", "F", "M", "F",…
## $ Race                                <chr> "Other Race", "White", "Other Race…
## $ Ethnicity                           <chr> "Spanish/Hispanic", "Not Span/Hisp…
## $ Length.of.Stay                      <chr> "27", "4", "2", "5", "3", "6", "1"…
## $ Type.of.Admission                   <chr> "Emergency", "Emergency", "Emergen…
## $ Patient.Disposition                 <chr> "Home w/ Home Health Services", "H…
## $ Discharge.Year                      <int> 2021, 2021, 2021, 2021, 2021, 2021…
## $ CCSR.Diagnosis.Code                 <chr> "INF012", "NVS005", "PRG016", "GEN…
## $ CCSR.Diagnosis.Description          <chr> "COVID-19", "Multiple sclerosis", …
## $ CCSR.Procedure.Code                 <chr> "OTR004", "", "PGN003", "ADM017", …
## $ CCSR.Procedure.Description          <chr> "ISOLATION PROCEDURES", "", "CESAR…
## $ APR.DRG.Code                        <int> 137, 43, 540, 463, 58, 813, 55, 64…
## $ APR.DRG.Description                 <chr> "MAJOR RESPIRATORY INFECTIONS AND …
## $ APR.MDC.Code                        <int> 4, 1, 14, 11, 1, 21, 1, 15, 18, 6,…
## $ APR.MDC.Description                 <chr> "DISEASES AND DISORDERS OF THE RES…
## $ APR.Severity.of.Illness.Code        <int> 3, 2, 1, 3, 2, 3, 1, 2, 2, 3, 3, 3…
## $ APR.Severity.of.Illness.Description <chr> "Major", "Moderate", "Minor", "Maj…
## $ APR.Risk.of.Mortality               <chr> "Extreme", "Minor", "Minor", "Majo…
## $ APR.Medical.Surgical.Description    <chr> "Medical", "Medical", "Surgical", …
## $ Payment.Typology.1                  <chr> "Medicare", "Private Health Insura…
## $ Payment.Typology.2                  <chr> "Medicaid", "", "", "Medicaid", "M…
## $ Payment.Typology.3                  <chr> "", "", "", "", "", "", "", "", ""…
## $ Birth.Weight                        <chr> "", "", "", "", "", "", "", "03100…
## $ Emergency.Department.Indicator      <chr> "Y", "Y", "N", "Y", "Y", "Y", "Y",…
## $ Total.Charges                       <chr> "$320,922.43", "$61,665.22", "$42,…
## $ Total.Costs                         <chr> "$60,241.34", "$9,180.69", "$11,36…
cat("Memory usage:", format(object.size(raw_df), units = "MB"), "\n")
## Memory usage: 696.7 Mb

Results:

Dimensions: 2,135,260 rows × 33 columns

Data types: 2 types detected via glimpse()

chr — categorical and descriptive variables (e.g. Gender, Race, Patient.Disposition) int — numeric codes and identifiers (e.g. APR.DRG.Code, Discharge.Year)

Missing Value Analysis

A missing value analysis was conducted to evaluate the extent of incomplete data across variables. This helps in:

  • Identifying which columns contain missing values
  • Assessing the proportion of missingness per variable
  • Informing preprocessing decisions in later stages
# Replace empty strings, "N/A", "NULL" with NA
raw_df[raw_df == "" | raw_df == "N/A" | raw_df == "NULL"] <- NA

# library(dplyr)

data.frame(
  missing_count   = colSums(is.na(raw_df)),        # total missing per column
  missing_percent = colMeans(is.na(raw_df)) * 100  # missing percentage per column
) |>
  arrange(desc(missing_count)) ->  # sort by missing count descending
  missing_value_summary

print(missing_value_summary)
##                                     missing_count missing_percent
## Birth.Weight                              1925333     90.16855090
## Payment.Typology.3                        1803108     84.44442363
## Payment.Typology.2                        1097001     51.37552336
## CCSR.Procedure.Code                        583187     27.31222427
## CCSR.Procedure.Description                 583187     27.31222427
## Zip.Code...3.digits                         40246      1.88482901
## Operating.Certificate.Number                 6663      0.31204631
## Hospital.Service.Area                        5214      0.24418572
## Hospital.County                              5214      0.24418572
## Permanent.Facility.Id                        5214      0.24418572
## APR.Severity.of.Illness.Description           589      0.02758446
## APR.Risk.of.Mortality                         589      0.02758446
## Facility.Name                                   0      0.00000000
## Age.Group                                       0      0.00000000
## Gender                                          0      0.00000000
## Race                                            0      0.00000000
## Ethnicity                                       0      0.00000000
## Length.of.Stay                                  0      0.00000000
## Type.of.Admission                               0      0.00000000
## Patient.Disposition                             0      0.00000000
## Discharge.Year                                  0      0.00000000
## CCSR.Diagnosis.Code                             0      0.00000000
## CCSR.Diagnosis.Description                      0      0.00000000
## APR.DRG.Code                                    0      0.00000000
## APR.DRG.Description                             0      0.00000000
## APR.MDC.Code                                    0      0.00000000
## APR.MDC.Description                             0      0.00000000
## APR.Severity.of.Illness.Code                    0      0.00000000
## APR.Medical.Surgical.Description                0      0.00000000
## Payment.Typology.1                              0      0.00000000
## Emergency.Department.Indicator                  0      0.00000000
## Total.Charges                                   0      0.00000000
## Total.Costs                                     0      0.00000000
# library(ggplot2)
ggplot(missing_value_summary, aes(x = reorder(rownames(missing_value_summary), missing_percent),
                                   y = missing_percent)) +
  geom_bar(stat = "identity", fill = "steelblue") +
  coord_flip() +
  labs(title = "Missing Values bar chart",
       x = "Column",
       y = "Missing (%)") +
  theme_minimal() +
  theme(panel.grid = element_blank())  # remove all gridlines

Key Observations on Missing Data

  • Some variables contain extremely high missingness:

    • Birth.Weight (~90%)
    • Payment.Typology.3 (~84%)
    • Payment.Typology.2 (~51%)
  • CCSR.Procedure.Code and CCSR.Procedure.Description are moderately missing (~27% each)

  • 21 out of 33 columns have no missing values and are reliable for modelling

  • These observations will inform the preprocessing strategy in later stages

2. Exploratory Data Analysis (EDA)

EDA is conducted prior to preprocessing to understand the data in its most complete form.

This EDA focuses on: - Distribution of both target variables - Missingness patterns and their structural causes - Relationship between features and each target variable - Feature redundancy and correlation - Geographic distribution

All analysis is performed on eda_df, a lightly cleaned version of raw_df with type conversions only.

2.1 Setup

An eda_df was created from raw_df with minimal type conversions only.

During type conversion, two data quality issues were identified and resolved:

  • Length of Stay: 1,575 rows (0.07%) contained the value "120+", representing stays of 120 or more days. The "+" suffix was removed and values were retained as 120.

  • Birth Weight: 122 rows contained the value "UNKN". These were coerced to NA as the column will be dropped in the preprocessing stage due to ~90% missingness.

categorical_cols <- c(
  "Hospital.Service.Area", "Hospital.County", "Operating.Certificate.Number",
  "Permanent.Facility.Id", "Facility.Name",
  "Age.Group", "Zip.Code...3.digits", "Gender", "Race", "Ethnicity",
  "Type.of.Admission", "Patient.Disposition",
  "APR.DRG.Code", "APR.DRG.Description",
  "APR.MDC.Code", "APR.MDC.Description",
  "CCSR.Diagnosis.Code", "CCSR.Diagnosis.Description",
  "CCSR.Procedure.Code",
  "APR.Severity.of.Illness.Code", "APR.Severity.of.Illness.Description",
  "APR.Risk.of.Mortality", "APR.Medical.Surgical.Description",
  "Payment.Typology.1", "Payment.Typology.2", "Payment.Typology.3",
  "Emergency.Department.Indicator"
)

numeric_cols <- c(
  "Length.of.Stay", "Birth.Weight",
  "Total.Charges", "Total.Costs"
)
# Create EDA dataframe with type conversion
eda_df <- raw_df %>%
  mutate(

    # Convert categorical columns to factors
    across(all_of(categorical_cols), as.factor),

    # Convert Length of Stay to integer
    Length.of.Stay = as.integer(gsub("\\+", "", Length.of.Stay)),  # remove "+" then convert

    # Convert Birth Weight to integer
    Birth.Weight = suppressWarnings(as.integer(Birth.Weight)),

    # Remove $ and commas, then convert charges to numeric
    Total.Charges = as.double(gsub("[$,]", "", Total.Charges)),

    # Remove $ and commas, then convert costs to numeric
    Total.Costs = as.double(gsub("[$,]", "", Total.Costs)),

    # Convert year into Date format
    Discharge.Year = as.Date(paste0(Discharge.Year, "-01-01"))
  )

# Verify data types after conversion
glimpse(eda_df)
## Rows: 2,135,260
## Columns: 33
## $ Hospital.Service.Area               <fct> New York City, New York City, New …
## $ Hospital.County                     <fct> Bronx, Bronx, Bronx, Bronx, Bronx,…
## $ Operating.Certificate.Number        <fct> 7000006, 7000006, 7000006, 7000006…
## $ Permanent.Facility.Id               <fct> 1169, 1169, 1168, 3058, 1169, 3058…
## $ Facility.Name                       <fct> "Montefiore Medical Center - Henry…
## $ Age.Group                           <fct> 70 or Older, 50 to 69, 18 to 29, 7…
## $ Zip.Code...3.digits                 <fct> 104, 104, 104, 104, 104, 105, 104,…
## $ Gender                              <fct> M, F, F, M, F, M, F, M, M, F, F, M…
## $ Race                                <fct> Other Race, White, Other Race, Oth…
## $ Ethnicity                           <fct> Spanish/Hispanic, Not Span/Hispani…
## $ Length.of.Stay                      <int> 27, 4, 2, 5, 3, 6, 1, 3, 21, 2, 3,…
## $ Type.of.Admission                   <fct> Emergency, Emergency, Emergency, E…
## $ Patient.Disposition                 <fct> Home w/ Home Health Services, Home…
## $ Discharge.Year                      <date> 2021-01-01, 2021-01-01, 2021-01-0…
## $ CCSR.Diagnosis.Code                 <fct> INF012, NVS005, PRG016, GEN004, NV…
## $ CCSR.Diagnosis.Description          <fct> "COVID-19", "Multiple sclerosis", …
## $ CCSR.Procedure.Code                 <fct> OTR004, NA, PGN003, ADM017, CNS002…
## $ CCSR.Procedure.Description          <chr> "ISOLATION PROCEDURES", NA, "CESAR…
## $ APR.DRG.Code                        <fct> 137, 43, 540, 463, 58, 813, 55, 64…
## $ APR.DRG.Description                 <fct> "MAJOR RESPIRATORY INFECTIONS AND …
## $ APR.MDC.Code                        <fct> 4, 1, 14, 11, 1, 21, 1, 15, 18, 6,…
## $ APR.MDC.Description                 <fct> "DISEASES AND DISORDERS OF THE RES…
## $ APR.Severity.of.Illness.Code        <fct> 3, 2, 1, 3, 2, 3, 1, 2, 2, 3, 3, 3…
## $ APR.Severity.of.Illness.Description <fct> Major, Moderate, Minor, Major, Mod…
## $ APR.Risk.of.Mortality               <fct> Extreme, Minor, Minor, Major, Mino…
## $ APR.Medical.Surgical.Description    <fct> Medical, Medical, Surgical, Medica…
## $ Payment.Typology.1                  <fct> "Medicare", "Private Health Insura…
## $ Payment.Typology.2                  <fct> "Medicaid", NA, NA, "Medicaid", "M…
## $ Payment.Typology.3                  <fct> NA, NA, NA, NA, NA, NA, NA, NA, NA…
## $ Birth.Weight                        <int> NA, NA, NA, NA, NA, NA, NA, 3100, …
## $ Emergency.Department.Indicator      <fct> Y, Y, N, Y, Y, Y, Y, N, Y, Y, Y, Y…
## $ Total.Charges                       <dbl> 320922.43, 61665.22, 42705.34, 727…
## $ Total.Costs                         <dbl> 60241.34, 9180.69, 11366.50, 12111…
cat("Memory usage:", format(object.size(eda_df), units = "MB"), "\n")
## Memory usage: 301.6 Mb

Result: eda_df retains all 2,135,260 rows and 33 columns. After type conversion: - 27 categorical variables converted to factor - 4 numerical variables correctly typed as integer / double - 1 date variable converted to Date

Memory usage reduced from 696.7 MB → 301.6 MB.

2.2 Target Variable Analysis

Before analysing individual features, we first examine the distribution of the primary target variable — Length of Stay (LoS) — to understand its characteristics and inform modelling decisions.

2.2.1 LoS Statistical Summary - Regression target

Summary statistics are computed to understand the central tendency, spread, and shape of the LoS distribution.

# Summary statistics for Length of Stay
cat(sprintf("count  %12.2f\n", sum(!is.na(eda_df$Length.of.Stay))))
## count    2135260.00
cat(sprintf("mean   %12.2f\n", mean(eda_df$Length.of.Stay, na.rm = TRUE)))
## mean           5.75
cat(sprintf("std    %12.2f\n", sd(eda_df$Length.of.Stay, na.rm = TRUE)))
## std            8.41
cat(sprintf("min    %12.2f\n", min(eda_df$Length.of.Stay, na.rm = TRUE)))
## min            1.00
cat(sprintf("25%%    %12.2f\n", quantile(eda_df$Length.of.Stay, 0.25, na.rm = TRUE)))
## 25%            2.00
cat(sprintf("50%%    %12.2f\n", quantile(eda_df$Length.of.Stay, 0.50, na.rm = TRUE)))
## 50%            3.00
cat(sprintf("75%%    %12.2f\n", quantile(eda_df$Length.of.Stay, 0.75, na.rm = TRUE)))
## 75%            6.00
cat(sprintf("max    %12.2f\n", max(eda_df$Length.of.Stay, na.rm = TRUE)))
## max          120.00
cat(sprintf("\nSkewness : %.2f\n", skewness(eda_df$Length.of.Stay, na.rm = TRUE)))
## 
## Skewness : 5.82
cat(sprintf("Kurtosis : %.2f\n", kurtosis(eda_df$Length.of.Stay, na.rm = TRUE) - 3))
## Kurtosis : 52.96

2.2.2 Outlier Analysis

Given the high skewness observed, an IQR-based outlier analysis is conducted to identify extreme values and determine whether they should be removed or retained.

# Boxplot of Length of Stay to visualise outliers
options(repr.plot.width = 10, repr.plot.height = 5)

boxplot(eda_df$Length.of.Stay,
        horizontal = TRUE,
        col = "lightcoral",
        main = "Length of Stay - Outlier Boxplot",
        xlab = "Days",
        ylim = c(0, 130),
        boxwex = 1.5,
        outline = TRUE)

# Calculate IQR and upper fence
q1 <- quantile(eda_df$Length.of.Stay, 0.25, na.rm = TRUE)
q3 <- quantile(eda_df$Length.of.Stay, 0.75, na.rm = TRUE)
iqr <- q3 - q1
upper <- q3 + 1.5 * iqr
n_out <- sum(eda_df$Length.of.Stay > upper, na.rm = TRUE)

cat(sprintf("Q1 = %.0f, Q3 = %.0f, IQR = %.0f\n", q1, q3, iqr))
## Q1 = 2, Q3 = 6, IQR = 4
cat(sprintf("Upper fence = %.1f\n", upper))
## Upper fence = 12.0
cat(sprintf("Records above upper fence: %s (%.2f%%)\n",
    format(n_out, big.mark=","),
    n_out/nrow(eda_df)*100))
## Records above upper fence: 212,255 (9.94%)

With Q1 = 2, Q3 = 6 and IQR = 4, the upper fence is 12 days. A total of 212,255 records (9.94%) exceed this threshold. However, these represent genuinely complex cases such as chronic or critically ill patients. Removing them would introduce bias, so they are retained.

2.2.3 Log Transformation

As LoS is heavily right-skewed with genuine extreme values, a log1p transformation is applied rather than removing outliers. This reduces skewness while preserving all records.

# Set plot dimensions
options(repr.plot.width = 12, repr.plot.height = 4)

# Plot Length of Stay distribution before and after log transformation
par(mfrow = c(1, 2))

hist(eda_df$Length.of.Stay,
     breaks = 50, col = "#4682B4", border = "white",
     main = "Length of Stay — raw", xlab = "Days")

hist(log1p(eda_df$Length.of.Stay),
     breaks = 50, col = "#008080", border = "white",
     main = "Length of Stay — after log1p",
     xlab = "log(1 + Days)",
     xlim = c(0, 5))

par(mfrow = c(1, 1))

# Calculate and print skewness before and after transformation
cat(sprintf("Skewness before: %.2f\n", skewness(eda_df$Length.of.Stay, na.rm = TRUE)))
## Skewness before: 5.82
cat(sprintf("Skewness after:  %.2f\n", skewness(log1p(eda_df$Length.of.Stay), na.rm = TRUE)))
## Skewness after:  0.97

Skewness reduced significantly from 5.82 to 0.97, producing a more symmetric distribution suitable for modelling. The transformed variable log1p(Length.of.Stay) will be used as the target variable in the modelling stage.

2.2.4 Patient Disposition (Classification Target)

Patient Disposition represents the discharge destination of each patient and serves as the classification target variable. Its distribution is examined to understand class balance before modelling.

# Count and percentage of each Patient Disposition category
eda_df %>%
  count(Patient.Disposition) %>%
  mutate(percentage = round(n / sum(n) * 100, 2)) %>%
  arrange(desc(n))
options(repr.plot.width = 10, repr.plot.height = 5)

# Bar chart of Patient Disposition distribution
eda_df %>%
  count(Patient.Disposition) %>%
  mutate(Patient.Disposition = fct_reorder(
    as.character(Patient.Disposition), n)) %>%
  ggplot(aes(x = n, y = Patient.Disposition)) +
  geom_col(fill = "#2980b9") +
  geom_text(aes(label = scales::comma(n)), hjust = -0.1, size = 3) +
  scale_x_continuous(labels = scales::comma) +
  labs(title = "Patient Disposition Distribution",
       x = "Count", y = NULL) +
  theme_minimal() +
  theme(
    plot.title = element_text(face = "bold", size = 13),
    panel.grid.major.y = element_blank()
  ) +
  xlim(0, 1600000)

options(repr.plot.width = 12, repr.plot.height = 6)

age_order <- c("0 to 17", "18 to 29", "30 to 49", "50 to 69", "70 or Older")

eda_df %>%
  filter(!is.na(Age.Group), !is.na(Patient.Disposition)) %>%
  count(Age.Group, Patient.Disposition) %>%
  group_by(Age.Group) %>%
  mutate(
    pct = n / sum(n) * 100,
    Age.Group = factor(Age.Group, levels = age_order)
  ) %>%
  ggplot(aes(x = Age.Group, y = pct, fill = Patient.Disposition)) +
  geom_col(position = "fill") +
  scale_y_continuous(labels = scales::percent) +
  labs(
    title = "Patient Disposition by Age Group",
    x = "Age Group",
    y = "Proportion",
    fill = "Patient Disposition"
  ) +
  theme_minimal() +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "right",
    legend.text = element_text(size = 8)
  )

Observation:

  • Severe class imbalance: The dataset contains 19 disposition categories. Home or Self Care dominates with 1,397,511 cases (65.45%), followed by Home w/ Home Health Services (14.23%) and Skilled Nursing Home (8.56%).

  • Minority classes: Several categories have very few cases, including Critical Access Hospital (169 cases, 0.01%) and Federal Health Care Facility (520 cases, 0.02%), making them difficult for a model to learn.

  • Modelling implication: The severe class imbalance suggests that direct multiclass classification across all 19 categories is not appropriate. The classification strategy will be determined in the modelling stage.

2.3 Missingness Pattern Analysis

Further analysis is conducted to determine whether missing values occur independently or follow a structural pattern. Understanding the nature of missingness helps in applying more informed preprocessing strategies rather than blindly imputing or removing data.

# Check whether hospital-related fields tend to be missing together
cols <- c("Permanent.Facility.Id", "Hospital.County", "Hospital.Service.Area")

eda_df %>%
  select(all_of(cols)) %>%
  mutate(across(everything(), is.na)) %>%
  group_by(across(everything())) %>%
  summarise(count = n(), .groups = "drop")

Finding - Hospital-related fields: Permanent Facility ID, Hospital County, and Hospital Service Area tend to be missing together, suggesting these records originate from facilities that did not report location data.

# Check whether Severity and Mortality missing values co-occur
cols <- c("APR.Risk.of.Mortality", "APR.Severity.of.Illness.Description")

eda_df %>%
  select(all_of(cols)) %>%
  mutate(across(everything(), is.na)) %>%
  group_by(across(everything())) %>%
  summarise(count = n(), .groups = "drop")
# Verify which Severity Code values have missing descriptions
eda_df %>%
  group_by(APR.Severity.of.Illness.Code) %>%
  summarise(
    na_description = sum(is.na(APR.Severity.of.Illness.Description)),
    not_na_description = sum(!is.na(APR.Severity.of.Illness.Description))
  )
print("Result: Only Code = 0 has all NA descriptions (589 rows)")
## [1] "Result: Only Code = 0 has all NA descriptions (589 rows)"
# Identify which DRG groups contain Severity Code = 0 and missing Risk of Mortality
eda_df %>%
  filter(APR.Severity.of.Illness.Code == 0) %>%
  count(APR.DRG.Code, sort = TRUE)
eda_df %>%
  filter(is.na(APR.Risk.of.Mortality)) %>%
  count(APR.DRG.Code, sort = TRUE)

Conclusion
Missingness appears to be structural rather than random, as records with missing APR Risk of Mortality and Severity Code = 0 are concentrated entirely within DRG groups 955 and 956. Since these DRG groups contain only missing values, DRG-based imputation is not feasible. Therefore, group-based mode imputation using MDC categories will be applied during preprocessing.

2.4 Categorical Features vs Target Variables

This section explores how each categorical feature relates to Length of Stay, helping to identify which variables are strong predictors for modelling.

2.4.1 Feature Overview

Before examining individual features, a summary of all categorical and numerical variables is generated to understand their cardinality and distribution.

# Summarize categorical variables: category count, sample values, dominant category and its percentage
summarize_categorical <- function(df, max_display = 5) {
  # Select all factor columns
  cat_cols <- names(df)[sapply(df, is.factor)]

  # Compute summary statistics for each categorical column
  do.call(rbind, lapply(cat_cols, function(col) {
    vc <- sort(table(df[[col]], useNA = "no"), decreasing = TRUE)
    n_unique <- length(vc)
    top_category <- names(vc)[1]
    top_pct <- round(vc[1] / sum(vc) * 100, 2)

    # Show all values if fewer than max_display, otherwise show top values only
    sample_vals <- if (n_unique <= max_display) {
      paste(names(vc), collapse = ", ")
    } else {
      paste0(paste(names(vc)[1:max_display], collapse = ", "), " ...")
    }

    data.frame(
      Column         = col,
      N_Categories   = n_unique,
      Example_Values = sample_vals,
      Top_Category   = top_category,
      Top_Pct        = top_pct,
      stringsAsFactors = FALSE
    )
  }))
}
# Summarize numerical variables: count, mean, std, min, quartiles, max
summarize_numerical <- function(df) {
  options(scipen = 999)  # Disable scientific notation
  df %>%
    select(where(is.numeric)) %>%
    sapply(function(x) c(
      count = sum(!is.na(x)),
      mean  = round(mean(x, na.rm = TRUE), 6),
      std   = round(sd(x,   na.rm = TRUE), 6),
      min   = min(x,  na.rm = TRUE),
      `25%` = unname(quantile(x, 0.25, na.rm = TRUE)),
      `50%` = unname(quantile(x, 0.50, na.rm = TRUE)),
      `75%` = unname(quantile(x, 0.75, na.rm = TRUE)),
      max   = max(x,  na.rm = TRUE)
    )) %>%
    t() %>%
    as.data.frame()
}
summarize_categorical(eda_df)
summarize_numerical(eda_df)

2.4.2 Categorical Features vs Length of Stay

The following plots compare the distribution of Length of Stay across categories of each key feature, using mean, median, Q1 and Q3 as summary statistics.

# Define age order
age_order <- c("0 to 17", "18 to 29", "30 to 49", "50 to 69", "70 or Older")

stats <- eda_df %>%
  filter(!is.na(Age.Group), !is.na(Length.of.Stay)) %>%
  group_by(Age.Group) %>%
  summarise(
    mean   = mean(Length.of.Stay),
    q1     = quantile(Length.of.Stay, 0.25),
    median = median(Length.of.Stay),
    q3     = quantile(Length.of.Stay, 0.75),
    count  = n(),
    .groups = "drop"
  ) %>%
  mutate(Age.Group = factor(Age.Group, levels = age_order)) %>%
  arrange(Age.Group)

options(repr.plot.width = 11, repr.plot.height = 6)

par(mar = c(5, 10, 6, 5))

bp <- barplot(
  stats$count,
  horiz = TRUE,
  names.arg = stats$Age.Group,
  col = "lightsteelblue",
  border = "steelblue",
  xlab = "Count (number of patients)",
  ylab = "",
  las = 1,
  xlim = c(0, max(stats$count) * 1.05)
)

mtext("Age Group", side = 2, line = 7)
title(main = "Length of Stay by Age Group", line = 5)

par(new = TRUE)

plot(
  stats$mean,
  bp,
  type = "n",
  axes = FALSE,
  xlab = "",
  ylab = "",
  xlim = c(1.5, 8),
  ylim = range(bp) + c(-0.5, 0.5)
)

axis(3)
mtext("Length of Stay (days)", side = 3, line = 2.5)

points(stats$mean,   bp, pch = 19, col = "#c0392b", cex = 1.4)
points(stats$median, bp, pch = 19, col = "#27ae60", cex = 1.4)
points(stats$q1,     bp, pch = 19, col = "#2980b9", cex = 1.1)
points(stats$q3,     bp, pch = 19, col = "#8e44ad", cex = 1.1)

text(
  stats$median,
  bp,
  labels = sprintf("%.2f", stats$median),
  pos = 1,
  cex = 0.8,
  col = "#27ae60"
)

legend(
  "bottomright",
  legend = c("Mean", "Median", "Q1 (25th %)", "Q3 (75th %)"),
  col = c("#c0392b", "#27ae60", "#2980b9", "#8e44ad"),
  pch = 19,
  bty = "o"
)

stats %>%
  mutate(across(where(is.numeric), \(x) round(x, 2)))

Observation: LoS increases consistently with age. Patients aged 70 or older have the longest stays (median = 4 days, mean = 6.71), while patients aged 0 to 17 have the shortest (median = 2 days, mean = 4.02). This confirms age group as a strong predictor of LoS.

# library(dplyr)

sev_order <- c("Unknown", "Minor", "Moderate", "Major", "Extreme")

stats_sev <- eda_df %>%
  mutate(APR.Severity.of.Illness.Description = ifelse(
    is.na(APR.Severity.of.Illness.Description),
    "Unknown",
    as.character(APR.Severity.of.Illness.Description)
  )) %>%
  filter(!is.na(APR.Severity.of.Illness.Description),
         !is.na(Length.of.Stay)) %>%
  group_by(APR.Severity.of.Illness.Description) %>%
  summarise(
    mean   = mean(Length.of.Stay),
    q1     = quantile(Length.of.Stay, 0.25),
    median = median(Length.of.Stay),
    q3     = quantile(Length.of.Stay, 0.75),
    count  = n(),
    .groups = "drop"
  ) %>%
  mutate(APR.Severity.of.Illness.Description =
           factor(APR.Severity.of.Illness.Description, levels = sev_order)) %>%
  arrange(APR.Severity.of.Illness.Description)

options(repr.plot.width = 11, repr.plot.height = 6)

par(mar = c(5, 10, 7, 5))

bp <- barplot(
  stats_sev$count,
  horiz = TRUE,
  names.arg = stats_sev$APR.Severity.of.Illness.Description,
  col = "lightsteelblue",
  border = "steelblue",
  xlab = "Count (number of patients)",
  ylab = "",
  las = 1,
  xlim = c(0, max(stats_sev$count) * 1.05)
)

mtext("Severity of Illness", side = 2, line = 5)
title(main = "Length of Stay by Severity of Illness", line = 5)

par(new = TRUE)

plot(
  stats_sev$mean,
  bp,
  type = "n",
  axes = FALSE,
  xlab = "",
  ylab = "",
  xlim = c(0, max(stats_sev$q3, na.rm = TRUE) * 1.2),
  ylim = range(bp) + c(-0.5, 0.5)
)

axis(3)
mtext("Length of Stay (days)", side = 3, line = 3)

points(stats_sev$mean,   bp, pch = 19, col = "#c0392b", cex = 1.4)
points(stats_sev$median, bp, pch = 19, col = "#27ae60", cex = 1.4)
points(stats_sev$q1,     bp, pch = 19, col = "#2980b9", cex = 1.1)
points(stats_sev$q3,     bp, pch = 19, col = "#8e44ad", cex = 1.1)

text(
  stats_sev$median,
  bp,
  labels = sprintf("%.2f", stats_sev$median),
  pos = 1,
  cex = 0.8,
  col = "#27ae60"
)

legend(
  "bottomright",
  legend = c("Mean", "Median", "Q1", "Q3"),
  col = c("#c0392b", "#27ae60", "#2980b9", "#8e44ad"),
  pch = 19,
  bty = "o",
  x.intersp = 0.4,
  y.intersp = 0.8,
  text.width = 2
)

stats_sev %>%
  mutate(across(where(is.numeric), \(x) round(x, 2)))

Observation: Severity of Illness shows the strongest gradient across all features. Median LoS increases from 2 days (Minor) to 9 days (Extreme), with Extreme cases averaging 13.60 days — nearly 5 times longer than Minor cases (2.87 days). This confirms Severity of Illness as the most influential predictor of LoS.

Observation:

  • Patient volume: New York City accounts for the highest discharge volume, reflecting its population size, followed by Long Island and Hudson Valley.

  • Average LoS: Finger Lakes has the highest mean LoS (6.09 days), followed by Hudson Valley (5.99 days) and New York City (5.90 days). The lowest is Central NY (5.24 days).

  • Conclusion: The difference in mean LoS across all regions is less than 1 day (5.24 to 6.09), indicating that geographic location has minimal impact on LoS. Clinical factors such as severity and diagnosis type are likely more influential predictors.

Note: NA records (5,214 rows, 0.24% of the dataset) are excluded from this analysis as they represent facilities that did not report location data. These records are subsequently imputed in the preprocessing stage (Section 3.3.1).

2.6 Top Diagnoses by Mean Length of Stay

To understand how diagnosis type relates to Length of Stay, the top 10 diagnoses with the highest mean LoS are identified, filtered to diagnoses with at least 1,000 cases to avoid small-sample distortion.

options(repr.plot.width = 11, repr.plot.height = 6)

eda_df %>%
  group_by(CCSR.Diagnosis.Description) %>%
  summarise(mean_LoS = mean(Length.of.Stay, na.rm = TRUE),
            count = n()) %>%
  filter(count > 1000) %>%
  arrange(desc(mean_LoS)) %>%
  head(10) %>%
  mutate(CCSR.Diagnosis.Description = fct_reorder(
    as.character(CCSR.Diagnosis.Description), mean_LoS)) %>%
  ggplot(aes(x = mean_LoS, y = CCSR.Diagnosis.Description)) +
  geom_col(fill = "#2980b9", width = 0.6) +
  geom_text(aes(label = round(mean_LoS, 1)), hjust = -0.1, size = 3.5) +
  labs(title = "Top 10 Diagnoses by Mean Length of Stay",
       subtitle = "Filtered to diagnoses with at least 1,000 cases",
       x = "Mean Length of Stay (days)",
       y = NULL) +
  theme_minimal() +
  theme(
    plot.title = element_text(face = "bold", size = 13),
    plot.subtitle = element_text(size = 10, color = "gray40"),
    axis.text.y = element_text(size = 10),
    panel.grid.major.y = element_blank()
  ) +
  xlim(0, 27)

Observation:

  • Haematological conditions: AML ranks highest at 22.8 days, reflecting prolonged chemotherapy cycles
  • Neurological & cerebrovascular conditions: TBI (17.1 days), cerebrovascular sequelae (18.7 days), and cerebral infarction sequelae (14.5 days) all feature prominently, driven by long rehabilitation periods
  • Psychiatric conditions: Schizophrenia spectrum disorders average 16.3 days with a large patient volume (38,030 cases), indicating psychiatric admissions are both common and prolonged
  • Other chronic conditions: Pressure ulcer (14.1 days), multiple myeloma (13.6 days), and hip fracture (13.1 days) reflect the complexity of managing long-term or post-surgical cases

These findings confirm that diagnosis type is a meaningful predictor of LoS and support the inclusion of CCSR Diagnosis Code in the modelling stage.

2.7 Code vs Description Mapping

Some features appear in pairs — a numeric code and its text description. If each code maps uniquely to one description and vice versa, the two columns carry identical information and one can be safely removed.

pairs <- list(
  c("CCSR.Diagnosis.Code",  "CCSR.Diagnosis.Description"),
  c("CCSR.Procedure.Code",  "CCSR.Procedure.Description"),
  c("APR.DRG.Code",         "APR.DRG.Description"),
  c("APR.MDC.Code",         "APR.MDC.Description")

)

for (pair in pairs) {
  a <- pair[1]
  b <- pair[2]

  a_to_b <- all(tapply(eda_df[[b]], eda_df[[a]], function(x) length(unique(x)) == 1))
  b_to_a <- all(tapply(eda_df[[a]], eda_df[[b]], function(x) length(unique(x)) == 1))

  cat(sprintf("(%s, %s): %s\n", a, b, a_to_b & b_to_a))
}
## (CCSR.Diagnosis.Code, CCSR.Diagnosis.Description): TRUE
## (CCSR.Procedure.Code, CCSR.Procedure.Description): TRUE
## (APR.DRG.Code, APR.DRG.Description): TRUE
## (APR.MDC.Code, APR.MDC.Description): TRUE

Observation: All four pairs show a one-to-one mapping, confirming that the description columns are fully redundant. Only the code columns will be retained in the preprocessing stage.

2.8 Severity vs Mortality Redundancy Check

APR Severity of Illness and APR Risk of Mortality both describe patient condition severity. A cross-tabulation heatmap is used to examine whether the two features carry overlapping information.

# Define severity level order
sev_order <- c("Minor", "Moderate", "Major", "Extreme")

# Select and filter relevant columns, map severity code to labels
heatmap_df <- eda_df %>%
  select(APR.Risk.of.Mortality, APR.Severity.of.Illness.Code) %>%
  filter(!is.na(APR.Risk.of.Mortality), !is.na(APR.Severity.of.Illness.Code)) %>%
  mutate(APR.Severity.of.Illness.Code = case_when(
    APR.Severity.of.Illness.Code == 1 ~ "Minor",
    APR.Severity.of.Illness.Code == 2 ~ "Moderate",
    APR.Severity.of.Illness.Code == 3 ~ "Major",
    APR.Severity.of.Illness.Code == 4 ~ "Extreme"
  ))

# Cross-tabulation of Risk of Mortality vs Severity of Illness
ct <- table(heatmap_df$APR.Risk.of.Mortality, heatmap_df$APR.Severity.of.Illness.Code)
ct <- ct[sev_order, sev_order]

# Reshape for ggplot
melted <- as.data.frame(ct)
colnames(melted) <- c("Risk", "Severity", "Count")

# Plot heatmap
options(repr.plot.width = 7, repr.plot.height = 5)

ggplot(melted, aes(x = Severity, y = Risk, fill = Count)) +
  geom_tile() +
  geom_text(aes(label = format(Count, big.mark = ",")), size = 3) +
  scale_fill_gradient(low = "white", high = "#d73027") +
  scale_x_discrete(limits = sev_order) +
  scale_y_discrete(limits = sev_order) +
  labs(title = "Risk of Mortality vs Severity of Illness") +
  theme_minimal()

Observation: The heatmap shows a strong diagonal pattern — patients with Minor Risk of Mortality are predominantly Minor severity (599,991), and Extreme Risk patients are mostly Extreme severity (191,901). The two features are highly correlated, confirming significant redundancy. Retaining both in the model would introduce multicollinearity. APR Severity of Illness Code will be retained as the primary severity indicator. The decision to drop APR Risk of Mortality will be finalised in the modelling stage.

2.9 Correlation Analysis

A correlation heatmap is generated using integer-encoded features to explore linear relationships with log-transformed Length of Stay. Note that this encoding is for exploration purposes only and will not be used in modelling.

# install.packages("reshape2")
# library(reshape2)
# Fix encoding direction before correlation analysis
eda_df <- eda_df %>%
  mutate(APR.Risk.of.Mortality = factor(
    APR.Risk.of.Mortality,
    levels = c("Minor", "Moderate", "Major", "Extreme")
  ))

candidate_cols <- c(
  'Age.Group', 'Gender', 'Type.of.Admission',
  'APR.Severity.of.Illness.Code', 'APR.Risk.of.Mortality',
  'APR.Medical.Surgical.Description', 'Emergency.Department.Indicator',
  'CCSR.Diagnosis.Code', 'CCSR.Procedure.Code'
)

# Keep only columns that exist
candidate_cols <- candidate_cols[candidate_cols %in% colnames(eda_df)]

# Add log1p(Length.of.Stay) as LoS_log temporarily
corr_df <- eda_df %>%
  select(all_of(candidate_cols), Length.of.Stay) %>%
  mutate(
    LoS_log = log1p(Length.of.Stay),
    across(where(~ !is.numeric(.)), ~ as.numeric(as.factor(.)))
  ) %>%
  select(-Length.of.Stay)

# Correlation matrix
corr_matrix <- round(cor(corr_df, use = "complete.obs"), 2)

# Plot heatmap
options(repr.plot.width = 9, repr.plot.height = 7)
melted <- melt(corr_matrix)
ggplot(melted, aes(Var1, Var2, fill = value)) +
  geom_tile() +
  geom_text(aes(label = value), size = 3) +
  scale_fill_gradient2(low = "blue", mid = "white", high = "red", midpoint = 0) +
  labs(title = "Correlation Heatmap (encoded features)") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

# Sorted correlations with LoS_log
cat("\nCorrelation with LoS_log (sorted):\n")
## 
## Correlation with LoS_log (sorted):
corr_los <- sort(corr_matrix["LoS_log", -which(colnames(corr_matrix) == "LoS_log")],
                 decreasing = TRUE)
print(data.frame(Feature = names(corr_los), Correlation = corr_los), row.names = FALSE)
##                           Feature Correlation
##      APR.Severity.of.Illness.Code        0.53
##             APR.Risk.of.Mortality        0.48
##                         Age.Group        0.25
##    Emergency.Department.Indicator        0.25
##                            Gender        0.09
##                 Type.of.Admission        0.02
##               CCSR.Procedure.Code       -0.05
##  APR.Medical.Surgical.Description       -0.06
##               CCSR.Diagnosis.Code       -0.17

Observation: APR Severity of Illness Code shows the strongest positive correlation with LoS_log (0.53), followed by Age Group (0.25) and Emergency Department Indicator (0.25). Gender (0.09) and Type of Admission (0.02) show weak correlations, suggesting limited predictive value. APR Severity of Illness Code and APR Risk of Mortality show a mutual correlation of -0.60, further confirming their redundancy — only one should be retained for modelling.

EDA Summary

The following key findings from EDA will guide the preprocessing and modelling stages:

Target Variable: - Length of Stay is heavily right-skewed (skewness = 5.82). A log1p transformation will be applied as the regression target. - For classification, Patient Disposition is used as the target variable. The classification strategy will be determined in the modelling stage.

Strong Predictors: - APR Severity of Illness Code (correlation = 0.53) is the strongest predictor of LoS. - Age Group and Emergency Department Indicator both show moderate correlation (0.25). - Type of Admission and Severity of Illness show clear gradients across categories.

Weak Predictors: - Gender (0.09) and Type of Admission (0.02) show weak linear correlation with LoS.

Redundant Features: - Code and Description column pairs are one-to-one mapped — description columns will be dropped. - APR Severity of Illness and Risk of Mortality are highly correlated (-0.60). APR Severity of Illness Code will be retained as the primary severity indicator. The decision to drop APR Risk of Mortality will be finalised in the modelling stage.

Missingness: - Missing values are structural, not random. Group-based mode imputation by MDC group will be applied, with global mode fallback for groups with no valid values.

3.Preprocessing

3.1 Setup

A fresh copy of raw_df is created as preprocessed_df. All preprocessing steps are applied to this copy, keeping raw_df intact for reference. The preprocessing decisions below are guided by the findings from EDA.

# Create preprocessed_df from raw_df
# All cleaning and transformation will be applied to this copy
preprocessed_df <- raw_df

3.2 Feature Reduction

Based on the findings from EDA, the following columns are removed before modelling:

  • Redundant description columns: Code and Description pairs are one-to-one mapped — only the code columns are retained.
  • Excessive missing values: Birth Weight (~90% missing) is dropped.
  • Data leakage: Total Charges and Total Costs are generated after discharge and are direct functions of Length of Stay. Including them would leak the target variable into the features, artificially inflating model performance.
  • Administrative fields: Facility Name, Operating Certificate Number, Permanent Facility ID, Zip Code, and Discharge Year carry no predictive value for LoS.
preprocessed_df <- preprocessed_df %>%
  select(-any_of(c(
    # Redundant description columns
    "APR.MDC.Description",
    "APR.DRG.Description",
    "CCSR.Diagnosis.Description",
    "CCSR.Procedure.Description",
    "APR.Severity.of.Illness.Description",

    # Excessive missing values (~90%)
    "Birth.Weight",

    # Data leakage
    "Total.Charges",
    "Total.Costs",

    # Administrative fields
    "Facility.Name",
    "Operating.Certificate.Number",
    "Permanent.Facility.Id",
    "Zip.Code...3.digits",
    "Discharge.Year"
  )))

# Verify
cat("Remaining columns:\n")
## Remaining columns:
print(colnames(preprocessed_df))
##  [1] "Hospital.Service.Area"            "Hospital.County"                 
##  [3] "Age.Group"                        "Gender"                          
##  [5] "Race"                             "Ethnicity"                       
##  [7] "Length.of.Stay"                   "Type.of.Admission"               
##  [9] "Patient.Disposition"              "CCSR.Diagnosis.Code"             
## [11] "CCSR.Procedure.Code"              "APR.DRG.Code"                    
## [13] "APR.MDC.Code"                     "APR.Severity.of.Illness.Code"    
## [15] "APR.Risk.of.Mortality"            "APR.Medical.Surgical.Description"
## [17] "Payment.Typology.1"               "Payment.Typology.2"              
## [19] "Payment.Typology.3"               "Emergency.Department.Indicator"
cat("\nTotal columns remaining:", ncol(preprocessed_df))
## 
## Total columns remaining: 20

3.3 Handling Missing Values

Based on the missingness patterns identified in EDA, imputation is applied rather than removing rows to preserve the dataset size. Mode imputation is used for categorical variables.

3.3.1 Hospital Location Fields — Global Mode Imputation

Hospital.Service.Area and Hospital.County are imputed using the global mode, as their missingness is linked to facilities that did not report location data and shows no relationship with any clinical grouping.

# Mode imputation for Hospital.Service.Area and Hospital.County
preprocessed_df <- preprocessed_df %>%
  mutate(
    Hospital.Service.Area = ifelse(
      is.na(Hospital.Service.Area),
      names(which.max(table(Hospital.Service.Area))),
      as.character(Hospital.Service.Area)
    ),
    Hospital.County = ifelse(
      is.na(Hospital.County),
      names(which.max(table(Hospital.County))),
      as.character(Hospital.County)
    )
  )

# Verify
cat("Hospital.Service.Area NA:", sum(is.na(preprocessed_df$Hospital.Service.Area)), "\n")
## Hospital.Service.Area NA: 0
cat("Hospital.County NA:", sum(is.na(preprocessed_df$Hospital.County)), "\n")
## Hospital.County NA: 0

3.3.2 APR Severity of Illness Code — MDC Group Mode Imputation

Missing values are first identified by MDC group to determine the appropriate imputation strategy.

eda_df %>%
  mutate(APR.Severity.of.Illness.Code = na_if(
    as.numeric(as.character(APR.Severity.of.Illness.Code)), 0
  )) %>%
  group_by(APR.MDC.Code) %>%
  summarise(
    total       = n(),
    na_count    = sum(is.na(APR.Severity.of.Illness.Code)),
    valid_count = sum(!is.na(APR.Severity.of.Illness.Code)),
    group_mode  = ifelse(
      valid_count > 0,
      names(which.max(table(APR.Severity.of.Illness.Code, useNA = "no"))),
      "NO VALID VALUES — fallback needed"
    )
  ) %>%
  filter(na_count > 0)

Missing values are found in three MDC groups only:

MDC Description NA Count Valid Count Strategy
14 Pregnancy, Childbirth and the Puerperium 467 223,791 Group mode (Minor)
15 Newborns and Neonates 6 206,541 Group mode (Minor)
0 Pre MDC 116 0 Global mode fallback

MDC 14 and 15 contain sufficient valid values for group mode imputation. MDC 0 (Pre MDC) has no valid values, so a global mode fallback is applied first before MDC group imputation is performed on the remaining missing values. The same pattern applies to APR Risk of Mortality.

# Step 1: Convert Severity Code = 0 to NA
preprocessed_df <- preprocessed_df %>%
  mutate(APR.Severity.of.Illness.Code = na_if(
    as.numeric(as.character(APR.Severity.of.Illness.Code)), 0
  ))

#  Step 2: MDC 0 — global mode fallback
global_mode_sev <- as.numeric(names(which.max(
  table(preprocessed_df$APR.Severity.of.Illness.Code, useNA = "no")
)))

preprocessed_df <- preprocessed_df %>%
  mutate(APR.Severity.of.Illness.Code = ifelse(
    is.na(APR.Severity.of.Illness.Code) & APR.MDC.Code == 0,
    global_mode_sev,
    APR.Severity.of.Illness.Code
  ))

# Step 3: MDC 14 & 15 — group mode
preprocessed_df <- preprocessed_df %>%
  group_by(APR.MDC.Code) %>%
  mutate(APR.Severity.of.Illness.Code = ifelse(
    is.na(APR.Severity.of.Illness.Code),
    as.numeric(names(which.max(table(APR.Severity.of.Illness.Code, useNA = "no")))),
    APR.Severity.of.Illness.Code
  )) %>%
  ungroup()

# Verify
cat("Remaining NA:", sum(is.na(preprocessed_df$APR.Severity.of.Illness.Code)), "\n")
## Remaining NA: 0

3.3.3 APR Risk of Mortality — MDC Group Mode Imputation

The same missingness pattern is confirmed for APR Risk of Mortality. Missing values are verified by MDC group before imputation.

eda_df %>%
  mutate(APR.Risk.of.Mortality = na_if(
    as.character(APR.Risk.of.Mortality), "0"
  )) %>%
  group_by(APR.MDC.Code) %>%
  summarise(
    total       = n(),
    na_count    = sum(is.na(APR.Risk.of.Mortality)),
    valid_count = sum(!is.na(APR.Risk.of.Mortality)),
    group_mode  = ifelse(
      valid_count > 0,
      names(which.max(table(APR.Risk.of.Mortality, useNA = "no"))),
      "NO VALID VALUES — fallback needed"
    )
  ) %>%
  filter(na_count > 0)

The same three MDC groups are affected:

MDC Description NA Count Valid Count Strategy
14 Pregnancy, Childbirth and the Puerperium 467 223,791 Group mode (Minor)
15 Newborns and Neonates 6 206,541 Group mode (Minor)
0 Pre MDC 116 0 Global mode fallback
# Step 1: MDC 0 — global mode fallback
global_mode_rom <- names(which.max(
  table(preprocessed_df$APR.Risk.of.Mortality, useNA = "no")
))

preprocessed_df <- preprocessed_df %>%
  mutate(APR.Risk.of.Mortality = ifelse(
    is.na(APR.Risk.of.Mortality) & APR.MDC.Code == 0,
    global_mode_rom,
    as.character(APR.Risk.of.Mortality)
  ))

# Step 2: MDC 14 & 15 — group mode
preprocessed_df <- preprocessed_df %>%
  group_by(APR.MDC.Code) %>%
  mutate(APR.Risk.of.Mortality = ifelse(
    is.na(APR.Risk.of.Mortality),
    names(which.max(table(APR.Risk.of.Mortality, useNA = "no"))),
    APR.Risk.of.Mortality
  )) %>%
  ungroup()
#Verify
cat("Risk of Mortality Remaining NA:", sum(is.na(preprocessed_df$APR.Risk.of.Mortality)), "\n")
## Risk of Mortality Remaining NA: 0

3.3.4 Missing Value Imputation Summary

All missing values have been successfully imputed. The final verification confirms that no missing values remain across all four imputed columns.

data.frame(
  Column = c("Hospital.Service.Area", "Hospital.County",
             "APR.Severity.of.Illness.Code", "APR.Risk.of.Mortality"),
  Remaining_NA = c(
    sum(is.na(preprocessed_df$Hospital.Service.Area)),
    sum(is.na(preprocessed_df$Hospital.County)),
    sum(is.na(preprocessed_df$APR.Severity.of.Illness.Code)),
    sum(is.na(preprocessed_df$APR.Risk.of.Mortality))
  )
)

### 3.4 Data Type Transformation

Following feature reduction, column type definitions are updated to reflect only the remaining columns. Categorical variables are converted to factors and numeric variables are correctly typed to ensure compatibility with the modelling stage.

# Update column lists to only include remaining columns
categorical_cols <- categorical_cols[categorical_cols %in% colnames(preprocessed_df)]
numeric_cols <- numeric_cols[numeric_cols %in% colnames(preprocessed_df)]

# Convert categorical columns to factor
preprocessed_df[categorical_cols] <- lapply(preprocessed_df[categorical_cols], as.factor)

# Convert numeric columns
preprocessed_df <- preprocessed_df %>%
  mutate(
    Length.of.Stay = as.integer(gsub("\\+", "", Length.of.Stay))
  )

# Verify
glimpse(preprocessed_df)
## Rows: 2,135,260
## Columns: 20
## $ Hospital.Service.Area            <fct> New York City, New York City, New Yor…
## $ Hospital.County                  <fct> Bronx, Bronx, Bronx, Bronx, Bronx, Br…
## $ Age.Group                        <fct> 70 or Older, 50 to 69, 18 to 29, 70 o…
## $ Gender                           <fct> M, F, F, M, F, M, F, M, M, F, F, M, M…
## $ Race                             <fct> Other Race, White, Other Race, Other …
## $ Ethnicity                        <fct> Spanish/Hispanic, Not Span/Hispanic, …
## $ Length.of.Stay                   <int> 27, 4, 2, 5, 3, 6, 1, 3, 21, 2, 3, 3,…
## $ Type.of.Admission                <fct> Emergency, Emergency, Emergency, Emer…
## $ Patient.Disposition              <fct> Home w/ Home Health Services, Home or…
## $ CCSR.Diagnosis.Code              <fct> INF012, NVS005, PRG016, GEN004, NVS00…
## $ CCSR.Procedure.Code              <fct> OTR004, NA, PGN003, ADM017, CNS002, I…
## $ APR.DRG.Code                     <fct> 137, 43, 540, 463, 58, 813, 55, 640, …
## $ APR.MDC.Code                     <fct> 4, 1, 14, 11, 1, 21, 1, 15, 18, 6, 4,…
## $ APR.Severity.of.Illness.Code     <fct> 3, 2, 1, 3, 2, 3, 1, 2, 2, 3, 3, 3, 1…
## $ APR.Risk.of.Mortality            <fct> Extreme, Minor, Minor, Major, Minor, …
## $ APR.Medical.Surgical.Description <fct> Medical, Medical, Surgical, Medical, …
## $ Payment.Typology.1               <fct> "Medicare", "Private Health Insurance…
## $ Payment.Typology.2               <fct> "Medicaid", NA, NA, "Medicaid", "Medi…
## $ Payment.Typology.3               <fct> NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ Emergency.Department.Indicator   <fct> Y, Y, N, Y, Y, Y, Y, N, Y, Y, Y, Y, Y…
cat("Memory usage:", format(object.size(preprocessed_df), units = "MB"), "\n")
## Memory usage: 163 Mb

3.5 Feature Engineering: Payment Typology

The three payment columns are combined into a single feature, coverage_count, representing the number of non-self payment sources available for each patient.

# Create coverage_count
preprocessed_df <- preprocessed_df %>%
  mutate(
    Has_Coverage_in_Primary   = ifelse(Payment.Typology.1 != "Self-Pay", 1, 0),
    Has_Coverage_in_Secondary = ifelse(!is.na(Payment.Typology.2) & Payment.Typology.2 != "Self-Pay", 1, 0),
    Has_Coverage_in_Tertiary  = ifelse(!is.na(Payment.Typology.3) & Payment.Typology.3 != "Self-Pay", 1, 0),
    coverage_count = Has_Coverage_in_Primary + Has_Coverage_in_Secondary + Has_Coverage_in_Tertiary
  ) %>%
  select(-c(Has_Coverage_in_Primary, Has_Coverage_in_Secondary, Has_Coverage_in_Tertiary)) %>%
  select(-any_of(c("Payment.Typology.1", "Payment.Typology.2", "Payment.Typology.3")))

# Verify coverage_count distribution
result <- table(preprocessed_df$coverage_count)
data.frame(
  Coverage_Count = names(result),
  Count = as.numeric(result)
)

3.6 Target Variable Separation

The two target variables are separated from the feature dataframe to prevent data leakage during modelling. A log1p transformation is applied to Length of Stay as the regression target.

# Separate target variables
target_df <- preprocessed_df %>%
  select(Length.of.Stay, Patient.Disposition) %>%
  mutate(LoS_log = log1p(Length.of.Stay))

# Remove target variables from feature dataframe
preprocessed_df <- preprocessed_df %>%
  select(-c(Length.of.Stay, Patient.Disposition))

# Verify dimensions
cat("Target dataframe dimensions:", dim(target_df), "\n")
## Target dataframe dimensions: 2135260 3
cat("Feature dataframe dimensions:", dim(preprocessed_df), "\n")
## Feature dataframe dimensions: 2135260 16
head(target_df)
colSums(is.na(preprocessed_df))[colSums(is.na(preprocessed_df)) > 0]
## CCSR.Procedure.Code 
##              583187

3.7 Preprocessing Summary

The preprocessing pipeline is now complete. The final dimensions are verified below.

# Overall verification
cat("=== Preprocessing Complete ===\n\n")
## === Preprocessing Complete ===
cat("Feature dataframe:\n")
## Feature dataframe:
cat("  Rows:", nrow(preprocessed_df), "\n")
##   Rows: 2135260
cat("  Cols:", ncol(preprocessed_df), "\n\n")
##   Cols: 16
cat("Target dataframe:\n")
## Target dataframe:
cat("  Rows:", nrow(target_df), "\n")
##   Rows: 2135260
cat("  Cols:", ncol(target_df), "\n\n")
##   Cols: 3
cat("Remaining features:\n")
## Remaining features:
print(colnames(preprocessed_df))
##  [1] "Hospital.Service.Area"            "Hospital.County"                 
##  [3] "Age.Group"                        "Gender"                          
##  [5] "Race"                             "Ethnicity"                       
##  [7] "Type.of.Admission"                "CCSR.Diagnosis.Code"             
##  [9] "CCSR.Procedure.Code"              "APR.DRG.Code"                    
## [11] "APR.MDC.Code"                     "APR.Severity.of.Illness.Code"    
## [13] "APR.Risk.of.Mortality"            "APR.Medical.Surgical.Description"
## [15] "Emergency.Department.Indicator"   "coverage_count"
cat("\nTarget variables:\n")
## 
## Target variables:
print(colnames(target_df))
## [1] "Length.of.Stay"      "Patient.Disposition" "LoS_log"

The dataset has been reduced from 33 columns to a clean, model-ready format:

  • Feature dataframe (preprocessed_df): 2,135,260 rows × 16 columns
  • Target dataframe (target_df): 2,135,260 rows × 3 columns (Length.of.Stay, Patient.Disposition, LoS_log)

Note: A consolidated Disposition class column will be added to target_df once the classification strategy is finalised in the modelling stage.

The preprocessed dataset is ready for the modelling stage.

4. Classification — Predicting Patient Disposition

This section predicts a patient’s discharge disposition class using the cleaned features created in Sections 1–3. The original Patient.Disposition variable has 19 discharge categories, so it is consolidated into four clinically meaningful classes:

  • Home
  • Facility
  • Hospice_Expired
  • Other

The classification workflow is designed for an imbalanced multi-class target. Therefore, the section reports not only accuracy, but also balanced accuracy, macro-F1, macro precision, macro recall, Cohen’s Kappa, per-class precision/recall/F1, confusion matrices, and feature importance.

Three Random Forest variants are compared:

  1. Baseline weighted Random Forest
  2. Balanced-accuracy tuned weighted Random Forest
  3. Downsampled Random Forest for minority-class recall improvement

The final model is selected programmatically using balanced accuracy first, then macro-F1 as the tie-breaker.

4.1 Setup & Target Consolidation

This cell loads the required packages, checks that the preprocessing objects exist, attaches the classification target, and consolidates the 19 raw discharge labels into four explicit classes. The guard stops execution if an unmapped label is found, so no category is silently pushed into the wrong class.

required_pkgs <- c(
  "tidyverse", "ranger", "rsample", "yardstick", "tune", "dials",
  "parsnip", "workflows", "recipes", "doParallel", "ggplot2"
)

missing_pkgs <- setdiff(required_pkgs, rownames(installed.packages()))
if (length(missing_pkgs) > 0) {
  install.packages(missing_pkgs, quiet = TRUE)
}

invisible(lapply(required_pkgs, library, character.only = TRUE))

set.seed(42)


options(yardstick.event_first = NULL)

n_cores <- max(1, parallel::detectCores(logical = TRUE) - 1)

# Guard: Sections 1-3 must be run before Section 4.
if (!exists("preprocessed_df")) {
  stop("preprocessed_df not found. Please run Sections 1-3 before running Section 4.")
}

if (!exists("target_df")) {
  stop("target_df not found. Please run Sections 1-3 before running Section 4.")
}

if (!"Patient.Disposition" %in% names(target_df)) {
  stop("Patient.Disposition is missing from target_df.")
}

home_labels <- c(
  "Home or Self Care",
  "Home w/ Home Health Services"
)

facility_labels <- c(
  "Skilled Nursing Home",
  "Medicaid Cert Nursing Facility",
  "Inpatient Rehabilitation Facility",
  "Short-term Hospital",
  "Federal Health Care Facility",
  "Psychiatric Hospital or Unit of Hosp",
  "Cancer Center or Children's Hospital",
  "Facility w/ Custodial/Supportive Care",
  "Hosp Basd Medicare Approved Swing Bed",
  "Medicare Cert Long Term Care Hospital",
  "Critical Access Hospital"
)

hospice_expired_labels <- c(
  "Hospice - Medical Facility",
  "Hospice - Home",
  "Expired"
)

other_labels <- c(
  "Left Against Medical Advice",
  "Court/Law Enforcement",
  "Another Type Not Listed"
)

mapped_labels <- c(home_labels, facility_labels, hospice_expired_labels, other_labels)

observed_labels <- target_df %>%
  transmute(Patient.Disposition = as.character(Patient.Disposition)) %>%
  filter(!is.na(Patient.Disposition)) %>%
  distinct(Patient.Disposition) %>%
  arrange(Patient.Disposition) %>%
  pull(Patient.Disposition)

unmapped_labels <- setdiff(observed_labels, mapped_labels)

if (length(unmapped_labels) > 0) {
  stop(
    "Unmapped Patient.Disposition label(s) found. Add these to the 19-to-4 mapping before modelling: ",
    paste(unmapped_labels, collapse = ", ")
  )
}

disposition_levels <- c("Home", "Facility", "Hospice_Expired", "Other")

model_df <- bind_cols(
  preprocessed_df,
  target_df %>% select(Patient.Disposition)
) %>%
  mutate(
    Disposition_Class = case_when(
      Patient.Disposition %in% home_labels ~ "Home",
      Patient.Disposition %in% facility_labels ~ "Facility",
      Patient.Disposition %in% hospice_expired_labels ~ "Hospice_Expired",
      Patient.Disposition %in% other_labels ~ "Other"
    ),
    Disposition_Class = factor(Disposition_Class, levels = disposition_levels)
  ) %>%
  select(-Patient.Disposition)

if (any(is.na(model_df$Disposition_Class))) {
  stop("NA values found in Disposition_Class after consolidation. Check the mapping.")
}

cat("Original Patient.Disposition categories:", length(observed_labels), "\n")
## Original Patient.Disposition categories: 19
cat("Model-ready rows:", nrow(model_df), "| columns:", ncol(model_df), "\n\n")
## Model-ready rows: 2135260 | columns: 17
class_distribution <- model_df %>%
  count(Disposition_Class, name = "n") %>%
  mutate(percent = round(100 * n / sum(n), 2))

cat("Consolidated 4-class distribution:\n")
## Consolidated 4-class distribution:
print(class_distribution)
## # A tibble: 4 × 3
##   Disposition_Class       n percent
##   <fct>               <int>   <dbl>
## 1 Home              1701462   79.7 
## 2 Facility           278073   13.0 
## 3 Hospice_Expired     86782    4.06
## 4 Other               68943    3.23

4.2 Stratified Sampling, Train/Test Split & Class Weights

The full dataset is large, so a stratified sample is used to keep Random Forest training practical while preserving the target distribution. The split is also stratified. Inverse-frequency class weights are created from the training set for the weighted Random Forest models.

sample_prop <- 0.10

sample_df <- model_df %>%
  group_by(Disposition_Class) %>%
  slice_sample(prop = sample_prop) %>%
  ungroup()

cat("Stratified sample size:", nrow(sample_df), "\n")
## Stratified sample size: 213525
cat("Sample class distribution:\n")
## Sample class distribution:
print(
  sample_df %>%
    count(Disposition_Class, name = "n") %>%
    mutate(percent = round(100 * n / sum(n), 2))
)
## # A tibble: 4 × 3
##   Disposition_Class      n percent
##   <fct>              <int>   <dbl>
## 1 Home              170146   79.7 
## 2 Facility           27807   13.0 
## 3 Hospice_Expired     8678    4.06
## 4 Other               6894    3.23
split <- initial_split(sample_df, prop = 0.80, strata = Disposition_Class)
train_df <- training(split)
test_df  <- testing(split)

cat("\nTraining rows:", nrow(train_df), "\n")
## 
## Training rows: 170819
cat("Testing rows :", nrow(test_df), "\n")
## Testing rows : 42706
class_counts <- table(train_df$Disposition_Class)
class_wts <- as.numeric(sum(class_counts) / (length(class_counts) * class_counts))
names(class_wts) <- names(class_counts)

cat("\nInverse-frequency class weights used for weighted RF:\n")
## 
## Inverse-frequency class weights used for weighted RF:
print(round(class_wts, 3))
##            Home        Facility Hospice_Expired           Other 
##           0.314           1.920           6.160           7.790

4.3 Evaluation Helper Functions

Accuracy alone is misleading for this problem because the Home class dominates the data. These helper functions produce:

  • overall metrics,
  • per-class metrics,
  • a clean classification report,
  • model-ranking tables,
  • confusion matrices,
  • and feature-importance tables.
make_results <- function(truth, estimate) {
  tibble(
    truth    = factor(truth, levels = disposition_levels),
    estimate = factor(estimate, levels = disposition_levels)
  )
}

safe_divide <- function(numerator, denominator) {
  ifelse(denominator == 0, NA_real_, numerator / denominator)
}

overall_metrics <- function(results, model_name) {
  bind_rows(
    accuracy(results,     truth, estimate),
    bal_accuracy(results, truth, estimate, event_level = "first"),
    f_meas(results,       truth, estimate, estimator = "macro", event_level = "first"),
    precision(results,    truth, estimate, estimator = "macro", event_level = "first"),
    recall(results,       truth, estimate, estimator = "macro", event_level = "first"),
    kap(results,          truth, estimate)
  ) %>%
    mutate(
      model = model_name,
      .before = .metric
    )
}

per_class_metrics <- function(results, model_name) {

  cm <- table(
    truth    = factor(results$truth,    levels = disposition_levels),
    estimate = factor(results$estimate, levels = disposition_levels)
  )

  cm <- as.matrix(cm)

  tp <- diag(cm)
  fp <- colSums(cm) - tp
  fn <- rowSums(cm) - tp
  tn <- sum(cm) - tp - fp - fn

  class_precision <- safe_divide(tp, tp + fp)
  class_recall    <- safe_divide(tp, tp + fn)
  class_f1        <- safe_divide(2 * class_precision * class_recall, class_precision + class_recall)
  class_specificity <- safe_divide(tn, tn + fp)

  tibble(
    model = model_name,
    class = disposition_levels,
    support = as.integer(rowSums(cm)),
    predicted = as.integer(colSums(cm)),
    precision = round(class_precision, 4),
    recall = round(class_recall, 4),
    f1 = round(class_f1, 4),
    specificity = round(class_specificity, 4)
  )
}

feature_importance_tbl <- function(importance_vector, model_name, n = 15) {
  tibble(
    model = model_name,
    feature = names(importance_vector),
    importance = as.numeric(importance_vector)
  ) %>%
    arrange(desc(importance)) %>%
    slice_head(n = n)
}

print_model_report <- function(results, model_name) {
  cat("\n============================================================\n")
  cat(model_name, "\n")
  cat("============================================================\n\n")

  cat("Overall metrics:\n")
  print(overall_metrics(results, model_name) %>% select(-model))

  cat("\nPer-class metrics:\n")
  print(per_class_metrics(results, model_name) %>% select(-model))

  cat("\nConfusion matrix:\n")
  print(conf_mat(results, truth, estimate))
}

4.4 Baseline Weighted Random Forest

The baseline model uses Random Forest with inverse-frequency class weights. This is the clean benchmark: no tuning, but imbalance is handled through class weights.

start_time <- Sys.time()

rf_baseline <- ranger(
  formula       = Disposition_Class ~ .,
  data          = train_df,
  num.trees     = 500,
  importance    = "impurity",
  class.weights = class_wts,
  num.threads   = n_cores,
  seed          = 42
)

baseline_time <- difftime(Sys.time(), start_time, units = "secs")

baseline_preds <- predict(rf_baseline, data = test_df)$predictions
results_baseline <- make_results(test_df$Disposition_Class, baseline_preds)

baseline_metrics <- overall_metrics(results_baseline, "Baseline weighted RF")
baseline_class_metrics <- per_class_metrics(results_baseline, "Baseline weighted RF")
baseline_importance <- feature_importance_tbl(
  rf_baseline$variable.importance,
  "Baseline weighted RF"
)

cat("Baseline training time:", round(baseline_time, 2), "seconds\n")
## Baseline training time: 24.85 seconds
print_model_report(results_baseline, "Baseline weighted RF")
## 
## ============================================================
## Baseline weighted RF 
## ============================================================
## 
## Overall metrics:
## # A tibble: 6 × 3
##   .metric      .estimator .estimate
##   <chr>        <chr>          <dbl>
## 1 accuracy     multiclass     0.791
## 2 bal_accuracy macro          0.673
## 3 f_meas       macro          0.507
## 4 precision    macro          0.543
## 5 recall       macro          0.503
## 6 kap          multiclass     0.365
## 
## Per-class metrics:
## # A tibble: 4 × 7
##   class           support predicted precision recall    f1 specificity
##   <chr>             <int>     <int>     <dbl>  <dbl> <dbl>       <dbl>
## 1 Home              33979     35150     0.872  0.902 0.887       0.486
## 2 Facility           5570      4499     0.409  0.331 0.366       0.928
## 3 Hospice_Expired    1745      2396     0.401  0.550 0.464       0.965
## 4 Other              1412       661     0.489  0.229 0.312       0.992
## 
## Confusion matrix:
##                  Truth
## Prediction         Home Facility Hospice_Expired Other
##   Home            30664     3077             434   975
##   Facility         2241     1842             339    77
##   Hospice_Expired   776      623             960    37
##   Other             298       28              12   323
cat("\nTop 15 baseline feature importances:\n")
## 
## Top 15 baseline feature importances:
print(baseline_importance)
## # A tibble: 15 × 3
##    model                feature                        importance
##    <chr>                <chr>                               <dbl>
##  1 Baseline weighted RF APR.DRG.Code                       16852.
##  2 Baseline weighted RF CCSR.Procedure.Code                12617.
##  3 Baseline weighted RF Hospital.County                    11598.
##  4 Baseline weighted RF CCSR.Diagnosis.Code                11570.
##  5 Baseline weighted RF APR.Risk.of.Mortality               6706.
##  6 Baseline weighted RF APR.Severity.of.Illness.Code        6544.
##  7 Baseline weighted RF Hospital.Service.Area               6366.
##  8 Baseline weighted RF Age.Group                           5895.
##  9 Baseline weighted RF APR.MDC.Code                        5504.
## 10 Baseline weighted RF coverage_count                      4476.
## 11 Baseline weighted RF Race                                4281.
## 12 Baseline weighted RF Ethnicity                           3647.
## 13 Baseline weighted RF Gender                              3058.
## 14 Baseline weighted RF Type.of.Admission                   2011.
## 15 Baseline weighted RF Emergency.Department.Indicator      1714.

4.5 Hyperparameter Tuning Using Balanced Accuracy

This cell tunes mtry and min_n using 5-fold cross-validation. The key correction is that the best model is selected using balanced accuracy, not plain F1/accuracy. Balanced accuracy is more appropriate here because each class contributes equally regardless of class frequency.

rf_tune_spec <- rand_forest(
  trees = 500,
  mtry  = tune(),
  min_n = tune()
) %>%
  set_engine(
    "ranger",
    num.threads   = n_cores,
    seed          = 42,
    importance    = "impurity",
    class.weights = class_wts
  ) %>%
  set_mode("classification")

rf_recipe <- recipe(Disposition_Class ~ ., data = train_df)

rf_wf <- workflow() %>%
  add_recipe(rf_recipe) %>%
  add_model(rf_tune_spec)

# Tuning on a stratified 30% training sample keeps runtime reasonable.
tune_prop <- 0.30

tune_df <- train_df %>%
  group_by(Disposition_Class) %>%
  slice_sample(prop = tune_prop) %>%
  ungroup()

cv_folds <- vfold_cv(tune_df, v = 5, strata = Disposition_Class)

p <- ncol(train_df) - 1

rf_grid <- grid_regular(
  mtry(range = c(max(1, floor(sqrt(p))), max(2, floor(p / 2)))),
  min_n(range = c(2, 20)),
  levels = c(4, 4)
)

cat("Tuning rows:", nrow(tune_df), "\n")
## Tuning rows: 51244
cat("Predictor count:", p, "\n")
## Predictor count: 16
cat("Tuning grid:\n")
## Tuning grid:
print(rf_grid)
## # A tibble: 16 × 2
##     mtry min_n
##    <int> <int>
##  1     4     2
##  2     5     2
##  3     6     2
##  4     8     2
##  5     4     8
##  6     5     8
##  7     6     8
##  8     8     8
##  9     4    14
## 10     5    14
## 11     6    14
## 12     8    14
## 13     4    20
## 14     5    20
## 15     6    20
## 16     8    20
options(yardstick.event_first = NULL)

rf_metrics <- metric_set(
  accuracy,
  bal_accuracy,
  f_meas,
  precision,
  recall
)

registerDoParallel(cores = n_cores)
start_time <- Sys.time()

tune_results <- tune_grid(
  rf_wf,
  resamples = cv_folds,
  grid      = rf_grid,
  metrics   = rf_metrics,
  control   = control_grid(verbose = TRUE, save_pred = FALSE)
)

tune_time <- difftime(Sys.time(), start_time, units = "secs")
stopImplicitCluster()

best_params <- select_best(tune_results, metric = "bal_accuracy")

cat("\nTuning time:", round(tune_time, 2), "seconds\n")
## 
## Tuning time: 790.33 seconds
cat("\nBest hyperparameters selected by balanced accuracy:\n")
## 
## Best hyperparameters selected by balanced accuracy:
print(best_params)
## # A tibble: 1 × 3
##    mtry min_n .config         
##   <int> <int> <chr>           
## 1     8     2 pre0_mod13_post0
cat("\nTop 10 tuning results by balanced accuracy:\n")
## 
## Top 10 tuning results by balanced accuracy:
print(show_best(tune_results, metric = "bal_accuracy", n = 10))
## # A tibble: 10 × 8
##     mtry min_n .metric      .estimator  mean     n std_err .config         
##    <int> <int> <chr>        <chr>      <dbl> <int>   <dbl> <chr>           
##  1     8     2 bal_accuracy macro      0.618     5 0.00174 pre0_mod13_post0
##  2     6     2 bal_accuracy macro      0.616     5 0.00218 pre0_mod09_post0
##  3     8     8 bal_accuracy macro      0.616     5 0.00214 pre0_mod14_post0
##  4     6     8 bal_accuracy macro      0.614     5 0.00232 pre0_mod10_post0
##  5     8    14 bal_accuracy macro      0.614     5 0.00256 pre0_mod15_post0
##  6     5     2 bal_accuracy macro      0.614     5 0.00229 pre0_mod05_post0
##  7     8    20 bal_accuracy macro      0.612     5 0.00341 pre0_mod16_post0
##  8     5     8 bal_accuracy macro      0.612     5 0.00299 pre0_mod06_post0
##  9     6    14 bal_accuracy macro      0.612     5 0.00276 pre0_mod11_post0
## 10     4     2 bal_accuracy macro      0.612     5 0.00235 pre0_mod01_post0
cat("\nTop 10 tuning results by macro-F1 for comparison only:\n")
## 
## Top 10 tuning results by macro-F1 for comparison only:
print(show_best(tune_results, metric = "f_meas", n = 10))
## # A tibble: 10 × 8
##     mtry min_n .metric .estimator  mean     n std_err .config         
##    <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>           
##  1     8     8 f_meas  macro      0.466     5 0.00493 pre0_mod14_post0
##  2     8    14 f_meas  macro      0.465     5 0.00523 pre0_mod15_post0
##  3     6     8 f_meas  macro      0.464     5 0.00493 pre0_mod10_post0
##  4     6     2 f_meas  macro      0.464     5 0.00509 pre0_mod09_post0
##  5     8    20 f_meas  macro      0.463     5 0.00678 pre0_mod16_post0
##  6     8     2 f_meas  macro      0.463     5 0.00419 pre0_mod13_post0
##  7     6    14 f_meas  macro      0.463     5 0.00567 pre0_mod11_post0
##  8     5     8 f_meas  macro      0.462     5 0.00597 pre0_mod06_post0
##  9     6    20 f_meas  macro      0.462     5 0.00627 pre0_mod12_post0
## 10     5     2 f_meas  macro      0.462     5 0.00491 pre0_mod05_post0

4.6 Final Tuned Weighted Random Forest

The tuned workflow is finalised with the best parameters selected by balanced accuracy, then refit on the full training set and evaluated on the untouched test set.

final_wf <- finalize_workflow(rf_wf, best_params)

start_time <- Sys.time()
final_fit <- fit(final_wf, data = train_df)
final_time <- difftime(Sys.time(), start_time, units = "secs")

tuned_preds <- predict(final_fit, new_data = test_df)$.pred_class
results_tuned <- make_results(test_df$Disposition_Class, tuned_preds)

tuned_metrics <- overall_metrics(results_tuned, "Tuned weighted RF")
tuned_class_metrics <- per_class_metrics(results_tuned, "Tuned weighted RF")

tuned_engine <- extract_fit_engine(final_fit)
tuned_importance <- feature_importance_tbl(
  tuned_engine$variable.importance,
  "Tuned weighted RF"
)

cat("Final tuned training time:", round(final_time, 2), "seconds\n")
## Final tuned training time: 69.46 seconds
print_model_report(results_tuned, "Tuned weighted RF")
## 
## ============================================================
## Tuned weighted RF 
## ============================================================
## 
## Overall metrics:
## # A tibble: 6 × 3
##   .metric      .estimator .estimate
##   <chr>        <chr>          <dbl>
## 1 accuracy     multiclass     0.813
## 2 bal_accuracy macro          0.635
## 3 f_meas       macro          0.492
## 4 precision    macro          0.629
## 5 recall       macro          0.446
## 6 kap          multiclass     0.333
## 
## Per-class metrics:
## # A tibble: 4 × 7
##   class           support predicted precision recall    f1 specificity
##   <chr>             <int>     <int>     <dbl>  <dbl> <dbl>       <dbl>
## 1 Home              33979     38014     0.852  0.953 0.900       0.356
## 2 Facility           5570      2935     0.460  0.242 0.318       0.957
## 3 Hospice_Expired    1745      1348     0.496  0.383 0.432       0.983
## 4 Other              1412       409     0.707  0.205 0.317       0.997
## 
## Confusion matrix:
##                  Truth
## Prediction         Home Facility Hospice_Expired Other
##   Home            32393     3859             713  1049
##   Facility         1171     1351             357    56
##   Hospice_Expired   319      343             668    18
##   Other              96       17               7   289
cat("\nTop 15 tuned feature importances:\n")
## 
## Top 15 tuned feature importances:
print(tuned_importance)
## # A tibble: 15 × 3
##    model             feature                        importance
##    <chr>             <chr>                               <dbl>
##  1 Tuned weighted RF APR.DRG.Code                       18150.
##  2 Tuned weighted RF Hospital.County                    14330.
##  3 Tuned weighted RF CCSR.Procedure.Code                14124.
##  4 Tuned weighted RF CCSR.Diagnosis.Code                12316.
##  5 Tuned weighted RF APR.Risk.of.Mortality               7829.
##  6 Tuned weighted RF Hospital.Service.Area               7186.
##  7 Tuned weighted RF APR.Severity.of.Illness.Code        6308.
##  8 Tuned weighted RF coverage_count                      6003.
##  9 Tuned weighted RF Age.Group                           5846.
## 10 Tuned weighted RF Race                                5309.
## 11 Tuned weighted RF APR.MDC.Code                        5022.
## 12 Tuned weighted RF Ethnicity                           4371.
## 13 Tuned weighted RF Gender                              3909.
## 14 Tuned weighted RF Type.of.Admission                   2294.
## 15 Tuned weighted RF Emergency.Department.Indicator      1857.

4.7 Downsampled Random Forest for Minority-Class Recall

This is the minority-recall lever. The majority classes are downsampled before training so that the model pays more attention to rare classes such as Hospice_Expired and Other. This model may reduce overall accuracy, but it can improve minority-class recall and balanced accuracy.

Adjust downsample_majority_multiplier if needed:

  • 1 = fully balanced classes
  • 2 = keeps up to twice the minority-class count per class
  • higher values = less aggressive downsampling
downsample_majority_multiplier <- 2

minority_n <- min(table(train_df$Disposition_Class))
max_n_per_class <- minority_n * downsample_majority_multiplier

set.seed(42)

train_down_df <- train_df %>%
  group_by(Disposition_Class) %>%
  group_modify(~ {
    slice_sample(.x, n = min(nrow(.x), max_n_per_class))
  }) %>%
  ungroup()

train_down_df$Disposition_Class <- factor(
  train_down_df$Disposition_Class,
  levels = levels(train_df$Disposition_Class)
)

cat("Downsampled training rows:", nrow(train_down_df), "\n")
## Downsampled training rows: 34343
cat("Downsampled class distribution:\n")
## Downsampled class distribution:
print(
  train_down_df %>%
    count(Disposition_Class, name = "n") %>%
    mutate(percent = round(100 * n / sum(n), 2))
)
## # A tibble: 4 × 3
##   Disposition_Class     n percent
##   <fct>             <int>   <dbl>
## 1 Home              10964    31.9
## 2 Facility          10964    31.9
## 3 Hospice_Expired    6933    20.2
## 4 Other              5482    16.0
start_time <- Sys.time()

rf_downsampled <- ranger(
  formula     = Disposition_Class ~ .,
  data        = train_down_df,
  num.trees   = 500,
  importance  = "impurity",
  num.threads = n_cores,
  seed        = 42
)

downsampled_time <- difftime(Sys.time(), start_time, units = "secs")

downsampled_preds <- predict(rf_downsampled, data = test_df)$predictions

results_downsampled <- make_results(
  test_df$Disposition_Class,
  downsampled_preds
)

downsampled_metrics <- overall_metrics(
  results_downsampled,
  "Downsampled RF"
)

downsampled_class_metrics <- per_class_metrics(
  results_downsampled,
  "Downsampled RF"
)

downsampled_importance <- feature_importance_tbl(
  rf_downsampled$variable.importance,
  "Downsampled RF"
)

cat("Downsampled RF training time:", round(downsampled_time, 2), "seconds\n")
## Downsampled RF training time: 6.75 seconds
print_model_report(
  results_downsampled,
  "Downsampled RF"
)
## 
## ============================================================
## Downsampled RF 
## ============================================================
## 
## Overall metrics:
## # A tibble: 6 × 3
##   .metric      .estimator .estimate
##   <chr>        <chr>          <dbl>
## 1 accuracy     multiclass     0.632
## 2 bal_accuracy macro          0.745
## 3 f_meas       macro          0.462
## 4 precision    macro          0.430
## 5 recall       macro          0.620
## 6 kap          multiclass     0.305
## 
## Per-class metrics:
## # A tibble: 4 × 7
##   class           support predicted precision recall    f1 specificity
##   <chr>             <int>     <int>     <dbl>  <dbl> <dbl>       <dbl>
## 1 Home              33979     22948     0.943  0.637 0.760       0.851
## 2 Facility           5570     11489     0.290  0.599 0.391       0.780
## 3 Hospice_Expired    1745      3985     0.308  0.704 0.429       0.933
## 4 Other              1412      4284     0.178  0.541 0.268       0.915
## 
## Confusion matrix:
##                  Truth
## Prediction         Home Facility Hospice_Expired Other
##   Home            21644      905              66   333
##   Facility         7492     3334             426   237
##   Hospice_Expired  1624     1055            1228    78
##   Other            3219      276              25   764
cat("\nTop 15 downsampled feature importances:\n")
## 
## Top 15 downsampled feature importances:
print(downsampled_importance)
## # A tibble: 15 × 3
##    model          feature                        importance
##    <chr>          <chr>                               <dbl>
##  1 Downsampled RF APR.DRG.Code                        2790.
##  2 Downsampled RF CCSR.Diagnosis.Code                 2537.
##  3 Downsampled RF CCSR.Procedure.Code                 2354.
##  4 Downsampled RF Hospital.County                     2151.
##  5 Downsampled RF Age.Group                           2044.
##  6 Downsampled RF APR.Risk.of.Mortality               1935.
##  7 Downsampled RF APR.Severity.of.Illness.Code        1787.
##  8 Downsampled RF APR.MDC.Code                        1380.
##  9 Downsampled RF Hospital.Service.Area               1269.
## 10 Downsampled RF coverage_count                       869.
## 11 Downsampled RF Race                                 834.
## 12 Downsampled RF Ethnicity                            644.
## 13 Downsampled RF Gender                               617.
## 14 Downsampled RF Emergency.Department.Indicator       494.
## 15 Downsampled RF Type.of.Admission                    447.

4.8 Model Comparison & Programmatic Final Selection

The final model is selected using this logic:

  1. maximise balanced accuracy,
  2. if close, prefer higher macro-F1,
  3. then prefer higher macro recall,
  4. then higher accuracy.

This avoids manually choosing a model based only on accuracy.

all_overall_metrics <- bind_rows(
  baseline_metrics,
  tuned_metrics,
  downsampled_metrics
)

all_class_metrics <- bind_rows(
  baseline_class_metrics,
  tuned_class_metrics,
  downsampled_class_metrics
)

training_time_tbl <- tibble(
  model = c("Baseline weighted RF", "Tuned weighted RF", "Downsampled RF"),
  training_or_tuning_time_sec = c(
    as.numeric(baseline_time),
    as.numeric(tune_time) + as.numeric(final_time),
    as.numeric(downsampled_time)
  )
)

comparison_tbl <- all_overall_metrics %>%
  select(model, .metric, .estimate) %>%
  mutate(.estimate = round(.estimate, 4)) %>%
  pivot_wider(names_from = .metric, values_from = .estimate) %>%
  left_join(training_time_tbl, by = "model") %>%
  arrange(desc(bal_accuracy), desc(f_meas), desc(recall), desc(accuracy))

cat("Overall model comparison:\n")
## Overall model comparison:
print(comparison_tbl)
## # A tibble: 3 × 8
##   model                accuracy bal_accuracy f_meas precision recall   kap
##   <chr>                   <dbl>        <dbl>  <dbl>     <dbl>  <dbl> <dbl>
## 1 Downsampled RF          0.632        0.745  0.462     0.43   0.620 0.305
## 2 Baseline weighted RF    0.791        0.673  0.507     0.543  0.503 0.364
## 3 Tuned weighted RF       0.813        0.635  0.492     0.629  0.446 0.333
## # ℹ 1 more variable: training_or_tuning_time_sec <dbl>
cat("\nPer-class classification report:\n")
## 
## Per-class classification report:
print(
  all_class_metrics %>%
    arrange(model, class)
)
## # A tibble: 12 × 8
##    model              class support predicted precision recall    f1 specificity
##    <chr>              <chr>   <int>     <int>     <dbl>  <dbl> <dbl>       <dbl>
##  1 Baseline weighted… Faci…    5570      4499     0.409  0.331 0.366       0.928
##  2 Baseline weighted… Home    33979     35150     0.872  0.902 0.887       0.486
##  3 Baseline weighted… Hosp…    1745      2396     0.401  0.550 0.464       0.965
##  4 Baseline weighted… Other    1412       661     0.489  0.229 0.312       0.992
##  5 Downsampled RF     Faci…    5570     11489     0.290  0.599 0.391       0.780
##  6 Downsampled RF     Home    33979     22948     0.943  0.637 0.760       0.851
##  7 Downsampled RF     Hosp…    1745      3985     0.308  0.704 0.429       0.933
##  8 Downsampled RF     Other    1412      4284     0.178  0.541 0.268       0.915
##  9 Tuned weighted RF  Faci…    5570      2935     0.460  0.242 0.318       0.957
## 10 Tuned weighted RF  Home    33979     38014     0.852  0.953 0.900       0.356
## 11 Tuned weighted RF  Hosp…    1745      1348     0.496  0.383 0.432       0.983
## 12 Tuned weighted RF  Other    1412       409     0.707  0.205 0.317       0.997
selected_model_name <- comparison_tbl %>%
  slice(1) %>%
  pull(model)

candidate_results <- list(
  "Baseline weighted RF" = results_baseline,
  "Tuned weighted RF"    = results_tuned,
  "Downsampled RF"       = results_downsampled
)

candidate_importance <- list(
  "Baseline weighted RF" = baseline_importance,
  "Tuned weighted RF"    = tuned_importance,
  "Downsampled RF"       = downsampled_importance
)

selected_results <- candidate_results[[selected_model_name]]
selected_importance <- candidate_importance[[selected_model_name]]

cat("\n============================================================\n")
## 
## ============================================================
cat("FINAL SELECTED MODEL:", selected_model_name, "\n")
## FINAL SELECTED MODEL: Downsampled RF
cat("Selection rule: highest balanced accuracy, then macro-F1, then macro recall.\n")
## Selection rule: highest balanced accuracy, then macro-F1, then macro recall.
cat("============================================================\n\n")
## ============================================================
selected_overall_metrics <- all_overall_metrics %>%
  filter(model == selected_model_name)

selected_class_metrics <- all_class_metrics %>%
  filter(model == selected_model_name)

cat("Selected model overall metrics:\n")
## Selected model overall metrics:
print(selected_overall_metrics)
## # A tibble: 6 × 4
##   model          .metric      .estimator .estimate
##   <chr>          <chr>        <chr>          <dbl>
## 1 Downsampled RF accuracy     multiclass     0.632
## 2 Downsampled RF bal_accuracy macro          0.745
## 3 Downsampled RF f_meas       macro          0.462
## 4 Downsampled RF precision    macro          0.430
## 5 Downsampled RF recall       macro          0.620
## 6 Downsampled RF kap          multiclass     0.305
cat("\nSelected model per-class metrics:\n")
## 
## Selected model per-class metrics:
print(selected_class_metrics)
## # A tibble: 4 × 8
##   model          class      support predicted precision recall    f1 specificity
##   <chr>          <chr>        <int>     <int>     <dbl>  <dbl> <dbl>       <dbl>
## 1 Downsampled RF Home         33979     22948     0.943  0.637 0.760       0.851
## 2 Downsampled RF Facility      5570     11489     0.290  0.599 0.391       0.780
## 3 Downsampled RF Hospice_E…    1745      3985     0.308  0.704 0.429       0.933
## 4 Downsampled RF Other         1412      4284     0.178  0.541 0.268       0.915
cat("\nSelected model confusion matrix:\n")
## 
## Selected model confusion matrix:
print(conf_mat(selected_results, truth, estimate))
##                  Truth
## Prediction         Home Facility Hospice_Expired Other
##   Home            21644      905              66   333
##   Facility         7492     3334             426   237
##   Hospice_Expired  1624     1055            1228    78
##   Other            3219      276              25   764

4.9 Final Visualisations

The selected model is visualised using a confusion-matrix heatmap and top feature importances. These visuals should be used for the report discussion because they reflect the model selected by the programmatic comparison above.

autoplot(conf_mat(selected_results, truth, estimate), type = "heatmap") +
  scale_fill_gradient(low = "#E8F1FA", high = "#1F4E79") +
  labs(
    title = paste("Confusion Matrix -", selected_model_name),
    x = "Predicted Class",
    y = "Actual Class"
  ) +
  theme_minimal()

ggplot(selected_importance, aes(x = reorder(feature, importance), y = importance)) +
  geom_col(fill = "#2E86AB") +
  coord_flip() +
  labs(
    title = paste("Top Feature Importance -", selected_model_name),
    x = NULL,
    y = "Importance"
  ) +
  theme_minimal()

5. Regression Modelling

library(caret)
feature_df <- preprocessed_df
lr_df <- cbind(feature_df, LoS_log = target_df$LoS_log, Length.of.Stay = target_df$Length.of.Stay)

cat_cols <- names(lr_df)[sapply(lr_df, function(x) is.factor(x) || is.character(x))]

for (col in cat_cols) {
  counts <- table(lr_df[[col]])

  max_count <- max(counts)
  keep_categories <- names(counts[counts / max_count >= 0.3])

  # convert to character first
  lr_df[[col]] <- as.character(lr_df[[col]])

  lr_df[[col]][!(lr_df[[col]] %in% keep_categories)] <- "Other"

  # convert back to factor
  lr_df[[col]] <- as.factor(lr_df[[col]])
}

5.1 Train, Test, Validation Splits

train_idx <- createDataPartition(lr_df$LoS_log, p = 0.7, list = FALSE)

train <- lr_df[train_idx, ]
test  <- lr_df[-train_idx, ]

X_train <- train[, colnames(feature_df)]
y_train <- train$LoS_log

X_test <- test[, colnames(feature_df)]
y_test <- test$LoS_log
y_test_real <- test$Length.of.Stay
cat("Train set:", nrow(train), "\n")
## Train set: 1494684
cat("Test set:", nrow(test), "\n")
## Test set: 640576

5.2 Training

model_lm <- lm(y_train ~ ., data = X_train)
pred_test <- predict(model_lm, newdata = X_test)
# convert log predictions to real scale
pred_test_real <- exp(pred_test)

log_metrics  <- postResample(pred = pred_test, obs = y_test)
real_metrics <- postResample(pred = pred_test_real, obs = y_test_real)

results <- data.frame(
  Model = "Full Model",
  RMSE_log = log_metrics["RMSE"],
  R2_log = log_metrics["Rsquared"],
  RMSE_real = real_metrics["RMSE"],
  R2_real = real_metrics["Rsquared"]
)

cat("=== RESULT IN LOG SCALE ===\n")
## === RESULT IN LOG SCALE ===
print(log_metrics)
##      RMSE  Rsquared       MAE 
## 0.5859208 0.3485062 0.4454313
cat("\n=== RESULT IN REAL SCALE ===\n")
## 
## === RESULT IN REAL SCALE ===
print(real_metrics)
##      RMSE  Rsquared       MAE 
## 7.4508278 0.2420682 3.5233430
# Get coefficients (remove intercept)
coefs <- coef(model_lm)[-1]

# Feature names from model
feature_names <- names(coefs)

# Original feature names BEFORE encoding
orig_features <- colnames(X_train)

# Map each dummy variable back to original feature
mapped_features <- sapply(feature_names, function(x) {

  hit <- orig_features[sapply(orig_features, function(f) grepl(f, x, fixed = TRUE))]

  if (length(hit) == 0) {
    return(x)
  } else {
    return(hit[1])
  }
})
# Group importance (sum of absolute coefficients)
group_importance <- aggregate(
  abs(coefs),
  by = list(mapped_features),
  FUN = function(x) sum(x, na.rm = TRUE)
)

colnames(group_importance) <- c("Feature", "Importance")

group_importance <- group_importance[!is.na(group_importance$Feature), ]
group_importance <- group_importance[is.finite(group_importance$Importance), ]

# Sort descending
group_importance <- group_importance[order(group_importance$Importance, decreasing = TRUE), ]
# Plot
ggplot(group_importance, aes(x = reorder(Feature, Importance), y = Importance)) +
  geom_col(fill = "steelblue") +
  coord_flip() +
  labs(
    title = "Feature Importance (Linear Regression)",
    x = "Feature",
    y = "Sum of Absolute Coefficients"
  ) +
  theme_minimal()

top10_features <- head(group_importance$Feature, 10)

X_train_10 <- X_train[, top10_features]
X_test_10  <- X_test[, top10_features]

model_lm_10 <- lm(y_train ~ ., data = X_train_10)
pred_test_10 <- predict(model_lm_10, newdata = X_test_10)
# convert log predictions to real scale
pred_test_real_10 <- exp(pred_test_10)

log_metrics_top10  <- postResample(pred = pred_test_10, obs = y_test)
real_metrics_top10 <- postResample(pred = pred_test_real_10, obs = y_test_real)

results <- rbind(results, data.frame(
  Model = "Top 10 Model",
  RMSE_log = log_metrics_top10["RMSE"],
  R2_log = log_metrics_top10["Rsquared"],
  RMSE_real = real_metrics_top10["RMSE"],
  R2_real = real_metrics_top10["Rsquared"]
))
top5_features <- head(group_importance$Feature, 5)

X_train_5 <- X_train[, top5_features]
X_test_5  <- X_test[, top5_features]

model_lm_5 <- lm(y_train ~ ., data = X_train_5)
pred_test_5 <- predict(model_lm_5, newdata = X_test_5)
# convert log predictions to real scale
pred_test_real_5 <- exp(pred_test_5)

log_metrics_top5  <- postResample(pred = pred_test_5, obs = y_test)
real_metrics_top5 <- postResample(pred = pred_test_real_5, obs = y_test_real)

results <- rbind(results, data.frame(
  Model = "Top 5 Model",
  RMSE_log = log_metrics_top5["RMSE"],
  R2_log = log_metrics_top5["Rsquared"],
  RMSE_real = real_metrics_top5["RMSE"],
  R2_real = real_metrics_top5["Rsquared"]
))
cat("=== COMPARSION W/ FEATURE SELECTION ===\n")
## === COMPARSION W/ FEATURE SELECTION ===
print(results, row.names = FALSE)
##         Model  RMSE_log    R2_log RMSE_real   R2_real
##    Full Model 0.5859208 0.3485062  7.450828 0.2420682
##  Top 10 Model 0.5864480 0.3473334  7.455907 0.2407064
##   Top 5 Model 0.5937702 0.3309335  7.496754 0.2317092