Parkinson’s disease (PD) is a progressive neurodegenerative disorder affecting millions of people worldwide, characterized by motor symptoms such as tremor, rigidity, and bradykinesia, as well as non-motor manifestations including cognitive decline and autonomic dysfunction. One of the major clinical challenges with Parkinson’s disease is the lack of reliable, objective set of biomarkers for early diagnosis. Ususally, patients are often diagnosed only after significant neurodegeneration has already occurred with physical tests or observing lesions in specific brain regions.
This project aims to investigate whether cerebrospinal fluid (CSF) proteomics data can be used to accurately predict Parkinson’s disease status. Specifically, Can protein expression profiles from CSF samples reliably distinguish Parkinson’s disease patients from healthy controls, and which proteins are most predictive of disease status?
The data used in this project comes from the Parkinson’s Progression Markers Initiative (PPMI), a landmark longitudinal observational study sponsored by the Michael J. Fox Foundation. PPMI was designed to identify biomarkers of Parkinson’s disease onset and progression by recruiting both PD patients and healthy controls and collecting biological samples, imaging, and clinical assessments over time. https://www.ppmi-info.org/
Clinical Data:
clinical_dir <- "/Users/ramazanegesolak/PPMI_STRING"
age_data <- read_csv(file.path(clinical_dir, "Age_at_visit_12Apr2026.csv"), show_col_types = FALSE)
clinical_dx <- read_csv(file.path(clinical_dir, "Clinical_Diagnosis_12Apr2026.csv"), show_col_types = FALSE)
demographics <- read_csv(file.path(clinical_dir, "Demographics_12Apr2026.csv"), show_col_types = FALSE)
genetics <- read_csv(file.path(clinical_dir, "iu_genetic_consensus_20251025_12Apr2026.csv"), show_col_types = FALSE)
## Warning: One or more parsing issues, call `problems()` on your data frame for details,
## e.g.:
## dat <- vroom(...)
## problems(dat)
participant_status <- read_csv(file.path(clinical_dir, "Participant_Status_12Apr2026.csv"), show_col_types = FALSE)
# Quick check all loaded correctly
cat("age_data: ", nrow(age_data), "rows\n")
## age_data: 46227 rows
cat("clinical_dx: ", nrow(clinical_dx), "rows\n")
## clinical_dx: 15440 rows
cat("demographics: ", nrow(demographics), "rows\n")
## demographics: 8351 rows
cat("genetics: ", nrow(genetics), "rows\n")
## genetics: 6265 rows
cat("participant_status:", nrow(participant_status), "rows\n")
## participant_status: 8417 rows
Outcome variable: Parkinson’s disease diagnosis (COHORT): binary, PD patient vs. control
Primary predictor variable: CSF protein expression levels (TESTVALUE), indexed by protein name (TESTNAME)
Possible confounders: SEX, protein expression levels are known to vary by sex and PD prevalence differs between males and females; PLATEID, batch-to-batch variability across assay plates can introduce systematic differences in protein measurements unrelated to disease status; AGE_AT_VISIT, age is strongly associated with both PD risk and protein expression levels; genetic mutation status from iu_genetic_consensus (e.g. LRRK2, GBA may independently affect protein expression)
Potential effect modifiers: SEX may act as an effect modifier if the relationship between specific protein expression levels and PD diagnosis differs between males and females
library(ggdag)
library(ggplot2)
dag <- dagify(
Outcome ~ Predictor + SEX + AGE + Genetics,
Predictor ~ SEX + AGE + Genetics + Batch,
coords = list(
x = c(SEX = 1, AGE = 1, Genetics = 1,
Batch = 1, Predictor = 2.8, Outcome = 4.5),
y = c(SEX = 1.5, AGE = 0.5, Genetics = -0.5,
Batch = -1.5, Predictor = 0, Outcome = 0)
)
)
ggdag(dag, text = FALSE) +
geom_dag_point(color = "#4472C4", size = 22) +
geom_dag_text(color = "white", size = 3.2, fontface = "bold") +
theme_dag() +
labs(
title = "DAG: CSF Proteomics and Parkinson's Disease Prediction",
caption = paste(
"Outcome = PD Diagnosis (COHORT)",
"Predictor = CSF Protein Expression (TESTVALUE)",
"SEX = Biological sex — affects PD prevalence & protein levels",
"AGE = Age at visit — risk factor for PD & protein expression",
"Genetics = Mutation carrier status (LRRK2, GBA, SNCA, PRKN, APOE-e4)",
"Batch = Assay plate (PLATEID) — technical measurement confounder",
sep = "\n"
)
) +
theme(
plot.title = element_text(face = "bold", size = 13, hjust = 0.5),
plot.caption = element_text(size = 8.5, color = "gray30", hjust = 0,
margin = margin(t = 12), family = "mono")
)
Because our outcome variable is binary, to answer the question of interest, three classification approaches will be built and compared: Logistic Regression, LASSO, and Random Forest. By comparing these models, this project will evaluate both predictive accuracy and the interpretability of selected biomarkers.
To control for potential confounding, sex will be included as a biological covariate across all three models, as protein expression levels are known to vary by sex and Parkinson’s disease prevalence differs between males and females. PLATEID will be examined and adjusted for as a technical confounder, since batch-to-batch variability across assay plates can introduce systematic differences in protein measurements unrelated to disease status. Protein expression values will also be normalized prior to modeling to further reduce technical variation. Model performance will be evaluated using AUC-ROC, accuracy and a confusion matrix.
Exploratory data analysis:
wanted_cols <- c(
"PATNO", "SEX", "COHORT", "TESTNAME",
"TESTVALUE", "PLATEID"
)
raw_data <- file_list %>%
map_dfr(~read_csv(.x, show_col_types = FALSE)) %>%
select(any_of(wanted_cols))
glimpse(raw_data)
## Rows: 5,541,030
## Columns: 6
## $ PATNO <dbl> 53595, 53595, 53595, 53595, 53595, 53595, 53595, 53595, 5359…
## $ SEX <chr> "Female", "Female", "Female", "Female", "Female", "Female", …
## $ COHORT <chr> "PD", "PD", "PD", "PD", "PD", "PD", "PD", "PD", "PD", "PD", …
## $ TESTNAME <chr> "5632-6_3", "5631-83_3", "5630-48_3", "5629-58_3", "5628-21_…
## $ TESTVALUE <dbl> 10.182013, 6.346285, 12.943830, 5.891402, 12.836350, 7.03480…
## $ PLATEID <chr> "P0022899", "P0022899", "P0022899", "P0022899", "P0022899", …
head(raw_data, 20)
## # A tibble: 20 × 6
## PATNO SEX COHORT TESTNAME TESTVALUE PLATEID
## <dbl> <chr> <chr> <chr> <dbl> <chr>
## 1 53595 Female PD 5632-6_3 10.2 P0022899
## 2 53595 Female PD 5631-83_3 6.35 P0022899
## 3 53595 Female PD 5630-48_3 12.9 P0022899
## 4 53595 Female PD 5629-58_3 5.89 P0022899
## 5 53595 Female PD 5628-21_3 12.8 P0022899
## 6 53595 Female PD 5627-53_3 7.03 P0022899
## 7 53595 Female PD 5626-20_3 5.89 P0022899
## 8 53595 Female PD 5624-66_3 7.66 P0022899
## 9 53595 Female PD 5623-11_3 6.13 P0022899
## 10 53595 Female PD 5621-64_3 5.95 P0022899
## 11 53595 Female PD 5620-13_3 10.1 P0022899
## 12 53595 Female PD 5618-50_3 13.9 P0022899
## 13 53595 Female PD 5617-41_3 8.84 P0022899
## 14 53595 Female PD 5615-62_3 6.25 P0022899
## 15 53595 Female PD 5614-44_3 6.73 P0022899
## 16 53595 Female PD 5613-75_3 5.87 P0022899
## 17 53595 Female PD 5612-16_3 6.20 P0022899
## 18 53595 Female PD 5611-56_3 6.02 P0022899
## 19 53595 Female PD 5610-32_3 6.38 P0022899
## 20 53595 Female PD 5609-92_3 14.1 P0022899
raw_data %>% count(COHORT)
## # A tibble: 3 × 2
## COHORT n
## <chr> <int>
## 1 Control 890010
## 2 PD 2952345
## 3 Prodromal 1698675
raw_data %>% summarise(
n_patients = n_distinct(PATNO),
n_proteins = n_distinct(TESTNAME),
n_rows = n()
)
## # A tibble: 1 × 3
## n_patients n_proteins n_rows
## <int> <int> <int>
## 1 1158 4785 5541030
# Check for duplicate patient-protein combinations
raw_data %>%
count(PATNO, TESTNAME) %>%
filter(n > 1) %>%
nrow()
## [1] 0
annotations <- read_csv(
file.path(data_dir, "PPMI_Project_151_pqtl_Analysis_Annotations_20210210.csv"),
show_col_types = FALSE
)
## New names:
## • `` -> `...3`
## • `` -> `...8`
## • `` -> `...9`
## • `` -> `...10`
## • `` -> `...11`
## • `` -> `...12`
## • `` -> `...13`
## • `` -> `...14`
## • `` -> `...15`
## • `` -> `...16`
## • `` -> `...17`
## • `` -> `...18`
## • `` -> `...19`
## • `` -> `...20`
## • `` -> `...21`
## • `` -> `...22`
## • `` -> `...23`
## • `` -> `...24`
## • `` -> `...25`
## • `` -> `...26`
## • `` -> `...27`
## • `` -> `...28`
## • `` -> `...29`
## • `` -> `...30`
## • `` -> `...31`
## • `` -> `...32`
## • `` -> `...33`
## • `` -> `...34`
## • `` -> `...35`
## • `` -> `...36`
## • `` -> `...37`
## • `` -> `...38`
## • `` -> `...39`
## • `` -> `...40`
## • `` -> `...41`
## • `` -> `...42`
## • `` -> `...43`
## • `` -> `...44`
## • `` -> `...45`
## • `` -> `...46`
## • `` -> `...47`
## • `` -> `...48`
## • `` -> `...49`
## • `` -> `...50`
## • `` -> `...51`
## • `` -> `...52`
## • `` -> `...53`
## • `` -> `...54`
## • `` -> `...55`
## • `` -> `...56`
## • `` -> `...57`
## • `` -> `...58`
## • `` -> `...59`
## • `` -> `...60`
## • `` -> `...61`
protein_lookup <- annotations %>%
select(SOMA_SEQ_ID, TARGET_GENE_SYMBOL) %>%
distinct() %>%
filter(!is.na(TARGET_GENE_SYMBOL), TARGET_GENE_SYMBOL != "")
head(protein_lookup)
## # A tibble: 6 × 2
## SOMA_SEQ_ID TARGET_GENE_SYMBOL
## <chr> <chr>
## 1 10000-28_3 CRYBB2
## 2 10001-7_3 RAF1
## 3 10003-15_3 ZNF41
## 4 10006-25_3 ELK1
## 5 10008-43_3 GUCA1A
## 6 10009-2_3 IRF1
raw_data %>%
distinct(PATNO, COHORT) %>%
count(COHORT) %>%
ggplot(aes(x = reorder(COHORT, -n), y = n, fill = COHORT)) +
geom_col(width = 0.5, show.legend = FALSE) +
geom_text(aes(label = n), vjust = -0.5, fontface = "bold", size = 5) +
scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
labs(title = "Number of Patients by Cohort", x = "", y = "Count") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13))
raw_data %>%
filter(COHORT %in% c("PD", "Control")) %>%
ggplot(aes(x = TESTVALUE, fill = COHORT)) +
geom_density(alpha = 0.5) +
scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
labs(title = "Distribution of Protein Expression Values",
x = "Expression Value (log2)", y = "Density", fill = "") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
legend.position = "bottom")
raw_data %>%
filter(COHORT %in% c("PD", "Control")) %>%
group_by(PATNO, COHORT) %>%
summarise(mean_expr = mean(TESTVALUE, na.rm = TRUE), .groups = "drop") %>%
ggplot(aes(x = COHORT, y = mean_expr, fill = COHORT)) +
geom_boxplot(alpha = 0.7, show.legend = FALSE) +
geom_jitter(width = 0.15, alpha = 0.2, size = 0.8) +
scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
labs(title = "Mean Protein Expression per Patient by Cohort",
x = "", y = "Mean Expression Value") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13))
raw_data %>%
filter(COHORT %in% c("PD", "Control")) %>%
group_by(COHORT) %>%
summarise(
mean_expr = round(mean(TESTVALUE, na.rm = TRUE), 3),
median_expr = round(median(TESTVALUE, na.rm = TRUE), 3),
sd_expr = round(sd(TESTVALUE, na.rm = TRUE), 3)
) %>%
print()
## # A tibble: 2 × 4
## COHORT mean_expr median_expr sd_expr
## <chr> <dbl> <dbl> <dbl>
## 1 Control 7.69 6.89 2.38
## 2 PD 7.69 6.88 2.38
# Top 20 most variable proteins, labelled with gene symbols
raw_data %>%
filter(COHORT %in% c("PD", "Control")) %>%
group_by(TESTNAME) %>%
summarise(variance = var(TESTVALUE, na.rm = TRUE), .groups = "drop") %>%
arrange(desc(variance)) %>%
slice_head(n = 20) %>%
left_join(protein_lookup, by = c("TESTNAME" = "SOMA_SEQ_ID")) %>%
mutate(protein_label = ifelse(is.na(TARGET_GENE_SYMBOL), TESTNAME, TARGET_GENE_SYMBOL)) %>%
ggplot(aes(x = reorder(protein_label, variance), y = variance)) +
geom_col(fill = "#4472C4", width = 0.7) +
coord_flip() +
labs(title = "Top 20 Most Variable Proteins",
x = "", y = "Variance") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
axis.text.y = element_text(size = 9))
glimpse(age_data)
## Rows: 46,227
## Columns: 3
## $ PATNO <dbl> 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 300…
## $ EVENT_ID <chr> "BL", "R17", "R18", "R19", "R20", "SC", "V01", "V02", "V0…
## $ AGE_AT_VISIT <dbl> 69.1, 80.5, 81.4, 84.0, 84.0, 69.1, 69.4, 69.6, 69.9, 70.…
age_col <- intersect(c("AGE", "AGE_AT_VISIT", "PATVISITAGE"), names(age_data))[1]
cat("Using age column:", age_col, "\n")
## Using age column: AGE_AT_VISIT
age_cohort <- age_data %>%
inner_join(raw_data %>% distinct(PATNO, COHORT), by = "PATNO") %>%
rename(age = all_of(age_col)) %>%
filter(!is.na(age), COHORT %in% c("PD", "Control"))
ggplot(age_cohort, aes(x = age, fill = COHORT)) +
geom_histogram(bins = 30, alpha = 0.7, position = "identity", color = "white") +
scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
labs(title = "Age Distribution by Cohort",
x = "Age", y = "Count", fill = "") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
legend.position = "bottom")
ggplot(age_cohort, aes(x = COHORT, y = age, fill = COHORT)) +
geom_boxplot(alpha = 0.7, show.legend = FALSE) +
geom_jitter(width = 0.15, alpha = 0.3, size = 1) +
scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
labs(title = "Age by Disease Status", x = "", y = "Age") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13))
age_cohort %>%
group_by(COHORT) %>%
summarise(
n = n(),
mean_age = round(mean(age, na.rm = TRUE), 1),
median_age = median(age, na.rm = TRUE),
sd_age = round(sd(age, na.rm = TRUE), 1),
min_age = min(age, na.rm = TRUE),
max_age = max(age, na.rm = TRUE)
) %>%
print()
## # A tibble: 2 × 7
## COHORT n mean_age median_age sd_age min_age max_age
## <chr> <int> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 Control 3582 65.4 66.5 12 30.6 98.1
## 2 PD 10961 66 66.8 9.8 32.3 93.6
glimpse(demographics)
## Rows: 8,351
## Columns: 29
## $ REC_ID <chr> "IA86904", "IA86905", "IA86906", "IA86907", "IA86908", "2…
## $ PATNO <dbl> 3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 300…
## $ EVENT_ID <chr> "TRANS", "TRANS", "TRANS", "TRANS", "TRANS", "TRANS", "TR…
## $ PAG_NAME <chr> "SCREEN", "SCREEN", "SCREEN", "SCREEN", "SCREEN", "SCREEN…
## $ INFODT <chr> "01/2011", "02/2011", "03/2011", "03/2011", "03/2011", "0…
## $ AFICBERB <dbl> 0, 0, 0, 0, 0, NA, NA, NA, 0, 0, 0, NA, NA, 0, NA, NA, 0,…
## $ ASHKJEW <dbl> 0, 0, 0, 0, 0, NA, NA, NA, 0, 0, 0, NA, NA, 0, NA, NA, 0,…
## $ BASQUE <dbl> 0, 0, 0, 0, 0, NA, NA, NA, 0, 0, 0, NA, NA, 0, NA, NA, 0,…
## $ BIRTHDT <chr> "12/1941", "01/1946", "08/1943", "07/1954", "11/1951", "1…
## $ SEX <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, …
## $ CHLDBEAR <dbl> 0, NA, 0, 0, NA, 0, 0, NA, 0, 0, NA, NA, NA, 0, NA, NA, N…
## $ HOWLIVE <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ GAYLES <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ HETERO <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ BISEXUAL <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ PANSEXUAL <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ ASEXUAL <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ OTHSEXUALITY <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ HANDED <dbl> 1, 2, 1, 1, 1, 3, 1, 2, 1, 1, 3, 1, 1, 3, 1, 1, 1, 1, 1, …
## $ HISPLAT <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RAASIAN <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RABLACK <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ RAHAWOPI <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RAINDALS <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RANOS <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ RAWHITE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, …
## $ RAUNKNOWN <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ORIG_ENTRY <chr> "01/2011", "02/2011", "03/2011", "03/2011", "03/2011", "0…
## $ LAST_UPDATE <dttm> 2022-11-07 00:00:00, 2022-11-07 00:00:00, 2022-11-07 00:…
# SEX coding: 0 = Female, 1 = Male
demo_cohort <- demographics %>%
inner_join(raw_data %>% distinct(PATNO, COHORT), by = "PATNO") %>%
distinct(PATNO, .keep_all = TRUE) %>%
filter(COHORT %in% c("PD", "Control")) %>%
mutate(SEX = factor(SEX, levels = c(0, 1),
labels = c("Female", "Male")))
demo_cohort %>%
count(COHORT, SEX) %>%
ggplot(aes(x = COHORT, y = n, fill = SEX)) +
geom_col(position = "dodge", width = 0.6) +
geom_text(aes(label = n), position = position_dodge(0.6),
vjust = -0.5, size = 3.5, fontface = "bold") +
scale_fill_manual(values = c("Male" = "#4472C4",
"Female" = "#ED7D31")) +
labs(title = "Sex Distribution by Cohort",
x = "", y = "Count", fill = "") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
legend.position = "bottom")
demo_cohort %>%
count(COHORT, SEX) %>%
group_by(COHORT) %>%
mutate(pct = round(100 * n / sum(n), 1)) %>%
print()
## # A tibble: 4 × 4
## # Groups: COHORT [2]
## COHORT SEX n pct
## <chr> <fct> <int> <dbl>
## 1 Control Female 69 37.1
## 2 Control Male 117 62.9
## 3 PD Female 244 39.5
## 4 PD Male 373 60.5
mutation_cols <- c("LRRK2", "GBA", "SNCA", "PRKN", "PINK1") # VPS35/PARK7 all 0
gen_cohort <- genetics %>%
inner_join(raw_data %>% distinct(PATNO, COHORT), by = "PATNO") %>%
distinct(PATNO, .keep_all = TRUE) %>%
filter(COHORT %in% c("PD", "Control")) %>%
mutate(across(all_of(mutation_cols), as.character))
gen_cohort %>%
pivot_longer(cols = all_of(mutation_cols),
names_to = "gene",
values_to = "status") %>%
mutate(carrier = status != "0" & !is.na(status)) %>%
group_by(COHORT, gene) %>%
summarise(n_carriers = sum(carrier, na.rm = TRUE), .groups = "drop") %>%
ggplot(aes(x = reorder(gene, -n_carriers), y = n_carriers, fill = COHORT)) +
geom_col(position = "dodge", width = 0.6) +
geom_text(aes(label = n_carriers), position = position_dodge(0.6),
vjust = -0.5, size = 3.5, fontface = "bold") +
scale_fill_manual(values = c("Control" = "#4472C4", "PD" = "#C00000")) +
labs(title = "Number of Mutation Carriers by Gene and Cohort",
x = "Gene", y = "Number of Carriers", fill = "") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
legend.position = "bottom")
For now I will only use PD and Control patients.
age_baseline <- age_data %>%
filter(EVENT_ID %in% c("BL", "SC")) %>%
arrange(PATNO, EVENT_ID) %>%
distinct(PATNO, .keep_all = TRUE) %>%
select(PATNO, AGE_AT_VISIT)
genetics_features <- genetics %>%
mutate(
LRRK2_carrier = as.integer(LRRK2 != "0" & !is.na(LRRK2)),
GBA_carrier = as.integer(GBA != "0" & !is.na(GBA)),
SNCA_carrier = as.integer(as.character(SNCA) != "0" & !is.na(SNCA)),
PRKN_carrier = as.integer(as.character(PRKN) != "0" & !is.na(PRKN)),
APOE_e4 = as.integer(grepl("E4", APOE, ignore.case = TRUE))
) %>%
select(PATNO, LRRK2_carrier, GBA_carrier,
SNCA_carrier, PRKN_carrier, APOE_e4)
cat("Patients with baseline age:", nrow(age_baseline), "\n")
## Patients with baseline age: 8118
cat("Patients with genetics: ", nrow(genetics_features), "\n")
## Patients with genetics: 6265
data_binary <- raw_data %>%
filter(COHORT %in% c("PD", "Control")) %>%
mutate(outcome = factor(ifelse(COHORT == "PD", 1, 0),
levels = c(0, 1),
labels = c("Control", "PD")))
data_binary %>% count(outcome)
## # A tibble: 2 × 2
## outcome n
## <fct> <int>
## 1 Control 890010
## 2 PD 2952345
wide_data <- data_binary %>%
select(PATNO, SEX, outcome, TESTNAME, TESTVALUE) %>%
pivot_wider(
names_from = TESTNAME,
values_from = TESTVALUE
) %>%
left_join(age_baseline, by = "PATNO") %>%
left_join(genetics_features, by = "PATNO")
dim(wide_data)
## [1] 803 4794
glimpse(wide_data %>% select(PATNO, SEX, outcome,
AGE_AT_VISIT, LRRK2_carrier,
GBA_carrier, SNCA_carrier,
PRKN_carrier, APOE_e4))
## Rows: 803
## Columns: 9
## $ PATNO <dbl> 53595, 3029, 3963, 3316, 3124, 3175, 4103, 3467, 3168, 5…
## $ SEX <chr> "Female", "Female", "Female", "Male", "Male", "Female", …
## $ outcome <fct> PD, Control, PD, Control, PD, PD, PD, PD, PD, PD, PD, PD…
## $ AGE_AT_VISIT <dbl> 65.2, 66.3, 58.4, 74.7, 57.2, 57.2, 59.0, 66.9, 63.1, 67…
## $ LRRK2_carrier <int> 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ GBA_carrier <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ SNCA_carrier <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ PRKN_carrier <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ APOE_e4 <int> 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,…
cat("Missing AGE_AT_VISIT: ", sum(is.na(wide_data$AGE_AT_VISIT)), "\n")
## Missing AGE_AT_VISIT: 0
cat("Missing LRRK2_carrier: ", sum(is.na(wide_data$LRRK2_carrier)), "\n")
## Missing LRRK2_carrier: 0
cat("LRRK2 carriers: ", sum(wide_data$LRRK2_carrier == 1, na.rm = TRUE), "\n")
## LRRK2 carriers: 156
cat("GBA carriers: ", sum(wide_data$GBA_carrier == 1, na.rm = TRUE), "\n")
## GBA carriers: 77
cat("APOE E4 carriers: ", sum(wide_data$APOE_e4 == 1, na.rm = TRUE), "\n")
## APOE E4 carriers: 197
Calculating the variance of each protein. Looking at the variance distribution of proteins and plotting to pick a cutoff. Selecting the top 500 proteins.
# Define clinical variables (excluded from protein variance calculation)
clinical_vars <- c("AGE_AT_VISIT", "LRRK2_carrier", "GBA_carrier",
"SNCA_carrier", "PRKN_carrier", "APOE_e4")
protein_cols <- wide_data %>%
select(-PATNO, -SEX, -outcome, -any_of(clinical_vars))
protein_vars <- protein_cols %>%
summarise(across(everything(), var, na.rm = TRUE)) %>%
pivot_longer(everything(),
names_to = "protein",
values_to = "variance") %>%
arrange(desc(variance))
## Warning: There was 1 warning in `summarise()`.
## ℹ In argument: `across(everything(), var, na.rm = TRUE)`.
## Caused by warning:
## ! The `...` argument of `across()` is deprecated as of dplyr 1.1.0.
## Supply arguments directly to `.fns` through an anonymous function instead.
##
## # Previously
## across(a:b, mean, na.rm = TRUE)
##
## # Now
## across(a:b, \(x) mean(x, na.rm = TRUE))
summary(protein_vars$variance)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.003911 0.025173 0.052631 0.122038 0.120528 12.827399
ggplot(protein_vars, aes(x = variance)) +
geom_histogram(bins = 100, fill = "steelblue", color = "white") +
geom_vline(xintercept = protein_vars$variance[500],
color = "red", linetype = "dashed", linewidth = 1) +
annotate("text",
x = protein_vars$variance[500],
y = Inf,
label = "Top 500 cutoff",
vjust = 2, hjust = -0.1, color = "red") +
labs(title = "Protein Variance Distribution",
x = "Variance",
y = "Count") +
theme_minimal()
top500_proteins <- protein_vars %>%
slice_head(n = 500) %>%
pull(protein)
data_filtered <- wide_data %>%
select(PATNO, SEX, outcome,
all_of(clinical_vars),
all_of(top500_proteins))
dim(data_filtered)
## [1] 803 509
cat("Clinical vars present:",
sum(clinical_vars %in% names(data_filtered)),
"of", length(clinical_vars), "\n")
## Clinical vars present: 6 of 6
Splitting the data before running the t-test to avoid data leakage — t-test feature selection is performed on the training set only.
set.seed(42)
data_split_pre <- initial_split(
data_filtered %>% select(-PATNO),
prop = 0.80,
strata = outcome
)
train_pre <- training(data_split_pre)
test_pre <- testing(data_split_pre)
cat("Pre-split training set:\n")
## Pre-split training set:
train_pre %>% count(outcome) %>% print()
## # A tibble: 2 × 2
## outcome n
## <fct> <int>
## 1 Control 148
## 2 PD 493
cat("\nPre-split test set:\n")
##
## Pre-split test set:
test_pre %>% count(outcome) %>% print()
## # A tibble: 2 × 2
## outcome n
## <fct> <int>
## 1 Control 38
## 2 PD 124
Running t-test on training data only (no leakage). Volcano plot of PD vs Control.
ttest_results <- train_pre %>%
select(outcome, all_of(top500_proteins)) %>%
pivot_longer(
cols = all_of(top500_proteins),
names_to = "protein",
values_to = "value"
) %>%
group_by(protein) %>%
summarise(
p_value = t.test(value ~ outcome)$p.value,
mean_PD = mean(value[outcome == "PD"], na.rm = TRUE),
mean_ctrl = mean(value[outcome == "Control"], na.rm = TRUE),
log2FC = mean_PD - mean_ctrl
) %>%
mutate(
p_adj = p.adjust(p_value, method = "BH"),
sig = p_adj < 0.05
) %>%
arrange(p_adj)
ttest_results %>% count(sig)
## # A tibble: 2 × 2
## sig n
## <lgl> <int>
## 1 FALSE 454
## 2 TRUE 46
head(ttest_results, 20)
## # A tibble: 20 × 7
## protein p_value mean_PD mean_ctrl log2FC p_adj sig
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <lgl>
## 1 3504-58_2 0.00000000789 9.33 8.95 0.375 0.00000395 TRUE
## 2 7757-5_3 0.0000000423 9.41 9.11 0.300 0.0000106 TRUE
## 3 5080-131_3 0.00000437 9.68 9.48 0.203 0.000728 TRUE
## 4 6521-35_3 0.00000716 11.7 11.9 -0.228 0.000895 TRUE
## 5 14125-5_3 0.0000162 8.99 8.77 0.219 0.00162 TRUE
## 6 13722-105_3 0.0000431 10.4 10.2 0.202 0.00214 TRUE
## 7 2292-17_4 0.0000401 9.66 9.45 0.205 0.00214 TRUE
## 8 2743-5_2 0.0000267 12.9 12.7 0.288 0.00214 TRUE
## 9 2900-53_3 0.0000418 9.47 9.24 0.234 0.00214 TRUE
## 10 4141-79_1 0.0000471 10.6 10.3 0.265 0.00214 TRUE
## 11 8240-207_3 0.0000430 11.6 11.4 0.221 0.00214 TRUE
## 12 6520-87_3 0.0000540 13.4 13.2 0.228 0.00225 TRUE
## 13 3535-84_1 0.0000771 8.55 8.34 0.204 0.00297 TRUE
## 14 13416-8_3 0.0000942 12.1 12.3 -0.233 0.00314 TRUE
## 15 3060-43_2 0.0000911 11.2 11.0 0.187 0.00314 TRUE
## 16 9900-36_3 0.000206 9.14 8.92 0.221 0.00642 TRUE
## 17 3415-61_2 0.000238 8.70 8.23 0.469 0.00662 TRUE
## 18 9796-4_3 0.000226 15.2 15.0 0.215 0.00662 TRUE
## 19 8890-9_3 0.000352 11.7 11.9 -0.197 0.00926 TRUE
## 20 9986-14_3 0.000414 9.36 9.16 0.206 0.0104 TRUE
ggplot(ttest_results, aes(x = log2FC, y = -log10(p_adj), color = sig)) +
geom_point(alpha = 0.6, size = 1.2) +
scale_color_manual(values = c("FALSE" = "grey60", "TRUE" = "#C00000"),
labels = c("Not significant", "Significant (BH < 0.05)")) +
geom_hline(yintercept = -log10(0.05), linetype = "dashed", color = "black") +
labs(title = "Volcano Plot: PD vs Control (t-test, training set)",
x = "Log2 Fold Change (PD - Control)",
y = "-log10(Adjusted p-value)",
color = "") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
legend.position = "bottom")
Selecting significant proteins and assembling the modelling dataset.
Using make_splits() to reconstruct the train/test split
object for last_fit().
sig_proteins <- ttest_results %>%
filter(sig == TRUE) %>%
pull(protein)
cat("Significant proteins:", length(sig_proteins), "\n")
## Significant proteins: 46
# Apply the same feature set to train and test
train_data <- train_pre %>%
select(SEX, outcome, all_of(clinical_vars), all_of(sig_proteins))
test_data <- test_pre %>%
select(SEX, outcome, all_of(clinical_vars), all_of(sig_proteins))
# Combine into one data frame and reconstruct the rsample split object
data_model <- bind_rows(train_data, test_data)
dim(data_model)
## [1] 803 54
data_model %>% count(outcome)
## # A tibble: 2 × 2
## outcome n
## <fct> <int>
## 1 Control 186
## 2 PD 617
data_model %>%
select(all_of(clinical_vars)) %>%
summarise(across(everything(), ~sum(is.na(.)))) %>%
pivot_longer(everything(),
names_to = "variable",
values_to = "n_missing") %>%
mutate(pct_missing = round(100 * n_missing / nrow(data_model), 1)) %>%
print()
## # A tibble: 6 × 3
## variable n_missing pct_missing
## <chr> <int> <dbl>
## 1 AGE_AT_VISIT 0 0
## 2 LRRK2_carrier 0 0
## 3 GBA_carrier 0 0
## 4 SNCA_carrier 0 0
## 5 PRKN_carrier 0 0
## 6 APOE_e4 0 0
# Reconstruct rsample split object so last_fit() uses the correct rows
data_split <- make_splits(
list(analysis = seq_len(nrow(train_data)),
assessment = seq(nrow(train_data) + 1, nrow(data_model))),
data = data_model
)
saveRDS(data_model, file.path(data_dir, "data_model.rds"))
set.seed(42)
cv_folds <- vfold_cv(
train_data,
v = 10,
strata = outcome
)
cv_folds
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [576/65]> Fold01
## 2 <split [576/65]> Fold02
## 3 <split [576/65]> Fold03
## 4 <split [577/64]> Fold04
## 5 <split [577/64]> Fold05
## 6 <split [577/64]> Fold06
## 7 <split [577/64]> Fold07
## 8 <split [577/64]> Fold08
## 9 <split [578/63]> Fold09
## 10 <split [578/63]> Fold10
# Recipe: remove SEX, impute missing clinical + protein values, normalize,
# and downsample the majority class to address class imbalance (617 PD vs 186 Control)
model_recipe <- recipe(outcome ~ ., data = train_data) %>%
step_rm(SEX) %>%
step_impute_median(all_of(clinical_vars)) %>%
step_impute_median(all_predictors()) %>%
step_normalize(all_predictors()) %>%
step_downsample(outcome, under_ratio = 1) # fix class imbalance
summary(model_recipe)
## # A tibble: 54 × 4
## variable type role source
## <chr> <list> <chr> <chr>
## 1 SEX <chr [3]> predictor original
## 2 AGE_AT_VISIT <chr [2]> predictor original
## 3 LRRK2_carrier <chr [2]> predictor original
## 4 GBA_carrier <chr [2]> predictor original
## 5 SNCA_carrier <chr [2]> predictor original
## 6 PRKN_carrier <chr [2]> predictor original
## 7 APOE_e4 <chr [2]> predictor original
## 8 3504-58_2 <chr [2]> predictor original
## 9 7757-5_3 <chr [2]> predictor original
## 10 5080-131_3 <chr [2]> predictor original
## # ℹ 44 more rows
log_spec <- logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification")
log_workflow <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(log_spec)
lasso_spec <- logistic_reg(
penalty = tune(),
mixture = 1
) %>%
set_engine("glmnet") %>%
set_mode("classification")
lasso_workflow <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(lasso_spec)
rf_spec <- rand_forest(
mtry = tune(),
trees = 500,
min_n = tune()
) %>%
set_engine("ranger", importance = "permutation") %>%
set_mode("classification")
rf_workflow <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(rf_spec)
log_workflow
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 5 Recipe Steps
##
## • step_rm()
## • step_impute_median()
## • step_impute_median()
## • step_normalize()
## • step_downsample()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Logistic Regression Model Specification (classification)
##
## Computational engine: glm
lasso_workflow
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 5 Recipe Steps
##
## • step_rm()
## • step_impute_median()
## • step_impute_median()
## • step_normalize()
## • step_downsample()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Logistic Regression Model Specification (classification)
##
## Main Arguments:
## penalty = tune()
## mixture = 1
##
## Computational engine: glmnet
rf_workflow
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 5 Recipe Steps
##
## • step_rm()
## • step_impute_median()
## • step_impute_median()
## • step_normalize()
## • step_downsample()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = tune()
## trees = 500
## min_n = tune()
##
## Engine-Specific Arguments:
## importance = permutation
##
## Computational engine: ranger
# ── Tune LASSO ────────────────────────────────────────────────────────────────
lasso_grid <- grid_regular(
penalty(range = c(-4, 0)),
levels = 50
)
set.seed(42)
lasso_tuned <- tune_grid(
lasso_workflow,
resamples = cv_folds,
grid = lasso_grid,
metrics = metric_set(roc_auc, accuracy, sens, spec),
control = control_grid(save_pred = TRUE)
)
## Warning: Using `all_of()` outside of a selecting function was deprecated in tidyselect
## 1.2.0.
## ℹ See details at
## <https://tidyselect.r-lib.org/reference/faq-selection-context.html>
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
lasso_best <- select_best(lasso_tuned, metric = "roc_auc")
cat("Best LASSO penalty:", lasso_best$penalty, "\n")
## Best LASSO penalty: 0.009102982
# ── Tune Random Forest ────────────────────────────────────────────────────────
rf_grid <- grid_regular(
mtry(range = c(2, 20)),
min_n(range = c(2, 20)),
levels = 5
)
set.seed(42)
rf_tuned <- tune_grid(
rf_workflow,
resamples = cv_folds,
grid = rf_grid,
metrics = metric_set(roc_auc, accuracy, sens, spec),
control = control_grid(save_pred = TRUE)
)
rf_best <- select_best(rf_tuned, metric = "roc_auc")
cat("Best RF mtry:", rf_best$mtry, "min_n:", rf_best$min_n, "\n")
## Best RF mtry: 11 min_n: 2
# ── Logistic Regression CV ────────────────────────────────────────────────────
set.seed(42)
log_cv_results <- fit_resamples(
log_workflow,
resamples = cv_folds,
metrics = metric_set(roc_auc, accuracy, sens, spec),
control = control_resamples(save_pred = TRUE)
)
## → A | warning: prediction from rank-deficient fit; attr(*, "non-estim") has doubtful cases
## There were issues with some computations A: x2There were issues with some computations A: x4There were issues with some computations A: x4
# ── Tuning plots ──────────────────────────────────────────────────────────────
autoplot(lasso_tuned, metric = "roc_auc") +
labs(title = "LASSO Tuning: ROC AUC vs Penalty") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13))
autoplot(rf_tuned, metric = "roc_auc") +
labs(title = "Random Forest Tuning: ROC AUC vs Hyperparameters") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13))
# ── Cross-validation comparison ───────────────────────────────────────────────
log_metrics <- collect_metrics(log_cv_results) %>% mutate(model = "Logistic")
lasso_metrics <- collect_metrics(lasso_tuned) %>%
filter(penalty == lasso_best$penalty) %>%
mutate(model = "LASSO")
rf_metrics <- collect_metrics(rf_tuned) %>%
filter(mtry == rf_best$mtry, min_n == rf_best$min_n) %>%
mutate(model = "Random Forest")
cv_comparison <- bind_rows(log_metrics, lasso_metrics, rf_metrics) %>%
select(model, .metric, mean, std_err) %>%
filter(.metric %in% c("roc_auc", "accuracy", "sens", "spec")) %>%
arrange(.metric, desc(mean))
print(cv_comparison, n = 50)
## # A tibble: 12 × 4
## model .metric mean std_err
## <chr> <chr> <dbl> <dbl>
## 1 Random Forest accuracy 0.705 0.0165
## 2 Logistic accuracy 0.694 0.0191
## 3 LASSO accuracy 0.671 0.0219
## 4 LASSO roc_auc 0.810 0.0221
## 5 Random Forest roc_auc 0.795 0.0195
## 6 Logistic roc_auc 0.777 0.0162
## 7 LASSO sens 0.811 0.0386
## 8 Random Forest sens 0.736 0.0504
## 9 Logistic sens 0.729 0.0326
## 10 Random Forest spec 0.696 0.0190
## 11 Logistic spec 0.683 0.0252
## 12 LASSO spec 0.629 0.0245
cv_comparison %>%
mutate(.metric = recode(.metric,
"roc_auc" = "ROC AUC",
"accuracy" = "Accuracy",
"sens" = "Sensitivity",
"spec" = "Specificity"
)) %>%
ggplot(aes(x = model, y = mean, fill = model)) +
geom_col(width = 0.6, show.legend = FALSE) +
geom_errorbar(aes(ymin = mean - std_err, ymax = mean + std_err),
width = 0.2) +
geom_text(aes(label = round(mean, 3)), vjust = -0.5, size = 3.5) +
facet_wrap(~ .metric, scales = "free_y") +
scale_fill_manual(values = c("Logistic" = "#4472C4",
"LASSO" = "#ED7D31",
"Random Forest" = "#70AD47")) +
labs(title = "Cross-Validation Performance by Model",
x = "", y = "Mean Score") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
axis.text.x = element_text(angle = 20, hjust = 1))
final_log_workflow <- log_workflow
final_lasso_workflow <- lasso_workflow %>% finalize_workflow(lasso_best)
final_rf_workflow <- rf_workflow %>% finalize_workflow(rf_best)
set.seed(42)
log_final <- last_fit(final_log_workflow, data_split,
metrics = metric_set(roc_auc, accuracy, sens, spec))
## → A | warning: prediction from rank-deficient fit; attr(*, "non-estim") has doubtful cases
## There were issues with some computations A: x2There were issues with some computations A: x2
lasso_final <- last_fit(final_lasso_workflow, data_split,
metrics = metric_set(roc_auc, accuracy, sens, spec))
rf_final <- last_fit(final_rf_workflow, data_split,
metrics = metric_set(roc_auc, accuracy, sens, spec))
# ── Test-set performance table ────────────────────────────────────────────────
test_metrics <- bind_rows(
collect_metrics(log_final) %>% mutate(model = "Logistic"),
collect_metrics(lasso_final) %>% mutate(model = "LASSO"),
collect_metrics(rf_final) %>% mutate(model = "Random Forest")
) %>%
select(model, .metric, .estimate) %>%
arrange(.metric, desc(.estimate))
print(test_metrics, n = 50)
## # A tibble: 12 × 3
## model .metric .estimate
## <chr> <chr> <dbl>
## 1 Logistic accuracy 0.722
## 2 Random Forest accuracy 0.716
## 3 LASSO accuracy 0.704
## 4 Logistic roc_auc 0.812
## 5 LASSO roc_auc 0.808
## 6 Random Forest roc_auc 0.778
## 7 LASSO sens 0.842
## 8 Logistic sens 0.789
## 9 Random Forest sens 0.711
## 10 Random Forest spec 0.718
## 11 Logistic spec 0.702
## 12 LASSO spec 0.661
test_metrics %>%
mutate(.metric = recode(.metric,
"roc_auc" = "ROC AUC",
"accuracy" = "Accuracy",
"sens" = "Sensitivity",
"spec" = "Specificity"
)) %>%
ggplot(aes(x = model, y = .estimate, fill = model)) +
geom_col(width = 0.6, show.legend = FALSE) +
geom_text(aes(label = round(.estimate, 3)), vjust = -0.5, size = 3.5) +
facet_wrap(~ .metric, scales = "free_y") +
scale_fill_manual(values = c("Logistic" = "#4472C4",
"LASSO" = "#ED7D31",
"Random Forest" = "#70AD47")) +
labs(title = "Test Set Performance by Model", x = "", y = "Score") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
axis.text.x = element_text(angle = 20, hjust = 1))
# ── Confusion matrices ────────────────────────────────────────────────────────
cat("=== Logistic Regression ===\n")
## === Logistic Regression ===
collect_predictions(log_final) %>%
conf_mat(truth = outcome, estimate = .pred_class) %>%
print()
## Truth
## Prediction Control PD
## Control 30 37
## PD 8 87
cat("\n=== LASSO ===\n")
##
## === LASSO ===
collect_predictions(lasso_final) %>%
conf_mat(truth = outcome, estimate = .pred_class) %>%
print()
## Truth
## Prediction Control PD
## Control 32 42
## PD 6 82
cat("\n=== Random Forest ===\n")
##
## === Random Forest ===
collect_predictions(rf_final) %>%
conf_mat(truth = outcome, estimate = .pred_class) %>%
print()
## Truth
## Prediction Control PD
## Control 27 35
## PD 11 89
# ── ROC curves ────────────────────────────────────────────────────────────────
log_roc <- collect_predictions(log_final) %>% mutate(model = "Logistic")
lasso_roc <- collect_predictions(lasso_final) %>% mutate(model = "LASSO")
rf_roc <- collect_predictions(rf_final) %>% mutate(model = "Random Forest")
bind_rows(log_roc, lasso_roc, rf_roc) %>%
group_by(model) %>%
roc_curve(truth = outcome, .pred_Control) %>%
ggplot(aes(x = 1 - specificity, y = sensitivity, color = model)) +
geom_line(linewidth = 1.2) +
geom_abline(linetype = "dashed", color = "grey50") +
scale_color_manual(values = c("Logistic" = "#4472C4",
"LASSO" = "#ED7D31",
"Random Forest" = "#70AD47")) +
labs(title = "ROC Curves on Test Set",
x = "1 - Specificity (False Positive Rate)",
y = "Sensitivity (True Positive Rate)",
color = "Model") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
legend.position = "bottom")
# ── LASSO: Top 20 coefficients with gene symbols ──────────────────────────────
lasso_final %>%
extract_fit_parsnip() %>%
tidy() %>%
filter(term != "(Intercept)", estimate != 0) %>%
arrange(desc(abs(estimate))) %>%
slice_head(n = 20) %>%
left_join(protein_lookup, by = c("term" = "SOMA_SEQ_ID")) %>%
mutate(protein_label = ifelse(is.na(TARGET_GENE_SYMBOL), term, TARGET_GENE_SYMBOL)) %>%
ggplot(aes(x = reorder(protein_label, abs(estimate)), y = estimate,
fill = estimate > 0)) +
geom_col(show.legend = FALSE) +
coord_flip() +
scale_fill_manual(values = c("TRUE" = "#C00000", "FALSE" = "#4472C4")) +
labs(title = "LASSO: Top 20 Coefficients",
x = "", y = "Coefficient") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
axis.text.y = element_text(size = 8))
# ── Random Forest: Top 20 variable importance with gene symbols ───────────────
rf_imp <- rf_final %>%
extract_fit_parsnip() %>%
vi() %>%
slice_head(n = 20) %>%
left_join(protein_lookup, by = c("Variable" = "SOMA_SEQ_ID")) %>%
mutate(protein_label = ifelse(is.na(TARGET_GENE_SYMBOL), Variable, TARGET_GENE_SYMBOL))
ggplot(rf_imp, aes(x = reorder(protein_label, Importance), y = Importance)) +
geom_col(fill = "#70AD47", show.legend = FALSE) +
coord_flip() +
labs(title = "Random Forest: Top 20 Variable Importance",
x = "", y = "Permutation Importance") +
theme_minimal() +
theme(plot.title = element_text(face = "bold", size = 13),
axis.text.y = element_text(size = 8))