Question of interest

Parkinson’s disease (PD) is a progressive neurodegenerative disorder affecting millions of people worldwide, characterized by motor symptoms such as tremor, rigidity, and bradykinesia, as well as non-motor manifestations including cognitive decline and autonomic dysfunction. One of the major clinical challenges with Parkinson’s disease is the lack of reliable, objective set of biomarkers for early diagnosis. Ususally, patients are often diagnosed only after significant neurodegeneration has already occurred with physical tests or observing lesions in specific brain regions.

This project aims to investigate whether cerebrospinal fluid (CSF) proteomics data can be used to accurately predict Parkinson’s disease status. Specifically, Can protein expression profiles from CSF samples reliably distinguish Parkinson’s disease patients from healthy controls, and which proteins are most predictive of disease status?

Data

The data used in this project comes from the Parkinson’s Progression Markers Initiative (PPMI), a landmark longitudinal observational study sponsored by the Michael J. Fox Foundation. PPMI was designed to identify biomarkers of Parkinson’s disease onset and progression by recruiting both PD patients and healthy controls and collecting biological samples, imaging, and clinical assessments over time. https://www.ppmi-info.org/

Clinical Data:

clinical_dir <- "/Users/ramazanegesolak/PPMI_STRING"
age_data           <- read_csv(file.path(clinical_dir, "Age_at_visit_12Apr2026.csv"),                 show_col_types = FALSE)
clinical_dx        <- read_csv(file.path(clinical_dir, "Clinical_Diagnosis_12Apr2026.csv"),            show_col_types = FALSE)
demographics       <- read_csv(file.path(clinical_dir, "Demographics_12Apr2026.csv"),                  show_col_types = FALSE)
genetics           <- read_csv(file.path(clinical_dir, "iu_genetic_consensus_20251025_12Apr2026.csv"), show_col_types = FALSE)
## Warning: One or more parsing issues, call `problems()` on your data frame for details,
## e.g.:
##   dat <- vroom(...)
##   problems(dat)
participant_status <- read_csv(file.path(clinical_dir, "Participant_Status_12Apr2026.csv"),            show_col_types = FALSE)

# Quick check all loaded correctly
cat("age_data:          ", nrow(age_data),           "rows\n")
## age_data:           46227 rows
cat("clinical_dx:       ", nrow(clinical_dx),        "rows\n")
## clinical_dx:        15440 rows
cat("demographics:      ", nrow(demographics),       "rows\n")
## demographics:       8351 rows
cat("genetics:          ", nrow(genetics),           "rows\n")
## genetics:           6265 rows
cat("participant_status:", nrow(participant_status), "rows\n")
## participant_status: 8417 rows

Variables of interest

Directed acyclic graph (DAG)

library(ggdag)
library(ggplot2)

dag <- dagify(
  Outcome   ~ Predictor + SEX + AGE + Genetics,
  Predictor ~ SEX + AGE + Genetics + Batch,
  coords = list(
    x = c(SEX      = 1,   AGE      = 1,    Genetics = 1,
          Batch    = 1,   Predictor = 2.8,  Outcome  = 4.5),
    y = c(SEX      = 1.5, AGE      = 0.5,  Genetics = -0.5,
          Batch    = -1.5, Predictor = 0,   Outcome  = 0)
  )
)

ggdag(dag, text = FALSE) +
  geom_dag_point(color = "#4472C4", size = 22) +
  geom_dag_text(color = "white", size = 3.2, fontface = "bold") +
  theme_dag() +
  labs(
    title   = "DAG: CSF Proteomics and Parkinson's Disease Prediction",
    caption = paste(
      "Outcome   = PD Diagnosis (COHORT)",
      "Predictor = CSF Protein Expression (TESTVALUE)",
      "SEX       = Biological sex — affects PD prevalence & protein levels",
      "AGE       = Age at visit — risk factor for PD & protein expression",
      "Genetics  = Mutation carrier status (LRRK2, GBA, SNCA, PRKN, APOE-e4)",
      "Batch     = Assay plate (PLATEID) — technical measurement confounder",
      sep = "\n"
    )
  ) +
  theme(
    plot.title   = element_text(face = "bold", size = 13, hjust = 0.5),
    plot.caption = element_text(size = 8.5, color = "gray30", hjust = 0,
                                margin = margin(t = 12), family = "mono")
  )

Statistical analysis plan

Because our outcome variable is binary, to answer the question of interest, three classification approaches will be built and compared: Logistic Regression, LASSO, and Random Forest. By comparing these models, this project will evaluate both predictive accuracy and the interpretability of selected biomarkers.

To control for potential confounding, sex will be included as a biological covariate across all three models, as protein expression levels are known to vary by sex and Parkinson’s disease prevalence differs between males and females. PLATEID will be examined and adjusted for as a technical confounder, since batch-to-batch variability across assay plates can introduce systematic differences in protein measurements unrelated to disease status. Protein expression values will also be normalized prior to modeling to further reduce technical variation. Model performance will be evaluated using AUC-ROC, accuracy and a confusion matrix.

Exploratory data analysis:

wanted_cols <- c(
  "PATNO", "SEX", "COHORT", "TESTNAME",
  "TESTVALUE", "PLATEID"
)
raw_data <- file_list %>%
  map_dfr(~read_csv(.x, show_col_types = FALSE)) %>%
  select(any_of(wanted_cols))
glimpse(raw_data)
## Rows: 5,541,030
## Columns: 6
## $ PATNO     <dbl> 53595, 53595, 53595, 53595, 53595, 53595, 53595, 53595, 5359…
## $ SEX       <chr> "Female", "Female", "Female", "Female", "Female", "Female", …
## $ COHORT    <chr> "PD", "PD", "PD", "PD", "PD", "PD", "PD", "PD", "PD", "PD", …
## $ TESTNAME  <chr> "5632-6_3", "5631-83_3", "5630-48_3", "5629-58_3", "5628-21_…
## $ TESTVALUE <dbl> 10.182013, 6.346285, 12.943830, 5.891402, 12.836350, 7.03480…
## $ PLATEID   <chr> "P0022899", "P0022899", "P0022899", "P0022899", "P0022899", …
head(raw_data, 20)
## # A tibble: 20 × 6
##    PATNO SEX    COHORT TESTNAME  TESTVALUE PLATEID 
##    <dbl> <chr>  <chr>  <chr>         <dbl> <chr>   
##  1 53595 Female PD     5632-6_3      10.2  P0022899
##  2 53595 Female PD     5631-83_3      6.35 P0022899
##  3 53595 Female PD     5630-48_3     12.9  P0022899
##  4 53595 Female PD     5629-58_3      5.89 P0022899
##  5 53595 Female PD     5628-21_3     12.8  P0022899
##  6 53595 Female PD     5627-53_3      7.03 P0022899
##  7 53595 Female PD     5626-20_3      5.89 P0022899
##  8 53595 Female PD     5624-66_3      7.66 P0022899
##  9 53595 Female PD     5623-11_3      6.13 P0022899
## 10 53595 Female PD     5621-64_3      5.95 P0022899
## 11 53595 Female PD     5620-13_3     10.1  P0022899
## 12 53595 Female PD     5618-50_3     13.9  P0022899
## 13 53595 Female PD     5617-41_3      8.84 P0022899
## 14 53595 Female PD     5615-62_3      6.25 P0022899
## 15 53595 Female PD     5614-44_3      6.73 P0022899
## 16 53595 Female PD     5613-75_3      5.87 P0022899
## 17 53595 Female PD     5612-16_3      6.20 P0022899
## 18 53595 Female PD     5611-56_3      6.02 P0022899
## 19 53595 Female PD     5610-32_3      6.38 P0022899
## 20 53595 Female PD     5609-92_3     14.1  P0022899
raw_data %>% count(COHORT)
## # A tibble: 3 × 2
##   COHORT          n
##   <chr>       <int>
## 1 Control    890010
## 2 PD        2952345
## 3 Prodromal 1698675
raw_data %>% summarise(
  n_patients = n_distinct(PATNO),
  n_proteins = n_distinct(TESTNAME),
  n_rows     = n()
)
## # A tibble: 1 × 3
##   n_patients n_proteins  n_rows
##        <int>      <int>   <int>
## 1       1158       4785 5541030
# Check for duplicate patient-protein combinations
raw_data %>%
  count(PATNO, TESTNAME) %>%
  filter(n > 1) %>%
  nrow()
## [1] 0
annotations <- read_csv(
  file.path(data_dir, "PPMI_Project_151_pqtl_Analysis_Annotations_20210210.csv"),
  show_col_types = FALSE
)
## New names:
## • `` -> `...3`
## • `` -> `...8`
## • `` -> `...9`
## • `` -> `...10`
## • `` -> `...11`
## • `` -> `...12`
## • `` -> `...13`
## • `` -> `...14`
## • `` -> `...15`
## • `` -> `...16`
## • `` -> `...17`
## • `` -> `...18`
## • `` -> `...19`
## • `` -> `...20`
## • `` -> `...21`
## • `` -> `...22`
## • `` -> `...23`
## • `` -> `...24`
## • `` -> `...25`
## • `` -> `...26`
## • `` -> `...27`
## • `` -> `...28`
## • `` -> `...29`
## • `` -> `...30`
## • `` -> `...31`
## • `` -> `...32`
## • `` -> `...33`
## • `` -> `...34`
## • `` -> `...35`
## • `` -> `...36`
## • `` -> `...37`
## • `` -> `...38`
## • `` -> `...39`
## • `` -> `...40`
## • `` -> `...41`
## • `` -> `...42`
## • `` -> `...43`
## • `` -> `...44`
## • `` -> `...45`
## • `` -> `...46`
## • `` -> `...47`
## • `` -> `...48`
## • `` -> `...49`
## • `` -> `...50`
## • `` -> `...51`
## • `` -> `...52`
## • `` -> `...53`
## • `` -> `...54`
## • `` -> `...55`
## • `` -> `...56`
## • `` -> `...57`
## • `` -> `...58`
## • `` -> `...59`
## • `` -> `...60`
## • `` -> `...61`
protein_lookup <- annotations %>%
  select(SOMA_SEQ_ID, TARGET_GENE_SYMBOL) %>%
  distinct() %>%
  filter(!is.na(TARGET_GENE_SYMBOL), TARGET_GENE_SYMBOL != "")
head(protein_lookup)
## # A tibble: 6 × 2
##   SOMA_SEQ_ID TARGET_GENE_SYMBOL
##   <chr>       <chr>             
## 1 10000-28_3  CRYBB2            
## 2 10001-7_3   RAF1              
## 3 10003-15_3  ZNF41             
## 4 10006-25_3  ELK1              
## 5 10008-43_3  GUCA1A            
## 6 10009-2_3   IRF1
raw_data %>%
  distinct(PATNO, COHORT) %>%
  count(COHORT) %>%
  ggplot(aes(x = reorder(COHORT, -n), y = n, fill = COHORT)) +
  geom_col(width = 0.5, show.legend = FALSE) +
  geom_text(aes(label = n), vjust = -0.5, fontface = "bold", size = 5) +
  scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
  labs(title = "Number of Patients by Cohort", x = "", y = "Count") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13))

raw_data %>%
  filter(COHORT %in% c("PD", "Control")) %>%
  ggplot(aes(x = TESTVALUE, fill = COHORT)) +
  geom_density(alpha = 0.5) +
  scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
  labs(title = "Distribution of Protein Expression Values",
       x = "Expression Value (log2)", y = "Density", fill = "") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13),
        legend.position = "bottom")

raw_data %>%
  filter(COHORT %in% c("PD", "Control")) %>%
  group_by(PATNO, COHORT) %>%
  summarise(mean_expr = mean(TESTVALUE, na.rm = TRUE), .groups = "drop") %>%
  ggplot(aes(x = COHORT, y = mean_expr, fill = COHORT)) +
  geom_boxplot(alpha = 0.7, show.legend = FALSE) +
  geom_jitter(width = 0.15, alpha = 0.2, size = 0.8) +
  scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
  labs(title = "Mean Protein Expression per Patient by Cohort",
       x = "", y = "Mean Expression Value") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13))

raw_data %>%
  filter(COHORT %in% c("PD", "Control")) %>%
  group_by(COHORT) %>%
  summarise(
    mean_expr   = round(mean(TESTVALUE,   na.rm = TRUE), 3),
    median_expr = round(median(TESTVALUE, na.rm = TRUE), 3),
    sd_expr     = round(sd(TESTVALUE,     na.rm = TRUE), 3)
  ) %>%
  print()
## # A tibble: 2 × 4
##   COHORT  mean_expr median_expr sd_expr
##   <chr>       <dbl>       <dbl>   <dbl>
## 1 Control      7.69        6.89    2.38
## 2 PD           7.69        6.88    2.38
# Top 20 most variable proteins, labelled with gene symbols
raw_data %>%
  filter(COHORT %in% c("PD", "Control")) %>%
  group_by(TESTNAME) %>%
  summarise(variance = var(TESTVALUE, na.rm = TRUE), .groups = "drop") %>%
  arrange(desc(variance)) %>%
  slice_head(n = 20) %>%
  left_join(protein_lookup, by = c("TESTNAME" = "SOMA_SEQ_ID")) %>%
  mutate(protein_label = ifelse(is.na(TARGET_GENE_SYMBOL), TESTNAME, TARGET_GENE_SYMBOL)) %>%
  ggplot(aes(x = reorder(protein_label, variance), y = variance)) +
  geom_col(fill = "#4472C4", width = 0.7) +
  coord_flip() +
  labs(title = "Top 20 Most Variable Proteins",
       x = "", y = "Variance") +
  theme_minimal() +
  theme(plot.title  = element_text(face = "bold", size = 13),
        axis.text.y = element_text(size = 9))

glimpse(age_data)
## Rows: 46,227
## Columns: 3
## $ PATNO        <dbl> 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 300…
## $ EVENT_ID     <chr> "BL", "R17", "R18", "R19", "R20", "SC", "V01", "V02", "V0…
## $ AGE_AT_VISIT <dbl> 69.1, 80.5, 81.4, 84.0, 84.0, 69.1, 69.4, 69.6, 69.9, 70.…
age_col <- intersect(c("AGE", "AGE_AT_VISIT", "PATVISITAGE"), names(age_data))[1]
cat("Using age column:", age_col, "\n")
## Using age column: AGE_AT_VISIT
age_cohort <- age_data %>%
  inner_join(raw_data %>% distinct(PATNO, COHORT), by = "PATNO") %>%
  rename(age = all_of(age_col)) %>%
  filter(!is.na(age), COHORT %in% c("PD", "Control"))
ggplot(age_cohort, aes(x = age, fill = COHORT)) +
  geom_histogram(bins = 30, alpha = 0.7, position = "identity", color = "white") +
  scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
  labs(title = "Age Distribution by Cohort",
       x = "Age", y = "Count", fill = "") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13),
        legend.position = "bottom")

ggplot(age_cohort, aes(x = COHORT, y = age, fill = COHORT)) +
  geom_boxplot(alpha = 0.7, show.legend = FALSE) +
  geom_jitter(width = 0.15, alpha = 0.3, size = 1) +
  scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
  labs(title = "Age by Disease Status", x = "", y = "Age") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13))

age_cohort %>%
  group_by(COHORT) %>%
  summarise(
    n          = n(),
    mean_age   = round(mean(age, na.rm = TRUE), 1),
    median_age = median(age, na.rm = TRUE),
    sd_age     = round(sd(age, na.rm = TRUE), 1),
    min_age    = min(age, na.rm = TRUE),
    max_age    = max(age, na.rm = TRUE)
  ) %>%
  print()
## # A tibble: 2 × 7
##   COHORT      n mean_age median_age sd_age min_age max_age
##   <chr>   <int>    <dbl>      <dbl>  <dbl>   <dbl>   <dbl>
## 1 Control  3582     65.4       66.5   12      30.6    98.1
## 2 PD      10961     66         66.8    9.8    32.3    93.6
glimpse(demographics)
## Rows: 8,351
## Columns: 29
## $ REC_ID       <chr> "IA86904", "IA86905", "IA86906", "IA86907", "IA86908", "2…
## $ PATNO        <dbl> 3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 300…
## $ EVENT_ID     <chr> "TRANS", "TRANS", "TRANS", "TRANS", "TRANS", "TRANS", "TR…
## $ PAG_NAME     <chr> "SCREEN", "SCREEN", "SCREEN", "SCREEN", "SCREEN", "SCREEN…
## $ INFODT       <chr> "01/2011", "02/2011", "03/2011", "03/2011", "03/2011", "0…
## $ AFICBERB     <dbl> 0, 0, 0, 0, 0, NA, NA, NA, 0, 0, 0, NA, NA, 0, NA, NA, 0,…
## $ ASHKJEW      <dbl> 0, 0, 0, 0, 0, NA, NA, NA, 0, 0, 0, NA, NA, 0, NA, NA, 0,…
## $ BASQUE       <dbl> 0, 0, 0, 0, 0, NA, NA, NA, 0, 0, 0, NA, NA, 0, NA, NA, 0,…
## $ BIRTHDT      <chr> "12/1941", "01/1946", "08/1943", "07/1954", "11/1951", "1…
## $ SEX          <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, …
## $ CHLDBEAR     <dbl> 0, NA, 0, 0, NA, 0, 0, NA, 0, 0, NA, NA, NA, 0, NA, NA, N…
## $ HOWLIVE      <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ GAYLES       <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ HETERO       <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ BISEXUAL     <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ PANSEXUAL    <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ ASEXUAL      <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ OTHSEXUALITY <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ HANDED       <dbl> 1, 2, 1, 1, 1, 3, 1, 2, 1, 1, 3, 1, 1, 3, 1, 1, 1, 1, 1, …
## $ HISPLAT      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RAASIAN      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RABLACK      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ RAHAWOPI     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RAINDALS     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RANOS        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RAWHITE      <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, …
## $ RAUNKNOWN    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ORIG_ENTRY   <chr> "01/2011", "02/2011", "03/2011", "03/2011", "03/2011", "0…
## $ LAST_UPDATE  <dttm> 2022-11-07 00:00:00, 2022-11-07 00:00:00, 2022-11-07 00:…
# SEX coding: 0 = Female, 1 = Male
demo_cohort <- demographics %>%
  inner_join(raw_data %>% distinct(PATNO, COHORT), by = "PATNO") %>%
  distinct(PATNO, .keep_all = TRUE) %>%
  filter(COHORT %in% c("PD", "Control")) %>%
  mutate(SEX = factor(SEX, levels = c(0, 1),
                      labels = c("Female", "Male")))
demo_cohort %>%
  count(COHORT, SEX) %>%
  ggplot(aes(x = COHORT, y = n, fill = SEX)) +
  geom_col(position = "dodge", width = 0.6) +
  geom_text(aes(label = n), position = position_dodge(0.6),
            vjust = -0.5, size = 3.5, fontface = "bold") +
  scale_fill_manual(values = c("Male"   = "#4472C4",
                               "Female" = "#ED7D31")) +
  labs(title = "Sex Distribution by Cohort",
       x = "", y = "Count", fill = "") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13),
        legend.position = "bottom")

demo_cohort %>%
  count(COHORT, SEX) %>%
  group_by(COHORT) %>%
  mutate(pct = round(100 * n / sum(n), 1)) %>%
  print()
## # A tibble: 4 × 4
## # Groups:   COHORT [2]
##   COHORT  SEX        n   pct
##   <chr>   <fct>  <int> <dbl>
## 1 Control Female    69  37.1
## 2 Control Male     117  62.9
## 3 PD      Female   244  39.5
## 4 PD      Male     373  60.5
mutation_cols <- c("LRRK2", "GBA", "SNCA", "PRKN", "PINK1")  # VPS35/PARK7 all 0
gen_cohort <- genetics %>%
  inner_join(raw_data %>% distinct(PATNO, COHORT), by = "PATNO") %>%
  distinct(PATNO, .keep_all = TRUE) %>%
  filter(COHORT %in% c("PD", "Control")) %>%
  mutate(across(all_of(mutation_cols), as.character))
gen_cohort %>%
  pivot_longer(cols = all_of(mutation_cols),
               names_to  = "gene",
               values_to = "status") %>%
  mutate(carrier = status != "0" & !is.na(status)) %>%
  group_by(COHORT, gene) %>%
  summarise(n_carriers = sum(carrier, na.rm = TRUE), .groups = "drop") %>%
  ggplot(aes(x = reorder(gene, -n_carriers), y = n_carriers, fill = COHORT)) +
  geom_col(position = "dodge", width = 0.6) +
  geom_text(aes(label = n_carriers), position = position_dodge(0.6),
            vjust = -0.5, size = 3.5, fontface = "bold") +
  scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
  labs(title = "Number of Mutation Carriers by Gene and Cohort",
       x = "Gene", y = "Number of Carriers", fill = "") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13),
        legend.position = "bottom")

For now I will only use PD and Control patients.

age_baseline <- age_data %>%
  filter(EVENT_ID %in% c("BL", "SC")) %>%
  arrange(PATNO, EVENT_ID) %>%
  distinct(PATNO, .keep_all = TRUE) %>%
  select(PATNO, AGE_AT_VISIT)
genetics_features <- genetics %>%
  mutate(
    LRRK2_carrier = as.integer(LRRK2 != "0" & !is.na(LRRK2)),
    GBA_carrier   = as.integer(GBA   != "0" & !is.na(GBA)),
    SNCA_carrier  = as.integer(as.character(SNCA) != "0" & !is.na(SNCA)),
    PRKN_carrier  = as.integer(as.character(PRKN) != "0" & !is.na(PRKN)),
    APOE_e4       = as.integer(grepl("E4", APOE, ignore.case = TRUE))
  ) %>%
  select(PATNO, LRRK2_carrier, GBA_carrier,
         SNCA_carrier, PRKN_carrier, APOE_e4)
cat("Patients with baseline age:", nrow(age_baseline), "\n")
## Patients with baseline age: 8118
cat("Patients with genetics:    ", nrow(genetics_features), "\n")
## Patients with genetics:     6265
data_binary <- raw_data %>%
  filter(COHORT %in% c("PD", "Control")) %>%
  mutate(outcome = factor(ifelse(COHORT == "PD", 1, 0),
                          levels = c(0, 1),
                          labels = c("Control", "PD")))
data_binary %>% count(outcome)
## # A tibble: 2 × 2
##   outcome       n
##   <fct>     <int>
## 1 Control  890010
## 2 PD      2952345
wide_data <- data_binary %>%
  select(PATNO, SEX, outcome, TESTNAME, TESTVALUE) %>%
  pivot_wider(
    names_from  = TESTNAME,
    values_from = TESTVALUE
  ) %>%
  left_join(age_baseline,      by = "PATNO") %>%
  left_join(genetics_features, by = "PATNO")
dim(wide_data)
## [1]  803 4794
glimpse(wide_data %>% select(PATNO, SEX, outcome,
                              AGE_AT_VISIT, LRRK2_carrier,
                              GBA_carrier, SNCA_carrier,
                              PRKN_carrier, APOE_e4))
## Rows: 803
## Columns: 9
## $ PATNO         <dbl> 53595, 3029, 3963, 3316, 3124, 3175, 4103, 3467, 3168, 5…
## $ SEX           <chr> "Female", "Female", "Female", "Male", "Male", "Female", …
## $ outcome       <fct> PD, Control, PD, Control, PD, PD, PD, PD, PD, PD, PD, PD…
## $ AGE_AT_VISIT  <dbl> 65.2, 66.3, 58.4, 74.7, 57.2, 57.2, 59.0, 66.9, 63.1, 67…
## $ LRRK2_carrier <int> 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ GBA_carrier   <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ SNCA_carrier  <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ PRKN_carrier  <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ APOE_e4       <int> 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,…
cat("Missing AGE_AT_VISIT:  ", sum(is.na(wide_data$AGE_AT_VISIT)),  "\n")
## Missing AGE_AT_VISIT:   0
cat("Missing LRRK2_carrier: ", sum(is.na(wide_data$LRRK2_carrier)), "\n")
## Missing LRRK2_carrier:  0
cat("LRRK2 carriers:        ", sum(wide_data$LRRK2_carrier == 1, na.rm = TRUE), "\n")
## LRRK2 carriers:         156
cat("GBA carriers:          ", sum(wide_data$GBA_carrier   == 1, na.rm = TRUE), "\n")
## GBA carriers:           77
cat("APOE E4 carriers:      ", sum(wide_data$APOE_e4       == 1, na.rm = TRUE), "\n")
## APOE E4 carriers:       197

Calculating the variance of each protein. Looking at the variance distribution of proteins and plotting to pick a cutoff. Selecting the top 500 proteins.

# Define clinical variables (excluded from protein variance calculation)
clinical_vars <- c("AGE_AT_VISIT", "LRRK2_carrier", "GBA_carrier",
                   "SNCA_carrier", "PRKN_carrier", "APOE_e4")

protein_cols <- wide_data %>%
  select(-PATNO, -SEX, -outcome, -any_of(clinical_vars))

protein_vars <- protein_cols %>%
  summarise(across(everything(), var, na.rm = TRUE)) %>%
  pivot_longer(everything(),
               names_to  = "protein",
               values_to = "variance") %>%
  arrange(desc(variance))
## Warning: There was 1 warning in `summarise()`.
## ℹ In argument: `across(everything(), var, na.rm = TRUE)`.
## Caused by warning:
## ! The `...` argument of `across()` is deprecated as of dplyr 1.1.0.
## Supply arguments directly to `.fns` through an anonymous function instead.
## 
##   # Previously
##   across(a:b, mean, na.rm = TRUE)
## 
##   # Now
##   across(a:b, \(x) mean(x, na.rm = TRUE))
summary(protein_vars$variance)
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
##  0.003911  0.025173  0.052631  0.122038  0.120528 12.827399
ggplot(protein_vars, aes(x = variance)) +
  geom_histogram(bins = 100, fill = "steelblue", color = "white") +
  geom_vline(xintercept = protein_vars$variance[500],
             color = "red", linetype = "dashed", linewidth = 1) +
  annotate("text",
           x     = protein_vars$variance[500],
           y     = Inf,
           label = "Top 500 cutoff",
           vjust = 2, hjust = -0.1, color = "red") +
  labs(title = "Protein Variance Distribution",
       x     = "Variance",
       y     = "Count") +
  theme_minimal()

top500_proteins <- protein_vars %>%
  slice_head(n = 500) %>%
  pull(protein)

data_filtered <- wide_data %>%
  select(PATNO, SEX, outcome,
         all_of(clinical_vars),
         all_of(top500_proteins))

dim(data_filtered)
## [1] 803 509
cat("Clinical vars present:",
    sum(clinical_vars %in% names(data_filtered)),
    "of", length(clinical_vars), "\n")
## Clinical vars present: 6 of 6

Splitting the data before running the t-test to avoid data leakage — t-test feature selection is performed on the training set only.

set.seed(42)
data_split_pre <- initial_split(
  data_filtered %>% select(-PATNO),
  prop   = 0.80,
  strata = outcome
)
train_pre <- training(data_split_pre)
test_pre  <- testing(data_split_pre)

cat("Pre-split training set:\n")
## Pre-split training set:
train_pre %>% count(outcome) %>% print()
## # A tibble: 2 × 2
##   outcome     n
##   <fct>   <int>
## 1 Control   148
## 2 PD        493
cat("\nPre-split test set:\n")
## 
## Pre-split test set:
test_pre %>% count(outcome) %>% print()
## # A tibble: 2 × 2
##   outcome     n
##   <fct>   <int>
## 1 Control    38
## 2 PD        124

Running t-test on training data only (no leakage). Volcano plot of PD vs Control.

ttest_results <- train_pre %>%
  select(outcome, all_of(top500_proteins)) %>%
  pivot_longer(
    cols      = all_of(top500_proteins),
    names_to  = "protein",
    values_to = "value"
  ) %>%
  group_by(protein) %>%
  summarise(
    p_value   = t.test(value ~ outcome)$p.value,
    mean_PD   = mean(value[outcome == "PD"],      na.rm = TRUE),
    mean_ctrl = mean(value[outcome == "Control"],  na.rm = TRUE),
    log2FC    = mean_PD - mean_ctrl
  ) %>%
  mutate(
    p_adj = p.adjust(p_value, method = "BH"),
    sig   = p_adj < 0.05
  ) %>%
  arrange(p_adj)

ttest_results %>% count(sig)
## # A tibble: 2 × 2
##   sig       n
##   <lgl> <int>
## 1 FALSE   454
## 2 TRUE     46
head(ttest_results, 20)
## # A tibble: 20 × 7
##    protein           p_value mean_PD mean_ctrl log2FC      p_adj sig  
##    <chr>               <dbl>   <dbl>     <dbl>  <dbl>      <dbl> <lgl>
##  1 3504-58_2   0.00000000789    9.33      8.95  0.375 0.00000395 TRUE 
##  2 7757-5_3    0.0000000423     9.41      9.11  0.300 0.0000106  TRUE 
##  3 5080-131_3  0.00000437       9.68      9.48  0.203 0.000728   TRUE 
##  4 6521-35_3   0.00000716      11.7      11.9  -0.228 0.000895   TRUE 
##  5 14125-5_3   0.0000162        8.99      8.77  0.219 0.00162    TRUE 
##  6 13722-105_3 0.0000431       10.4      10.2   0.202 0.00214    TRUE 
##  7 2292-17_4   0.0000401        9.66      9.45  0.205 0.00214    TRUE 
##  8 2743-5_2    0.0000267       12.9      12.7   0.288 0.00214    TRUE 
##  9 2900-53_3   0.0000418        9.47      9.24  0.234 0.00214    TRUE 
## 10 4141-79_1   0.0000471       10.6      10.3   0.265 0.00214    TRUE 
## 11 8240-207_3  0.0000430       11.6      11.4   0.221 0.00214    TRUE 
## 12 6520-87_3   0.0000540       13.4      13.2   0.228 0.00225    TRUE 
## 13 3535-84_1   0.0000771        8.55      8.34  0.204 0.00297    TRUE 
## 14 13416-8_3   0.0000942       12.1      12.3  -0.233 0.00314    TRUE 
## 15 3060-43_2   0.0000911       11.2      11.0   0.187 0.00314    TRUE 
## 16 9900-36_3   0.000206         9.14      8.92  0.221 0.00642    TRUE 
## 17 3415-61_2   0.000238         8.70      8.23  0.469 0.00662    TRUE 
## 18 9796-4_3    0.000226        15.2      15.0   0.215 0.00662    TRUE 
## 19 8890-9_3    0.000352        11.7      11.9  -0.197 0.00926    TRUE 
## 20 9986-14_3   0.000414         9.36      9.16  0.206 0.0104     TRUE
ggplot(ttest_results, aes(x = log2FC, y = -log10(p_adj), color = sig)) +
  geom_point(alpha = 0.6, size = 1.2) +
  scale_color_manual(values = c("FALSE" = "grey60", "TRUE" = "#C00000"),
                     labels = c("Not significant", "Significant (BH < 0.05)")) +
  geom_hline(yintercept = -log10(0.05), linetype = "dashed", color = "black") +
  labs(title = "Volcano Plot: PD vs Control (t-test, training set)",
       x = "Log2 Fold Change (PD - Control)",
       y = "-log10(Adjusted p-value)",
       color = "") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13),
        legend.position = "bottom")

Selecting significant proteins and assembling the modelling dataset. Using make_splits() to reconstruct the train/test split object for last_fit().

sig_proteins <- ttest_results %>%
  filter(sig == TRUE) %>%
  pull(protein)
cat("Significant proteins:", length(sig_proteins), "\n")
## Significant proteins: 46
# Apply the same feature set to train and test
train_data <- train_pre %>%
  select(SEX, outcome, all_of(clinical_vars), all_of(sig_proteins))
test_data  <- test_pre  %>%
  select(SEX, outcome, all_of(clinical_vars), all_of(sig_proteins))

# Combine into one data frame and reconstruct the rsample split object
data_model <- bind_rows(train_data, test_data)
dim(data_model)
## [1] 803  54
data_model %>% count(outcome)
## # A tibble: 2 × 2
##   outcome     n
##   <fct>   <int>
## 1 Control   186
## 2 PD        617
data_model %>%
  select(all_of(clinical_vars)) %>%
  summarise(across(everything(), ~sum(is.na(.)))) %>%
  pivot_longer(everything(),
               names_to  = "variable",
               values_to = "n_missing") %>%
  mutate(pct_missing = round(100 * n_missing / nrow(data_model), 1)) %>%
  print()
## # A tibble: 6 × 3
##   variable      n_missing pct_missing
##   <chr>             <int>       <dbl>
## 1 AGE_AT_VISIT          0           0
## 2 LRRK2_carrier         0           0
## 3 GBA_carrier           0           0
## 4 SNCA_carrier          0           0
## 5 PRKN_carrier          0           0
## 6 APOE_e4               0           0
# Reconstruct rsample split object so last_fit() uses the correct rows
data_split <- make_splits(
  list(analysis   = seq_len(nrow(train_data)),
       assessment = seq(nrow(train_data) + 1, nrow(data_model))),
  data = data_model
)

saveRDS(data_model, file.path(data_dir, "data_model.rds"))
set.seed(42)
cv_folds <- vfold_cv(
  train_data,
  v      = 10,
  strata = outcome
)
cv_folds
## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits           id    
##    <list>           <chr> 
##  1 <split [576/65]> Fold01
##  2 <split [576/65]> Fold02
##  3 <split [576/65]> Fold03
##  4 <split [577/64]> Fold04
##  5 <split [577/64]> Fold05
##  6 <split [577/64]> Fold06
##  7 <split [577/64]> Fold07
##  8 <split [577/64]> Fold08
##  9 <split [578/63]> Fold09
## 10 <split [578/63]> Fold10
# Recipe: remove SEX, impute missing clinical + protein values, normalize,
# and downsample the majority class to address class imbalance (617 PD vs 186 Control)
model_recipe <- recipe(outcome ~ ., data = train_data) %>%
  step_rm(SEX) %>%
  step_impute_median(all_of(clinical_vars)) %>%
  step_impute_median(all_predictors()) %>%
  step_normalize(all_predictors()) %>%
  step_downsample(outcome, under_ratio = 1)   # fix class imbalance

summary(model_recipe)
## # A tibble: 54 × 4
##    variable      type      role      source  
##    <chr>         <list>    <chr>     <chr>   
##  1 SEX           <chr [3]> predictor original
##  2 AGE_AT_VISIT  <chr [2]> predictor original
##  3 LRRK2_carrier <chr [2]> predictor original
##  4 GBA_carrier   <chr [2]> predictor original
##  5 SNCA_carrier  <chr [2]> predictor original
##  6 PRKN_carrier  <chr [2]> predictor original
##  7 APOE_e4       <chr [2]> predictor original
##  8 3504-58_2     <chr [2]> predictor original
##  9 7757-5_3      <chr [2]> predictor original
## 10 5080-131_3    <chr [2]> predictor original
## # ℹ 44 more rows
log_spec <- logistic_reg() %>%
  set_engine("glm") %>%
  set_mode("classification")
log_workflow <- workflow() %>%
  add_recipe(model_recipe) %>%
  add_model(log_spec)

lasso_spec <- logistic_reg(
  penalty = tune(),
  mixture = 1
) %>%
  set_engine("glmnet") %>%
  set_mode("classification")
lasso_workflow <- workflow() %>%
  add_recipe(model_recipe) %>%
  add_model(lasso_spec)

rf_spec <- rand_forest(
  mtry  = tune(),
  trees = 500,
  min_n = tune()
) %>%
  set_engine("ranger", importance = "permutation") %>%
  set_mode("classification")
rf_workflow <- workflow() %>%
  add_recipe(model_recipe) %>%
  add_model(rf_spec)

log_workflow
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 5 Recipe Steps
## 
## • step_rm()
## • step_impute_median()
## • step_impute_median()
## • step_normalize()
## • step_downsample()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Logistic Regression Model Specification (classification)
## 
## Computational engine: glm
lasso_workflow
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 5 Recipe Steps
## 
## • step_rm()
## • step_impute_median()
## • step_impute_median()
## • step_normalize()
## • step_downsample()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Logistic Regression Model Specification (classification)
## 
## Main Arguments:
##   penalty = tune()
##   mixture = 1
## 
## Computational engine: glmnet
rf_workflow
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 5 Recipe Steps
## 
## • step_rm()
## • step_impute_median()
## • step_impute_median()
## • step_normalize()
## • step_downsample()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   mtry = tune()
##   trees = 500
##   min_n = tune()
## 
## Engine-Specific Arguments:
##   importance = permutation
## 
## Computational engine: ranger
# ── Tune LASSO ────────────────────────────────────────────────────────────────
lasso_grid <- grid_regular(
  penalty(range = c(-4, 0)),
  levels = 50
)
set.seed(42)
lasso_tuned <- tune_grid(
  lasso_workflow,
  resamples = cv_folds,
  grid      = lasso_grid,
  metrics   = metric_set(roc_auc, accuracy, sens, spec),
  control   = control_grid(save_pred = TRUE)
)
## Warning: Using `all_of()` outside of a selecting function was deprecated in tidyselect
## 1.2.0.
## ℹ See details at
##   <https://tidyselect.r-lib.org/reference/faq-selection-context.html>
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
lasso_best <- select_best(lasso_tuned, metric = "roc_auc")
cat("Best LASSO penalty:", lasso_best$penalty, "\n")
## Best LASSO penalty: 0.009102982
# ── Tune Random Forest ────────────────────────────────────────────────────────
rf_grid <- grid_regular(
  mtry(range  = c(2, 20)),
  min_n(range = c(2, 20)),
  levels = 5
)
set.seed(42)
rf_tuned <- tune_grid(
  rf_workflow,
  resamples = cv_folds,
  grid      = rf_grid,
  metrics   = metric_set(roc_auc, accuracy, sens, spec),
  control   = control_grid(save_pred = TRUE)
)
rf_best <- select_best(rf_tuned, metric = "roc_auc")
cat("Best RF mtry:", rf_best$mtry, "min_n:", rf_best$min_n, "\n")
## Best RF mtry: 11 min_n: 2
# ── Logistic Regression CV ────────────────────────────────────────────────────
set.seed(42)
log_cv_results <- fit_resamples(
  log_workflow,
  resamples = cv_folds,
  metrics   = metric_set(roc_auc, accuracy, sens, spec),
  control   = control_resamples(save_pred = TRUE)
)
## → A | warning: prediction from rank-deficient fit; attr(*, "non-estim") has doubtful cases
## There were issues with some computations   A: x2There were issues with some computations   A: x4There were issues with some computations   A: x4
# ── Tuning plots ──────────────────────────────────────────────────────────────
autoplot(lasso_tuned, metric = "roc_auc") +
  labs(title = "LASSO Tuning: ROC AUC vs Penalty") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13))

autoplot(rf_tuned, metric = "roc_auc") +
  labs(title = "Random Forest Tuning: ROC AUC vs Hyperparameters") +
  theme_minimal() +
  theme(plot.title = element_text(face = "bold", size = 13))

# ── Cross-validation comparison ───────────────────────────────────────────────
log_metrics   <- collect_metrics(log_cv_results) %>% mutate(model = "Logistic")
lasso_metrics <- collect_metrics(lasso_tuned) %>%
  filter(penalty == lasso_best$penalty) %>%
  mutate(model = "LASSO")
rf_metrics    <- collect_metrics(rf_tuned) %>%
  filter(mtry == rf_best$mtry, min_n == rf_best$min_n) %>%
  mutate(model = "Random Forest")

cv_comparison <- bind_rows(log_metrics, lasso_metrics, rf_metrics) %>%
  select(model, .metric, mean, std_err) %>%
  filter(.metric %in% c("roc_auc", "accuracy", "sens", "spec")) %>%
  arrange(.metric, desc(mean))
print(cv_comparison, n = 50)
## # A tibble: 12 × 4
##    model         .metric   mean std_err
##    <chr>         <chr>    <dbl>   <dbl>
##  1 Random Forest accuracy 0.705  0.0165
##  2 Logistic      accuracy 0.694  0.0191
##  3 LASSO         accuracy 0.671  0.0219
##  4 LASSO         roc_auc  0.810  0.0221
##  5 Random Forest roc_auc  0.795  0.0195
##  6 Logistic      roc_auc  0.777  0.0162
##  7 LASSO         sens     0.811  0.0386
##  8 Random Forest sens     0.736  0.0504
##  9 Logistic      sens     0.729  0.0326
## 10 Random Forest spec     0.696  0.0190
## 11 Logistic      spec     0.683  0.0252
## 12 LASSO         spec     0.629  0.0245
cv_comparison %>%
  mutate(.metric = recode(.metric,
    "roc_auc"  = "ROC AUC",
    "accuracy" = "Accuracy",
    "sens"     = "Sensitivity",
    "spec"     = "Specificity"
  )) %>%
  ggplot(aes(x = model, y = mean, fill = model)) +
  geom_col(width = 0.6, show.legend = FALSE) +
  geom_errorbar(aes(ymin = mean - std_err, ymax = mean + std_err),
                width = 0.2) +
  geom_text(aes(label = round(mean, 3)), vjust = -0.5, size = 3.5) +
  facet_wrap(~ .metric, scales = "free_y") +
  scale_fill_manual(values = c("Logistic"      = "#4472C4",
                               "LASSO"         = "#ED7D31",
                               "Random Forest" = "#70AD47")) +
  labs(title = "Cross-Validation Performance by Model",
       x = "", y = "Mean Score") +
  theme_minimal() +
  theme(plot.title  = element_text(face = "bold", size = 13),
        axis.text.x = element_text(angle = 20, hjust = 1))

final_log_workflow   <- log_workflow
final_lasso_workflow <- lasso_workflow %>% finalize_workflow(lasso_best)
final_rf_workflow    <- rf_workflow    %>% finalize_workflow(rf_best)

set.seed(42)
log_final   <- last_fit(final_log_workflow,   data_split,
                        metrics = metric_set(roc_auc, accuracy, sens, spec))
## → A | warning: prediction from rank-deficient fit; attr(*, "non-estim") has doubtful cases
## There were issues with some computations   A: x2There were issues with some computations   A: x2
lasso_final <- last_fit(final_lasso_workflow, data_split,
                        metrics = metric_set(roc_auc, accuracy, sens, spec))
rf_final    <- last_fit(final_rf_workflow,    data_split,
                        metrics = metric_set(roc_auc, accuracy, sens, spec))

# ── Test-set performance table ────────────────────────────────────────────────
test_metrics <- bind_rows(
  collect_metrics(log_final)   %>% mutate(model = "Logistic"),
  collect_metrics(lasso_final) %>% mutate(model = "LASSO"),
  collect_metrics(rf_final)    %>% mutate(model = "Random Forest")
) %>%
  select(model, .metric, .estimate) %>%
  arrange(.metric, desc(.estimate))
print(test_metrics, n = 50)
## # A tibble: 12 × 3
##    model         .metric  .estimate
##    <chr>         <chr>        <dbl>
##  1 Logistic      accuracy     0.722
##  2 Random Forest accuracy     0.716
##  3 LASSO         accuracy     0.704
##  4 Logistic      roc_auc      0.812
##  5 LASSO         roc_auc      0.808
##  6 Random Forest roc_auc      0.778
##  7 LASSO         sens         0.842
##  8 Logistic      sens         0.789
##  9 Random Forest sens         0.711
## 10 Random Forest spec         0.718
## 11 Logistic      spec         0.702
## 12 LASSO         spec         0.661
test_metrics %>%
  mutate(.metric = recode(.metric,
    "roc_auc"  = "ROC AUC",
    "accuracy" = "Accuracy",
    "sens"     = "Sensitivity",
    "spec"     = "Specificity"
  )) %>%
  ggplot(aes(x = model, y = .estimate, fill = model)) +
  geom_col(width = 0.6, show.legend = FALSE) +
  geom_text(aes(label = round(.estimate, 3)), vjust = -0.5, size = 3.5) +
  facet_wrap(~ .metric, scales = "free_y") +
  scale_fill_manual(values = c("Logistic"      = "#4472C4",
                               "LASSO"         = "#ED7D31",
                               "Random Forest" = "#70AD47")) +
  labs(title = "Test Set Performance by Model", x = "", y = "Score") +
  theme_minimal() +
  theme(plot.title  = element_text(face = "bold", size = 13),
        axis.text.x = element_text(angle = 20, hjust = 1))

# ── Confusion matrices ────────────────────────────────────────────────────────
cat("=== Logistic Regression ===\n")
## === Logistic Regression ===
collect_predictions(log_final) %>%
  conf_mat(truth = outcome, estimate = .pred_class) %>%
  print()
##           Truth
## Prediction Control PD
##    Control      30 37
##    PD            8 87
cat("\n=== LASSO ===\n")
## 
## === LASSO ===
collect_predictions(lasso_final) %>%
  conf_mat(truth = outcome, estimate = .pred_class) %>%
  print()
##           Truth
## Prediction Control PD
##    Control      32 42
##    PD            6 82
cat("\n=== Random Forest ===\n")
## 
## === Random Forest ===
collect_predictions(rf_final) %>%
  conf_mat(truth = outcome, estimate = .pred_class) %>%
  print()
##           Truth
## Prediction Control PD
##    Control      27 35
##    PD           11 89
# ── ROC curves ────────────────────────────────────────────────────────────────
log_roc   <- collect_predictions(log_final)   %>% mutate(model = "Logistic")
lasso_roc <- collect_predictions(lasso_final) %>% mutate(model = "LASSO")
rf_roc    <- collect_predictions(rf_final)    %>% mutate(model = "Random Forest")

bind_rows(log_roc, lasso_roc, rf_roc) %>%
  group_by(model) %>%
  roc_curve(truth = outcome, .pred_Control) %>%
  ggplot(aes(x = 1 - specificity, y = sensitivity, color = model)) +
  geom_line(linewidth = 1.2) +
  geom_abline(linetype = "dashed", color = "grey50") +
  scale_color_manual(values = c("Logistic"      = "#4472C4",
                                "LASSO"         = "#ED7D31",
                                "Random Forest" = "#70AD47")) +
  labs(title    = "ROC Curves on Test Set",
       x        = "1 - Specificity (False Positive Rate)",
       y        = "Sensitivity (True Positive Rate)",
       color    = "Model") +
  theme_minimal() +
  theme(plot.title      = element_text(face = "bold", size = 13),
        legend.position = "bottom")

# ── LASSO: Top 20 coefficients with gene symbols ──────────────────────────────
lasso_final %>%
  extract_fit_parsnip() %>%
  tidy() %>%
  filter(term != "(Intercept)", estimate != 0) %>%
  arrange(desc(abs(estimate))) %>%
  slice_head(n = 20) %>%
  left_join(protein_lookup, by = c("term" = "SOMA_SEQ_ID")) %>%
  mutate(protein_label = ifelse(is.na(TARGET_GENE_SYMBOL), term, TARGET_GENE_SYMBOL)) %>%
  ggplot(aes(x = reorder(protein_label, abs(estimate)), y = estimate,
             fill = estimate > 0)) +
  geom_col(show.legend = FALSE) +
  coord_flip() +
  scale_fill_manual(values = c("TRUE" = "#C00000", "FALSE" = "#4472C4")) +
  labs(title = "LASSO: Top 20 Coefficients",
       x = "", y = "Coefficient") +
  theme_minimal() +
  theme(plot.title  = element_text(face = "bold", size = 13),
        axis.text.y = element_text(size = 8))

# ── Random Forest: Top 20 variable importance with gene symbols ───────────────
rf_imp <- rf_final %>%
  extract_fit_parsnip() %>%
  vi() %>%
  slice_head(n = 20) %>%
  left_join(protein_lookup, by = c("Variable" = "SOMA_SEQ_ID")) %>%
  mutate(protein_label = ifelse(is.na(TARGET_GENE_SYMBOL), Variable, TARGET_GENE_SYMBOL))

ggplot(rf_imp, aes(x = reorder(protein_label, Importance), y = Importance)) +
  geom_col(fill = "#70AD47", show.legend = FALSE) +
  coord_flip() +
  labs(title = "Random Forest: Top 20 Variable Importance",
       x = "", y = "Permutation Importance") +
  theme_minimal() +
  theme(plot.title  = element_text(face = "bold", size = 13),
        axis.text.y = element_text(size = 8))