knitr::opts_chunk$set(echo = TRUE, message = FALSE, warning = FALSE)
set.seed(123)

# 0. Packages
packages <- c(
  "dplyr", "tidyr", "ggplot2", "patchwork",
  "limma", "sva", "caret", "e1071", "glmnet",
  "pROC", "purrr", "tibble"
)

for (p in packages) {
  if (!requireNamespace(p, quietly = TRUE)) {
    install.packages(p)
  }
  library(p, character.only = TRUE)
}

use_lumi_mapping <- FALSE
if (requireNamespace("lumi", quietly = TRUE) &&
    requireNamespace("lumiHumanIDMapping", quietly = TRUE)) {
  use_lumi_mapping <- TRUE
  library(lumi)
  library(lumiHumanIDMapping)
}

1 1. Load data

data_dir <- "."

clinical_path <- file.path(data_dir, "ANMerge_clinical_under_90.csv")
gene_path     <- file.path(data_dir, "ANMerge_gene_expression_normalized_under_90.csv")
prot_path     <- file.path(data_dir, "ANMerge_Proteomics_under_90.csv")

clinical_raw <- read.csv(clinical_path, row.names = 1, stringsAsFactors = FALSE, check.names = FALSE)
gene_expr_raw <- read.csv(gene_path, row.names = 1, stringsAsFactors = FALSE, check.names = FALSE)
proteomics_raw <- read.csv(prot_path, row.names = 1, stringsAsFactors = FALSE, check.names = FALSE)

cat("Clinical dim:   ", dim(clinical_raw), "\n")
## Clinical dim:    4448 44
cat("Gene expr dim:  ", dim(gene_expr_raw), "\n")
## Gene expr dim:   691 5221
cat("Proteomics dim: ", dim(proteomics_raw), "\n")
## Proteomics dim:  898 1024
head(clinical_raw[, 1:min(10, ncol(clinical_raw))])
##            Subject_ID Sadman_ID Site Month Visit Max_Visit    Sex Diagnosis
## DCR00001_1   DCR00001  DCR00001  DCR     0     1         2 Female        AD
## DCR00001_2   DCR00001  DCR00001  DCR    12     2         2 Female        AD
## DCR00003_1   DCR00003  DCR00003  DCR     0     1         7 Female        AD
## DCR00003_2   DCR00003  DCR00003  DCR    12     2         7 Female        AD
## DCR00003_3   DCR00003  DCR00003  DCR    24     3         7 Female        AD
## DCR00003_5   DCR00003  DCR00003  DCR    48     5         7 Female        AD
##            BL_Diagnosis Final_Diagnosis
## DCR00001_1           AD              AD
## DCR00001_2           AD              AD
## DCR00003_1           AD              AD
## DCR00003_2           AD              AD
## DCR00003_3           AD              AD
## DCR00003_5           AD              AD

2 2. Basic checks and transformations

check_range <- function(data, name) {
  numeric_data <- data[, sapply(data, is.numeric), drop = FALSE]
  cat("\n====", name, "====\n")
  cat("Min:   ", min(as.matrix(numeric_data), na.rm = TRUE), "\n")
  cat("Q1:    ", quantile(as.matrix(numeric_data), 0.25, na.rm = TRUE), "\n")
  cat("Median:", median(as.matrix(numeric_data), na.rm = TRUE), "\n")
  cat("Mean:  ", mean(as.matrix(numeric_data), na.rm = TRUE), "\n")
  cat("Q3:    ", quantile(as.matrix(numeric_data), 0.75, na.rm = TRUE), "\n")
  cat("Max:   ", max(as.matrix(numeric_data), na.rm = TRUE), "\n")
}

check_range(clinical_raw, "Clinical")
## 
## ==== Clinical ====
## Min:    0 
## Q1:     1 
## Median: 6 
## Mean:   81.18859 
## Q3:     20 
## Max:    2014
check_range(gene_expr_raw, "Gene Expression")
## 
## ==== Gene Expression ====
## Min:    0 
## Q1:     7.501708 
## Median: 8.504594 
## Mean:   8.752086 
## Q3:     9.725001 
## Max:    89
check_range(proteomics_raw, "Proteomics raw")
## 
## ==== Proteomics raw ====
## Min:    0 
## Q1:     927.7 
## Median: 1924.2 
## Mean:   8237.674 
## Q3:     5523.4 
## Max:    537370.2
# Proteomics is intensity-like, so log2(x + 1).
proteomics_numeric <- proteomics_raw[, sapply(proteomics_raw, is.numeric), drop = FALSE]
proteomics_log <- proteomics_numeric
proteomics_log[] <- log2(proteomics_numeric + 1)
check_range(proteomics_log, "Proteomics log2(x + 1)")
## 
## ==== Proteomics log2(x + 1) ====
## Min:    0 
## Q1:     9.859069 
## Median: 10.91079 
## Mean:   11.27747 
## Q3:     12.4316 
## Max:    19.03556

3 3. Remove embedded clinical columns and prepare omics matrices

clin_cols <- c("Visit", "Month", "Site", "Diagnosis", "Sex", "Age", "APOE", "MMSE", "Gexp_batch")

batch_vec_all <- gene_expr_raw$Gexp_batch
names(batch_vec_all) <- rownames(gene_expr_raw)

site_vec_all <- gene_expr_raw$Site
names(site_vec_all) <- rownames(gene_expr_raw)

# Gene expression: remove clinical columns.
gene_expr_clean <- gene_expr_raw[, !colnames(gene_expr_raw) %in% clin_cols, drop = FALSE]

# Keep numeric gene columns only.
gene_expr_clean <- gene_expr_clean[, sapply(gene_expr_clean, is.numeric), drop = FALSE]
cat("Gene expression features after removing clinical cols:", ncol(gene_expr_clean), "\n")
## Gene expression features after removing clinical cols: 5212
# Proteomics: remove clinical columns if present.
proteomics_clean <- proteomics_log[, !colnames(proteomics_log) %in% clin_cols, drop = FALSE]
proteomics_clean <- proteomics_clean[, sapply(proteomics_clean, is.numeric), drop = FALSE]
cat("Proteomics features:", ncol(proteomics_clean), "\n")
## Proteomics features: 1016
cat("\nGene batch distribution:\n")
## 
## Gene batch distribution:
print(table(batch_vec_all))
## batch_vec_all
##   1   2 
## 341 350
cat("\nGene site × batch:\n")
## 
## Gene site × batch:
print(table(gene_expr_raw$Site, gene_expr_raw$Gexp_batch))
##               
##                  1   2
##   ART            0   8
##   DCR            0 115
##   Kuopio        93  43
##   Lodz          44  44
##   London        44  42
##   Perugia       83  63
##   Thessaloniki  43  24
##   Toulouse      34  11
cat("\nGene site × diagnosis:\n")
## 
## Gene site × diagnosis:
print(table(gene_expr_raw$Site, gene_expr_raw$Diagnosis))
##               
##                AD CTL MCI
##   ART           0   2   6
##   DCR          30  60  25
##   Kuopio       44  45  47
##   Lodz         39  25  24
##   London       21  41  24
##   Perugia      49  46  51
##   Thessaloniki 25  11  31
##   Toulouse     15  19  11

4 4. Probe to gene symbol mapping / averaging

if (use_lumi_mapping) {
  cat("Using lumiHumanIDMapping for probe-to-gene-symbol conversion.\n")
  probe_ids <- colnames(gene_expr_clean)
  probe_ids_no_x <- sub("^X", "", probe_ids)
  mapped <- tryCatch(
    lumi::nuID2IlluminaID(probe_ids_no_x, lib.mapping = "lumiHumanIDMapping"),
    error = function(e) NULL
  )
  
  if (!is.null(mapped)) {
    gene_anno <- data.frame(
      nuID = rownames(mapped),
      ILMN_ID = mapped[, "Probe_Id"],
      Gene_Symbol = mapped[, "Symbol"],
      Accession = mapped[, "Accession"],
      stringsAsFactors = FALSE
    )
    matched_gene <- match(probe_ids_no_x, gene_anno$nuID)
    keep_probe <- !is.na(matched_gene) &
      gene_anno$Gene_Symbol[matched_gene] != "" &
      !is.na(gene_anno$Gene_Symbol[matched_gene])
    
    cat("Probes total:       ", length(probe_ids_no_x), "\n")
    cat("Probes with symbol: ", sum(keep_probe), "\n")
    cat("Dropped no symbol:  ", sum(!keep_probe), "\n")
    
    gene_expr_clean <- gene_expr_clean[, keep_probe, drop = FALSE]
    colnames(gene_expr_clean) <- make.names(gene_anno$Gene_Symbol[matched_gene[keep_probe]], unique = FALSE)
    gene_expr_clean <- as.data.frame(t(limma::avereps(t(as.matrix(gene_expr_clean)))))
  }
} else {
  cat("lumi mapping package not available. Keeping probe IDs as gene-expression features.\n")
  colnames(gene_expr_clean) <- make.names(colnames(gene_expr_clean), unique = TRUE)
}
## lumi mapping package not available. Keeping probe IDs as gene-expression features.
protein_mapping_file <- "updated_proteomics_mapping.csv"

if (file.exists(protein_mapping_file)) {
  protein_map <- read.csv(protein_mapping_file, check.names = FALSE)
  
  # Make column names R-friendly for easier access.
  colnames(protein_map) <- make.names(colnames(protein_map))
  
  # Expected columns after make.names(): UniProt, EntrezGeneSymbol, Public.Name, Target
  protein_map <- protein_map %>%
    mutate(
      Protein_Label = dplyr::case_when(
        !is.na(EntrezGeneSymbol) & EntrezGeneSymbol != "" ~ EntrezGeneSymbol,
        !is.na(Public.Name) & Public.Name != "" ~ Public.Name,
        !is.na(Target) & Target != "" ~ Target,
        TRUE ~ UniProt
      )
    ) %>%
    distinct(UniProt, .keep_all = TRUE)
  
  matched_protein <- match(colnames(proteomics_clean), protein_map$UniProt)
  new_protein_names <- ifelse(
    is.na(matched_protein),
    colnames(proteomics_clean),
    protein_map$Protein_Label[matched_protein]
  )
  
  colnames(proteomics_clean) <- make.names(new_protein_names, unique = FALSE)
  

  proteomics_clean <- as.data.frame(t(limma::avereps(t(as.matrix(proteomics_clean)))))
  colnames(proteomics_clean) <- make.names(colnames(proteomics_clean), unique = TRUE)
  
  cat("Proteomics mapping file found. Proteins mapped:", sum(!is.na(matched_protein)), "of", length(matched_protein), "\n")
} else {
  
  colnames(proteomics_clean) <- make.names(colnames(proteomics_clean), unique = TRUE)
  cat("Proteomics mapping file not found. Keeping UniProt IDs as protein feature names.\n")
}
## Proteomics mapping file found. Proteins mapped: 974 of 1016
cat("Final gene-expression feature count:", ncol(gene_expr_clean), "\n")
## Final gene-expression feature count: 5212
cat("Final proteomics feature count:     ", ncol(proteomics_clean), "\n")
## Final proteomics feature count:      1015

5 5. Align samples and build clinical feature matrix

common_ids <- Reduce(intersect, list(
  rownames(clinical_raw),
  rownames(gene_expr_clean),
  rownames(proteomics_clean)
))

clinical_common   <- clinical_raw[common_ids, , drop = FALSE]
gene_expr_common  <- gene_expr_clean[common_ids, , drop = FALSE]
proteomics_common <- proteomics_clean[common_ids, , drop = FALSE]

cat("Common samples across three datasets:", length(common_ids), "\n")
## Common samples across three datasets: 339
cat("Visit distribution:\n")
## Visit distribution:
print(table(clinical_common$Visit))
## 
##   1 
## 339
# Baseline visit only.
baseline_ids <- rownames(clinical_common)[clinical_common$Visit == 1]
clinical_final   <- clinical_common[baseline_ids, , drop = FALSE]
gene_expr_final  <- gene_expr_common[baseline_ids, , drop = FALSE]
proteomics_final <- proteomics_common[baseline_ids, , drop = FALSE]

# Keep AD/MCI/CTL only.
keep <- clinical_final$Diagnosis %in% c("CTL", "MCI", "AD")
clinical_final   <- clinical_final[keep, , drop = FALSE]
gene_expr_final  <- gene_expr_final[keep, , drop = FALSE]
proteomics_final <- proteomics_final[keep, , drop = FALSE]

status_final <- factor(clinical_final$Diagnosis, levels = c("CTL", "MCI", "AD"))

# Clinical feature matrix. 
sex_lower <- tolower(trimws(clinical_final$Sex))
clinical_features <- data.frame(
  age = as.numeric(clinical_final$Age),
  sex_male = ifelse(is.na(sex_lower), NA, ifelse(sex_lower == "female", 0L, 1L)),
  row.names = rownames(clinical_final)
)

if ("APOE" %in% colnames(clinical_final)) {
  clinical_features$apoe_e4 <- ifelse(
    is.na(clinical_final$APOE), NA,
    as.integer(grepl("E4", toupper(clinical_final$APOE)))
  )
}

if ("Fulltime_Education_Years" %in% colnames(clinical_final)) {
  clinical_features$education <- as.numeric(clinical_final$Fulltime_Education_Years)
}

if ("Father_Dementia" %in% colnames(clinical_final)) {
  clinical_features$father_dementia <- as.integer(tolower(clinical_final$Father_Dementia) == "yes")
}

if ("Mother_Dementia" %in% colnames(clinical_final)) {
  clinical_features$mother_dementia <- as.integer(tolower(clinical_final$Mother_Dementia) == "yes")
}

# Remove samples with missing clinical features. Omics missing values are handled inside CV.
na_clin <- rowSums(is.na(clinical_features)) > 0
if (any(na_clin)) {
  clinical_features <- clinical_features[!na_clin, , drop = FALSE]
  gene_expr_final <- gene_expr_final[!na_clin, , drop = FALSE]
  proteomics_final <- proteomics_final[!na_clin, , drop = FALSE]
  status_final <- status_final[!na_clin]
}

final_ids <- rownames(clinical_features)
batch_final <- batch_vec_all[final_ids]
site_final <- site_vec_all[final_ids]

cat("\nFinal sample distribution:\n")
## 
## Final sample distribution:
print(table(status_final))
## status_final
## CTL MCI  AD 
##  87  97 140
cat("Clinical features:", dim(clinical_features), "\n")
## Clinical features: 324 6
cat("Gene expression:  ", dim(gene_expr_final), "\n")
## Gene expression:   324 5212
cat("Proteomics:       ", dim(proteomics_final), "\n")
## Proteomics:        324 1015
cat("Batch distribution:\n")
## Batch distribution:
print(table(batch_final))
## batch_final
##   1   2 
## 263  61
stopifnot(all(rownames(clinical_features) == rownames(gene_expr_final)))
stopifnot(all(rownames(clinical_features) == rownames(proteomics_final)))
stopifnot(length(status_final) == nrow(clinical_features))

6 6. Exploratory PCA plots

# Gene expression PCA before ComBat, final aligned samples.
gene_mat <- as.matrix(gene_expr_final)
gv <- apply(gene_mat, 2, var, na.rm = TRUE)
gene_mat <- gene_mat[, gv > 0, drop = FALSE]

pca_g <- prcomp(gene_mat, scale. = TRUE)
ve_g <- round(summary(pca_g)$importance[2, 1:2] * 100, 1)

df_g <- data.frame(
  PC1 = pca_g$x[, 1],
  PC2 = pca_g$x[, 2],
  Batch = factor(batch_final),
  Site = factor(site_final),
  Diagnosis = status_final
)

p_g_batch <- ggplot(df_g, aes(PC1, PC2, color = Batch)) +
  geom_point(size = 2, alpha = 0.75) +
  theme_bw() +
  labs(title = "Gene expression PCA by batch",
       x = paste0("PC1 (", ve_g[1], "%)"),
       y = paste0("PC2 (", ve_g[2], "%)"))

p_g_diag <- ggplot(df_g, aes(PC1, PC2, color = Diagnosis)) +
  geom_point(size = 2, alpha = 0.75) +
  theme_bw() +
  labs(title = "Gene expression PCA by diagnosis",
       x = paste0("PC1 (", ve_g[1], "%)"),
       y = paste0("PC2 (", ve_g[2], "%)"))

print(p_g_batch + p_g_diag)

ggsave("pca_gene_final.png", p_g_batch + p_g_diag, width = 12, height = 5, dpi = 150)

# Proteomics PCA.
prot_mat <- as.matrix(proteomics_final)
pv <- apply(prot_mat, 2, var, na.rm = TRUE)
prot_mat <- prot_mat[, pv > 0, drop = FALSE]

pca_p <- prcomp(prot_mat, scale. = TRUE)
ve_p <- round(summary(pca_p)$importance[2, 1:2] * 100, 1)

df_p <- data.frame(
  PC1 = pca_p$x[, 1],
  PC2 = pca_p$x[, 2],
  Site = factor(site_final),
  Diagnosis = status_final
)

p_p_site <- ggplot(df_p, aes(PC1, PC2, color = Site)) +
  geom_point(size = 2, alpha = 0.75) +
  theme_bw() +
  labs(title = "Proteomics PCA by site",
       x = paste0("PC1 (", ve_p[1], "%)"),
       y = paste0("PC2 (", ve_p[2], "%)"))

p_p_diag <- ggplot(df_p, aes(PC1, PC2, color = Diagnosis)) +
  geom_point(size = 2, alpha = 0.75) +
  theme_bw() +
  labs(title = "Proteomics PCA by diagnosis",
       x = paste0("PC1 (", ve_p[1], "%)"),
       y = paste0("PC2 (", ve_p[2], "%)"))

print(p_p_site + p_p_diag)

ggsave("pca_proteomics_final.png", p_p_site + p_p_diag, width = 12, height = 5, dpi = 150)

7 7. Nested 5-fold CV: multiclass SVM

This is the main modelling section.

remove_zero_var <- function(x) {
  x <- as.data.frame(x)
  v <- apply(x, 2, var, na.rm = TRUE)
  x[, is.finite(v) & v > 0, drop = FALSE]
}

median_impute_fit <- function(x) {
  x <- as.data.frame(x)
  meds <- sapply(x, function(z) median(z, na.rm = TRUE))
  meds[is.na(meds)] <- 0
  meds
}

median_impute_apply <- function(x, meds) {
  x <- as.data.frame(x)
  common <- intersect(colnames(x), names(meds))
  x <- x[, common, drop = FALSE]
  meds <- meds[common]
  for (j in seq_along(common)) {
    idx <- is.na(x[[j]])
    if (any(idx)) x[[j]][idx] <- meds[j]
  }
  x
}

scale_fit <- function(x) {
  x <- as.data.frame(x)
  mu <- sapply(x, mean, na.rm = TRUE)
  sdv <- sapply(x, sd, na.rm = TRUE)
  sdv[is.na(sdv) | sdv == 0] <- 1
  list(mean = mu, sd = sdv)
}

scale_apply <- function(x, scaler) {
  x <- as.data.frame(x)
  common <- intersect(colnames(x), names(scaler$mean))
  x <- x[, common, drop = FALSE]
  out <- sweep(as.matrix(x), 2, scaler$mean[common], "-")
  out <- sweep(out, 2, scaler$sd[common], "/")
  as.data.frame(out)
}

# Multinomial LASSO feature selection for gene expression.
select_genes_lasso <- function(x_train_gene, y_train, max_genes = 80) {
  x_train_gene <- remove_zero_var(x_train_gene)
  if (ncol(x_train_gene) < 2) return(colnames(x_train_gene))
  
  meds <- median_impute_fit(x_train_gene)
  x_imp <- median_impute_apply(x_train_gene, meds)
  sc <- scale_fit(x_imp)
  x_scaled <- scale_apply(x_imp, sc)
  
  fit <- tryCatch({
    cv.glmnet(
      x = as.matrix(x_scaled),
      y = y_train,
      family = "multinomial",
      alpha = 1,
      type.measure = "class",
      nfolds = 5,
      standardize = FALSE
    )
  }, error = function(e) NULL)
  
  if (is.null(fit)) {
    v <- sort(apply(x_train_gene, 2, var, na.rm = TRUE), decreasing = TRUE)
    return(names(v)[1:min(max_genes, length(v))])
  }
  
  coefs <- coef(fit, s = "lambda.min")
  selected <- unique(unlist(lapply(coefs, function(m) rownames(m)[as.vector(m != 0)])))
  selected <- setdiff(selected, "(Intercept)")
  
  if (length(selected) == 0) {
    v <- sort(apply(x_train_gene, 2, var, na.rm = TRUE), decreasing = TRUE)
    selected <- names(v)[1:min(max_genes, length(v))]
  }
  
  selected[1:min(max_genes, length(selected))]
}

make_feature_set <- function(feature_name, clinical_x, prot_x, gene_x, selected_genes = NULL) {
  if (feature_name == "clinical") {
    out <- clinical_x
  } else if (feature_name == "protein") {
    out <- prot_x
  } else if (feature_name == "gene") {
    out <- gene_x[, selected_genes, drop = FALSE]
  } else if (feature_name == "combined") {
    out <- cbind(
      clinical_x,
      prot_x,
      gene_x[, selected_genes, drop = FALSE]
    )
  } else {
    stop("Unknown feature set")
  }
  out <- as.data.frame(out)
  colnames(out) <- make.names(colnames(out), unique = TRUE)
  out
}

prepare_train_valid <- function(x_train, x_valid) {
  x_train <- remove_zero_var(x_train)
  x_valid <- x_valid[, colnames(x_train), drop = FALSE]
  meds <- median_impute_fit(x_train)
  x_train_imp <- median_impute_apply(x_train, meds)
  x_valid_imp <- median_impute_apply(x_valid, meds)
  sc <- scale_fit(x_train_imp)
  x_train_scaled <- scale_apply(x_train_imp, sc)
  x_valid_scaled <- scale_apply(x_valid_imp, sc)
  list(train = x_train_scaled, valid = x_valid_scaled, medians = meds, scaler = sc)
}

macro_f1 <- function(pred, truth) {
  lv <- levels(truth)
  f1s <- sapply(lv, function(cl) {
    tp <- sum(pred == cl & truth == cl)
    fp <- sum(pred == cl & truth != cl)
    fn <- sum(pred != cl & truth == cl)
    precision <- ifelse(tp + fp == 0, NA, tp / (tp + fp))
    recall <- ifelse(tp + fn == 0, NA, tp / (tp + fn))
    if (is.na(precision) || is.na(recall) || precision + recall == 0) return(NA)
    2 * precision * recall / (precision + recall)
  })
  mean(f1s, na.rm = TRUE)
}

balanced_accuracy <- function(pred, truth) {
  lv <- levels(truth)
  sens <- sapply(lv, function(cl) {
    tp <- sum(pred == cl & truth == cl)
    fn <- sum(pred != cl & truth == cl)
    ifelse(tp + fn == 0, NA, tp / (tp + fn))
  })
  mean(sens, na.rm = TRUE)
}

per_class_metrics <- function(pred, truth) {
  lv <- levels(truth)
  do.call(rbind, lapply(lv, function(cl) {
    tp <- sum(pred == cl & truth == cl)
    tn <- sum(pred != cl & truth != cl)
    fp <- sum(pred == cl & truth != cl)
    fn <- sum(pred != cl & truth == cl)
    data.frame(
      class = cl,
      sensitivity = ifelse(tp + fn == 0, NA, tp / (tp + fn)),
      specificity = ifelse(tn + fp == 0, NA, tn / (tn + fp)),
      precision = ifelse(tp + fp == 0, NA, tp / (tp + fp)),
      f1 = {
        prec <- ifelse(tp + fp == 0, NA, tp / (tp + fp))
        rec <- ifelse(tp + fn == 0, NA, tp / (tp + fn))
        ifelse(is.na(prec) | is.na(rec) | prec + rec == 0, NA, 2 * prec * rec / (prec + rec))
      }
    )
  }))
}

evaluate_pred <- function(pred, truth) {
  data.frame(
    accuracy = mean(pred == truth),
    macro_f1 = macro_f1(pred, truth),
    balanced_accuracy = balanced_accuracy(pred, truth)
  )
}
# Main nested CV settings
feature_sets <- c("clinical", "protein", "gene", "combined")
svm_grid <- expand.grid(
  cost = c(0.1, 1, 10),
  gamma = c(0.001, 0.01, 0.1)
)

outer_k <- 5
inner_k <- 5
max_genes <- 80

outer_folds <- caret::createFolds(status_final, k = outer_k, returnTrain = FALSE)

all_outer_predictions <- list()
all_outer_metrics <- list()
all_inner_results <- list()

for (outer_i in seq_along(outer_folds)) {
  cat("\n==============================\n")
  cat("Outer fold", outer_i, "of", outer_k, "\n")
  cat("==============================\n")
  
  test_idx <- outer_folds[[outer_i]]
  dev_idx <- setdiff(seq_along(status_final), test_idx)
  
  y_dev <- status_final[dev_idx]
  y_test <- status_final[test_idx]
  
  clinical_dev <- clinical_features[dev_idx, , drop = FALSE]
  clinical_test <- clinical_features[test_idx, , drop = FALSE]
  
  prot_dev <- proteomics_final[dev_idx, , drop = FALSE]
  prot_test <- proteomics_final[test_idx, , drop = FALSE]
  
  gene_dev_raw <- gene_expr_final[dev_idx, , drop = FALSE]
  gene_test_raw <- gene_expr_final[test_idx, , drop = FALSE]
  
  # B. ComBat on dev + test jointly, mod = NULL, no diagnosis labels.
  gene_outer <- rbind(gene_dev_raw, gene_test_raw)
  batch_outer <- batch_final[c(dev_idx, test_idx)]
  
  if (length(unique(batch_outer)) > 1) {
    gene_outer_cb <- t(sva::ComBat(
      dat = t(as.matrix(gene_outer)),
      batch = batch_outer,
      mod = NULL,
      par.prior = TRUE,
      prior.plots = FALSE
    ))
  } else {
    gene_outer_cb <- as.matrix(gene_outer)
  }
  
  gene_dev <- as.data.frame(gene_outer_cb[seq_along(dev_idx), , drop = FALSE])
  gene_test <- as.data.frame(gene_outer_cb[(length(dev_idx) + 1):nrow(gene_outer_cb), , drop = FALSE])
  rownames(gene_dev) <- rownames(gene_dev_raw)
  rownames(gene_test) <- rownames(gene_test_raw)
  colnames(gene_dev) <- colnames(gene_expr_final)
  colnames(gene_test) <- colnames(gene_expr_final)
  
  inner_folds <- caret::createFolds(y_dev, k = inner_k, returnTrain = FALSE)
  
  for (fs in feature_sets) {
    cat("\nFeature set:", fs, "\n")
    
    inner_scores <- list()
    
    for (inner_i in seq_along(inner_folds)) {
      val_local <- inner_folds[[inner_i]]
      train_local <- setdiff(seq_along(y_dev), val_local)
      
      y_inner_train <- y_dev[train_local]
      y_inner_val <- y_dev[val_local]
      
      # C2. LASSO on inner_train only for gene-containing models.
      selected_genes <- NULL
      if (fs %in% c("gene", "combined")) {
        selected_genes <- select_genes_lasso(
          x_train_gene = gene_dev[train_local, , drop = FALSE],
          y_train = y_inner_train,
          max_genes = max_genes
        )
      }
      
      x_inner_train <- make_feature_set(
        fs,
        clinical_x = clinical_dev[train_local, , drop = FALSE],
        prot_x = prot_dev[train_local, , drop = FALSE],
        gene_x = gene_dev[train_local, , drop = FALSE],
        selected_genes = selected_genes
      )
      
      x_inner_val <- make_feature_set(
        fs,
        clinical_x = clinical_dev[val_local, , drop = FALSE],
        prot_x = prot_dev[val_local, , drop = FALSE],
        gene_x = gene_dev[val_local, , drop = FALSE],
        selected_genes = selected_genes
      )
      
      prep <- prepare_train_valid(x_inner_train, x_inner_val)
      
      for (g in seq_len(nrow(svm_grid))) {
        model <- e1071::svm(
          x = prep$train,
          y = y_inner_train,
          kernel = "radial",
          cost = svm_grid$cost[g],
          gamma = svm_grid$gamma[g],
          probability = TRUE
        )
        pred <- predict(model, prep$valid)
        met <- evaluate_pred(factor(pred, levels = levels(status_final)), y_inner_val)
        
        inner_scores[[length(inner_scores) + 1]] <- data.frame(
          outer_fold = outer_i,
          inner_fold = inner_i,
          feature_set = fs,
          cost = svm_grid$cost[g],
          gamma = svm_grid$gamma[g],
          n_features = ncol(prep$train),
          n_selected_genes = ifelse(is.null(selected_genes), 0, length(selected_genes)),
          met
        )
      }
    }
    
    inner_df <- bind_rows(inner_scores)
    all_inner_results[[length(all_inner_results) + 1]] <- inner_df
    
    # C5. Average over inner folds and select best hyperparameters.
    best_row <- inner_df %>%
      group_by(feature_set, cost, gamma) %>%
      summarise(
        mean_macro_f1 = mean(macro_f1, na.rm = TRUE),
        mean_accuracy = mean(accuracy, na.rm = TRUE),
        .groups = "drop"
      ) %>%
      arrange(desc(mean_macro_f1), desc(mean_accuracy)) %>%
      slice(1)
    
    best_cost <- best_row$cost[1]
    best_gamma <- best_row$gamma[1]
    cat("Best cost =", best_cost, "| best gamma =", best_gamma, "\n")
    
    # D. Re-run LASSO on full dev only.
    selected_genes_final <- NULL
    if (fs %in% c("gene", "combined")) {
      selected_genes_final <- select_genes_lasso(gene_dev, y_dev, max_genes = max_genes)
      cat("Final selected genes:", length(selected_genes_final), "\n")
    }
    
    x_dev_final <- make_feature_set(
      fs,
      clinical_x = clinical_dev,
      prot_x = prot_dev,
      gene_x = gene_dev,
      selected_genes = selected_genes_final
    )
    
    x_test_final <- make_feature_set(
      fs,
      clinical_x = clinical_test,
      prot_x = prot_test,
      gene_x = gene_test,
      selected_genes = selected_genes_final
    )
    
    prep_final <- prepare_train_valid(x_dev_final, x_test_final)
    
    # E. Final fit on full dev, predict locked-away outer test.
    final_model <- e1071::svm(
      x = prep_final$train,
      y = y_dev,
      kernel = "radial",
      cost = best_cost,
      gamma = best_gamma,
      probability = TRUE
    )
    
    pred_test <- predict(final_model, prep_final$valid)
    pred_test <- factor(pred_test, levels = levels(status_final))
    
    met_test <- evaluate_pred(pred_test, y_test)
    
    all_outer_predictions[[length(all_outer_predictions) + 1]] <- data.frame(
      sample_id = rownames(clinical_features)[test_idx],
      outer_fold = outer_i,
      feature_set = fs,
      truth = y_test,
      prediction = pred_test
    )
    
    all_outer_metrics[[length(all_outer_metrics) + 1]] <- data.frame(
      outer_fold = outer_i,
      feature_set = fs,
      cost = best_cost,
      gamma = best_gamma,
      n_features = ncol(prep_final$train),
      n_selected_genes = ifelse(is.null(selected_genes_final), 0, length(selected_genes_final)),
      met_test
    )
  }
}
## 
## ==============================
## Outer fold 1 of 5 
## ==============================
## 
## Feature set: clinical 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: protein 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: gene 
## Best cost = 0.1 | best gamma = 0.001 
## Final selected genes: 7 
## 
## Feature set: combined 
## Best cost = 0.1 | best gamma = 0.001 
## Final selected genes: 17 
## 
## ==============================
## Outer fold 2 of 5 
## ==============================
## 
## Feature set: clinical 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: protein 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: gene 
## Best cost = 0.1 | best gamma = 0.001 
## Final selected genes: 16 
## 
## Feature set: combined 
## Best cost = 0.1 | best gamma = 0.001 
## Final selected genes: 12 
## 
## ==============================
## Outer fold 3 of 5 
## ==============================
## 
## Feature set: clinical 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: protein 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: gene 
## Best cost = 0.1 | best gamma = 0.001 
## Final selected genes: 18 
## 
## Feature set: combined 
## Best cost = 0.1 | best gamma = 0.001 
## Final selected genes: 25 
## 
## ==============================
## Outer fold 4 of 5 
## ==============================
## 
## Feature set: clinical 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: protein 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: gene 
## Best cost = 0.1 | best gamma = 0.001 
## Final selected genes: 45 
## 
## Feature set: combined 
## Best cost = 0.1 | best gamma = 0.001 
## Final selected genes: 17 
## 
## ==============================
## Outer fold 5 of 5 
## ==============================
## 
## Feature set: clinical 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: protein 
## Best cost = 0.1 | best gamma = 0.001 
## 
## Feature set: gene 
## Best cost = 0.1 | best gamma = 0.1 
## Final selected genes: 5 
## 
## Feature set: combined 
## Best cost = 0.1 | best gamma = 0.001 
## Final selected genes: 19
outer_predictions <- bind_rows(all_outer_predictions)
outer_metrics <- bind_rows(all_outer_metrics)
inner_results <- bind_rows(all_inner_results)

write.csv(outer_predictions, "nested_svm_outer_predictions.csv", row.names = FALSE)
write.csv(outer_metrics, "nested_svm_outer_metrics.csv", row.names = FALSE)
write.csv(inner_results, "nested_svm_inner_results.csv", row.names = FALSE)

outer_metrics
##    outer_fold feature_set cost gamma n_features n_selected_genes  accuracy
## 1           1    clinical  0.1 0.001          6                0 0.4307692
## 2           1     protein  0.1 0.001       1015                0 0.4307692
## 3           1        gene  0.1 0.001          7                7 0.4307692
## 4           1    combined  0.1 0.001       1038               17 0.4307692
## 5           2    clinical  0.1 0.001          6                0 0.4307692
## 6           2     protein  0.1 0.001       1015                0 0.4307692
## 7           2        gene  0.1 0.001         16               16 0.4307692
## 8           2    combined  0.1 0.001       1033               12 0.4307692
## 9           3    clinical  0.1 0.001          6                0 0.4375000
## 10          3     protein  0.1 0.001       1015                0 0.4375000
## 11          3        gene  0.1 0.001         18               18 0.4375000
## 12          3    combined  0.1 0.001       1046               25 0.4375000
## 13          4    clinical  0.1 0.001          6                0 0.4242424
## 14          4     protein  0.1 0.001       1015                0 0.4242424
## 15          4        gene  0.1 0.001         45               45 0.4242424
## 16          4    combined  0.1 0.001       1038               17 0.4242424
## 17          5    clinical  0.1 0.001          6                0 0.4375000
## 18          5     protein  0.1 0.001       1015                0 0.4375000
## 19          5        gene  0.1 0.100          5                5 0.5312500
## 20          5    combined  0.1 0.001       1040               19 0.4375000
##     macro_f1 balanced_accuracy
## 1  0.6021505         0.3333333
## 2  0.6021505         0.3333333
## 3  0.6021505         0.3333333
## 4  0.6021505         0.3333333
## 5  0.6021505         0.3333333
## 6  0.6021505         0.3333333
## 7  0.6021505         0.3333333
## 8  0.6021505         0.3333333
## 9  0.6086957         0.3333333
## 10 0.6086957         0.3333333
## 11 0.6086957         0.3333333
## 12 0.6086957         0.3333333
## 13 0.5957447         0.3333333
## 14 0.5957447         0.3333333
## 15 0.5957447         0.3333333
## 16 0.5957447         0.3333333
## 17 0.6086957         0.3333333
## 18 0.6086957         0.3333333
## 19 0.6164557         0.4740896
## 20 0.6086957         0.3333333

8 8. Aggregate performance across outer folds

summary_metrics <- outer_metrics %>%
  group_by(feature_set) %>%
  summarise(
    accuracy_mean = mean(accuracy, na.rm = TRUE),
    accuracy_sd = sd(accuracy, na.rm = TRUE),
    macro_f1_mean = mean(macro_f1, na.rm = TRUE),
    macro_f1_sd = sd(macro_f1, na.rm = TRUE),
    balanced_accuracy_mean = mean(balanced_accuracy, na.rm = TRUE),
    balanced_accuracy_sd = sd(balanced_accuracy, na.rm = TRUE),
    mean_n_features = mean(n_features, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  arrange(desc(macro_f1_mean))

summary_metrics
## # A tibble: 4 × 8
##   feature_set accuracy_mean accuracy_sd macro_f1_mean macro_f1_sd
##   <chr>               <dbl>       <dbl>         <dbl>       <dbl>
## 1 gene                0.451     0.0452          0.605     0.00785
## 2 clinical            0.432     0.00556         0.603     0.00543
## 3 combined            0.432     0.00556         0.603     0.00543
## 4 protein             0.432     0.00556         0.603     0.00543
## # ℹ 3 more variables: balanced_accuracy_mean <dbl>, balanced_accuracy_sd <dbl>,
## #   mean_n_features <dbl>
write.csv(summary_metrics, "nested_svm_summary_metrics.csv", row.names = FALSE)
plot_df <- summary_metrics %>%
  select(feature_set, macro_f1_mean, macro_f1_sd, accuracy_mean, accuracy_sd) %>%
  pivot_longer(
    cols = c(macro_f1_mean, accuracy_mean),
    names_to = "metric",
    values_to = "mean"
  ) %>%
  mutate(
    sd = ifelse(metric == "macro_f1_mean", macro_f1_sd, accuracy_sd),
    metric = recode(metric,
                    macro_f1_mean = "Macro F1",
                    accuracy_mean = "Accuracy")
  )

ggplot(plot_df, aes(x = feature_set, y = mean, fill = metric)) +
  geom_col(position = position_dodge(width = 0.8), width = 0.7) +
  geom_errorbar(aes(ymin = mean - sd, ymax = mean + sd),
                position = position_dodge(width = 0.8), width = 0.2) +
  theme_bw() +
  labs(title = "Nested 5-fold CV SVM performance",
       subtitle = "Mean ± SD across outer folds",
       x = "Feature set",
       y = "Score")

ggsave("nested_svm_performance_barplot.png", width = 9, height = 5, dpi = 150)

9 9. Confusion matrices and sensitivity/specificity

conf_df <- outer_predictions %>%
  group_by(feature_set, truth, prediction) %>%
  summarise(n = n(), .groups = "drop")

ggplot(conf_df, aes(x = prediction, y = truth, fill = n)) +
  geom_tile() +
  geom_text(aes(label = n), color = "white", size = 4) +
  facet_wrap(~ feature_set) +
  theme_bw() +
  labs(title = "Outer-fold aggregated confusion matrices",
       x = "Predicted class",
       y = "True class")

ggsave("nested_svm_confusion_matrices.png", width = 8, height = 5, dpi = 150)

per_class_all <- outer_predictions %>%
  group_by(feature_set) %>%
  group_modify(~ per_class_metrics(
    pred = factor(.x$prediction, levels = levels(status_final)),
    truth = factor(.x$truth, levels = levels(status_final))
  )) %>%
  ungroup()

per_class_all
## # A tibble: 12 × 6
##    feature_set class sensitivity specificity precision     f1
##    <chr>       <chr>       <dbl>       <dbl>     <dbl>  <dbl>
##  1 clinical    CTL         0          1         NA     NA    
##  2 clinical    MCI         0          1         NA     NA    
##  3 clinical    AD          1          0          0.432  0.603
##  4 combined    CTL         0          1         NA     NA    
##  5 combined    MCI         0          1         NA     NA    
##  6 combined    AD          1          0          0.432  0.603
##  7 gene        CTL         0.103      0.983      0.692  0.18 
##  8 gene        MCI         0          1         NA     NA    
##  9 gene        AD          0.979      0.0543     0.441  0.608
## 10 protein     CTL         0          1         NA     NA    
## 11 protein     MCI         0          1         NA     NA    
## 12 protein     AD          1          0          0.432  0.603
write.csv(per_class_all, "nested_svm_per_class_metrics.csv", row.names = FALSE)
ggplot(per_class_all, aes(x = class, y = feature_set, fill = sensitivity)) +
  geom_tile() +
  geom_text(aes(label = round(sensitivity, 2)), color = "white") +
  theme_bw() +
  labs(title = "Per-class sensitivity heatmap",
       x = "Class",
       y = "Feature set")

ggsave("nested_svm_sensitivity_heatmap.png", width = 7, height = 4, dpi = 150)

10 10. One-vs-rest ROC curves for the best feature set

# e1071 probability estimates inside nested CV are not saved above.
# For honest final reporting, use confusion matrix, macro F1, accuracy, sensitivity, specificity.
# If a ROC plot is required, it should be generated inside each outer fold from saved probabilities.
# This template focuses on the metrics requested in the workflow figure.

cat("Best feature set by Macro F1:\n")
## Best feature set by Macro F1:
print(summary_metrics[1, ])
## # A tibble: 1 × 8
##   feature_set accuracy_mean accuracy_sd macro_f1_mean macro_f1_sd
##   <chr>               <dbl>       <dbl>         <dbl>       <dbl>
## 1 gene                0.451      0.0452         0.605     0.00785
## # ℹ 3 more variables: balanced_accuracy_mean <dbl>, balanced_accuracy_sd <dbl>,
## #   mean_n_features <dbl>

11 11. Workflow

cat("Figure 1 workflow:\n")
## Figure 1 workflow:
cat("\nRaw clinical, proteomics and gene-expression data")
## 
## Raw clinical, proteomics and gene-expression data
cat("\n→ remove embedded clinical columns, log2-transform proteomics, map probes to gene symbols")
## 
## → remove embedded clinical columns, log2-transform proteomics, map probes to gene symbols
cat("\n→ align baseline AD/MCI/CTL samples")
## 
## → align baseline AD/MCI/CTL samples
cat("\n→ outer 5-fold CV")
## 
## → outer 5-fold CV
cat("\n   → for each outer fold: dev/test split")
## 
##    → for each outer fold: dev/test split
cat("\n   → ComBat batch correction on gene expression using batch only, no diagnosis labels")
## 
##    → ComBat batch correction on gene expression using batch only, no diagnosis labels
cat("\n   → inner 5-fold CV on dev")
## 
##    → inner 5-fold CV on dev
cat("\n      → LASSO feature selection on inner_train only")
## 
##       → LASSO feature selection on inner_train only
cat("\n      → standardisation on inner_train only")
## 
##       → standardisation on inner_train only
cat("\n      → tune SVM C and gamma")
## 
##       → tune SVM C and gamma
cat("\n   → refit LASSO and SVM on full dev")
## 
##    → refit LASSO and SVM on full dev
cat("\n   → predict locked-away outer test fold")
## 
##    → predict locked-away outer test fold
cat("\n→ aggregate accuracy, macro F1, sensitivity and specificity across 5 outer folds\n")
## 
## → aggregate accuracy, macro F1, sensitivity and specificity across 5 outer folds