0. SETUP ###================================================================

## 0.1 SET PATHS ---------------------------------------------------------------
# DEFINE ROOT DIRECTORY
root_path <- "C:/Users/alden/Dissertation/TP53_isoforms_project"

# SET OTHER DIRECTORIES RELATIVE TO ROOT
scripts_dir <- file.path(root_path, "scripts")
input_dir   <- file.path(root_path, "input")
output_dir  <- file.path(root_path, "output")

## 0.2 OTHER -------------------------------------------------------------------
# SET SEED FOR REPRODUCABILITY
set.seed(42)

1. INSTALL PACKAGES ###=====================================================

## 1.1 INSTALL PACKAGES --------------------------------------------------------
# INSTALL BIOCMANAGER
if (!requireNamespace("BiocManager", quietly = TRUE)) {
    install.packages("BiocManager")
}

# LIST PACKAGES
packages <- c(
  "BiocManager", 
  "MultiAssayExperiment", 
  "SummarizedExperiment", 
  "GenomicRanges",
  "snapcount", 
  "edgeR", 
  "limma", 
  "GSVA", 
  "BiocParallel", 
  "MOFA2", 
  "factoextra", 
  "infotheo", 
  "bnlearn", 
  "survival", 
  "gridExtra",
  "dplyr", 
  "tidyr", 
  "ggplot2", 
  "stringr", 
  "tibble", 
  "data.table", 
  "readxl"
)

# CHECK FOR MISSING PACKAGES
missing_pkgs <- packages[!(packages %in% rownames(installed.packages()))]

# INSTALL
if (length(missing_pkgs) > 0) {
    BiocManager::install(missing_pkgs, update = FALSE, ask = FALSE)
}

# LOAD
invisible(lapply(packages, library, character.only = TRUE))

PART A - BUILDING THE MAE

2. GET ISOFORM COUNTS ###===================================================

## 2.1 DEFINE TARGET REGION ----------------------------------------------------
targets <- GenomicRanges::GRanges(
    seqnames = "chr17",
    ranges = IRanges(
        start = c(7673267, 7673207, 7685260, 7675239, 7674526),
        end = c(7673339, 7673266, 7686371, 7675493, 7674858)
    ),
    strand = "-",
    label = c("exon 9B", "exon 9y", "intron 2", "intron 4", "intron 6")
)

## 2.2 GET DATA ----------------------------------------------------------------
# INITIALIZE QUERY FOR TP53 ACROSS ALL TCGA SAMPLES
qb <- snapcount::QueryBuilder(compilation = "tcga", regions = "TP53")
rse <- snapcount::query_exon(qb, return_rse = TRUE)

# IDENTIFY OVERLAPS THAT MATCH hg38 COORDINATES
hits <- findOverlaps(targets, rowRanges(rse), type = "equal")

# EXTRACT COUNTS TO MATRIX AND APPLY LABLES
filtered_counts <- as.matrix(assay(rse)[subjectHits(hits), , drop = FALSE])
rownames(filtered_counts) <- targets$label[queryHits(hits)]

# AS MATRIX
isoform_counts <- as.matrix(filtered_counts)

## 2.3 SUMMARISE + SAVE --------------------------------------------------------
# GET HEAD AND DIMENSION
head(isoform_counts[, 1:10], 10)
         rail_59761 rail_59762 rail_59763 rail_59764 rail_59765
exon 9B        1262        335        275          0         59
exon 9y         851         99        373          0         41
intron 2      22219       1857      19857       2196       4192
intron 4       2422         48        152          0         48
intron 6       3312        126        153         78         92
         rail_59766 rail_59767 rail_59768 rail_59769 rail_59770
exon 9B          97        222        142        490          0
exon 9y         124         65         62        494          0
intron 2       7174       2275      16291       7602       1949
intron 4          0          0          0        209          0
intron 6         67        116         74        878          0
dim(isoform_counts)
[1]     5 11284
# SAVE
saveRDS(isoform_counts, file = file.path(output_dir, "TP53_isoform_counts.rds"))

3. GET METADATA ###=========================================================

## 3.1 DEFINE METADATA ---------------------------------------------------------
# CREATE METADATA LIST AND RENAME
meta_cols <- c(
  
    # IDENTIFIERS
    barcode_full  =   "gdc_cases.samples.portions.analytes.aliquots.submitter_id",
    project       =   "gdc_cases.project.project_id",
    
    # LIBRARY INFORMATION
    lib.size      =   "mapped_read_count",
    a260_a280     =   "gdc_cases.samples.portions.analytes.a260_a280_ratio",
    
    # BRCA HORMONE STATUS
    er_status     =   "xml_breast_carcinoma_estrogen_receptor_status", 
    pr_status     =   "xml_breast_carcinoma_progesterone_receptor_status", 
    her2_ish      =   "xml_lab_procedure_her2_neu_in_situ_hybrid_outcome_type", 
    her2_ihc      =   "xml_lab_proc_her2_neu_immunohistochemistry_receptor_status",
    
    # THERAPY INFORMATION
    drug_name     =   "cgc_drug_therapy_drug_name",
    drug_type     =   "cgc_drug_therapy_pharmaceutical_therapy_type"
)

## 3.2 EXTRACT METADATA --------------------------------------------------------
meta <- as.data.frame(colData(rse)) %>%
  dplyr::select(any_of(meta_cols)) %>%
  dplyr::rename(any_of(meta_cols)) %>%
  
# DERIVE PATIENT IDENTIFIERS
  dplyr::mutate(
    rail_id            = rownames(.), 
    patient_barcode_15 = substr(barcode_full, 1, 15),
    patient_barcode_12 = substr(barcode_full, 1, 12),
    batch              = stringr::str_split_i(barcode_full, "-", 6)
  ) %>%
  
# CONSOLIDATE SUBTYPE
  dplyr::mutate(
    her2_consolidated = dplyr::case_when(
      her2_ish %in% c("Positive", "Negative") ~ her2_ish,
      her2_ihc %in% c("Positive", "Negative") ~ her2_ihc,
      TRUE ~ ""
    ),
    hr_consolidated = dplyr::case_when(
      er_status == "Positive" | pr_status == "Positive" ~ "Positive", 
      er_status == "Negative" & pr_status == "Negative" ~ "Negative",
      TRUE ~ ""
    ),

# CLASSIFY CANCER SUBTYPE (PAM50)
    breast_cancer_subtype = dplyr::case_when(
      project != "TCGA-BRCA" ~ "Non-BRCA",
      hr_consolidated == "Positive" & her2_consolidated == "Negative" ~ "LuminalA",
      hr_consolidated == "Positive" & her2_consolidated == "Positive" ~ "LuminalB",
      hr_consolidated == "Negative" & her2_consolidated == "Positive" ~ "HER2-enriched",
      hr_consolidated == "Negative" & her2_consolidated == "Negative" ~ "Triple-negative",
      TRUE ~ "Unclassified BRCA"
    )
  ) %>%

# DEDUPLICATE ALIQUOTS
  dplyr::distinct(patient_barcode_15, .keep_all = TRUE) %>%

# UPDATE ROWNAMES
  tibble::remove_rownames() %>%
  tibble::column_to_rownames("patient_barcode_15")

## 3.3 UPDATE "isoform_counts" -------------------------------------------------

# RENAME USING TCGA BARCODE
isoform_counts <- isoform_counts[, meta$rail_id] # EXPLICIT ORDERING FOR CORRECT MAPPING
colnames(isoform_counts) <- rownames(meta)

## 3.4 SUBSET TO BRCA ----------------------------------------------------------

# SUBSET "meta"
meta <- meta %>%
  dplyr::filter(project == "TCGA-BRCA")

# SAVE SAMPLE LIST
brca_samples <- rownames(meta)

# SUBSET "isoform_counts"
isoform_counts <- isoform_counts[, brca_samples]

## 3.3 SUMMARISE ---------------------------------------------------------------
# GET ISOFORM MEANS
isoform_means <- rowMeans(isoform_counts)
print(isoform_means)
  exon 9B   exon 9y  intron 2  intron 4  intron 6 
 158.2591  115.4488 4274.9299  126.4282  224.2822 
# GET "meta" HEAD AND DIMENSION
head(meta)
dim(meta)
[1] 1212   16
# GET "isoform_counts" HEAD AND DIMENSION
head(isoform_counts[, 1:10], 10)
         TCGA-A2-A0CL-01 TCGA-A8-A08G-01 TCGA-A2-A0D2-01
exon 9B              490               0             235
exon 9y              494               0             112
intron 2            7602            2266            8967
intron 4             209               0             153
intron 6             878              55              87
         TCGA-A7-A0DB-01 TCGA-BH-A0BT-11 TCGA-AN-A04D-01
exon 9B               68              63             204
exon 9y              111              87              73
intron 2            1194            4390            6909
intron 4              53              92             238
intron 6             174             118             362
         TCGA-BH-A0DO-11 TCGA-BH-A0B5-11 TCGA-AO-A0J3-01
exon 9B              368              78              29
exon 9y              236               5              55
intron 2            2632             699            6399
intron 4              50               0              50
intron 6             444             161              70
         TCGA-B6-A0RE-01
exon 9B               20
exon 9y              169
intron 2            3938
intron 4             100
intron 6             235
dim(isoform_counts)
[1]    5 1212

4. NORMALISATION ###========================================================

# PREPARE DGE LIST
dge <- edgeR::DGEList(
    counts = isoform_counts, 
    samples = meta, 
    lib.size = meta$lib.size
)

# TMM NORMALISATION
dge <- edgeR::calcNormFactors(dge, method = "TMM")

# LOGCPM NORMALISATION
isoform_logcpm <- edgeR::cpm(dge, log = TRUE, prior.count = 1)

BATCH CORRECTION (limma::removeBatchEffect) REDUCED MUTUAL INFORMATION BETWEEN ISOFORMS AND SURVIVAL

6. PCA ###==================================================================

## 6.1 COMPUTE PCA -------------------------------------------------------------
isoform_pca <- prcomp(t(isoform_logcpm), scale. = TRUE)
head(isoform_pca$rotation, 10)
                 PC1        PC2        PC3         PC4          PC5
exon 9B  -0.59484692 -0.2692719  0.2526056 -0.09666798 -0.707457072
exon 9y  -0.59611159 -0.2607293  0.2434739 -0.13704115  0.706124211
intron 2 -0.07122502  0.7164767  0.6734204  0.16754286  0.004741838
intron 4 -0.32466305  0.5533393 -0.4727046 -0.60364507 -0.023928730
intron 6 -0.42465163  0.1999743 -0.4471758  0.76119448  0.017263268
## 6.2 SCREE PLOT --------------------------------------------------------------
factoextra::fviz_eig(
    isoform_pca, 
    choice = "variance",  
    geom = "line",        
    addlabels = TRUE,     
) +
theme_minimal() +
theme(
    panel.grid.major = element_blank(), 
    panel.grid.minor = element_blank(),
    axis.line = element_line(colour = "black"),
    panel.border = element_blank()
)


## 6.3 BIPLOTS -----------------------------------------------------------------

# FUNCTION TI CREATE BIPLOTS
create_biplot <- function(pca_obj, pc_x, pc_y) {
  
  # PLOT 
  p <- factoextra::fviz_pca_var(
    pca_obj,
    axes = c(pc_x, pc_y),
    col.var = "contrib",
    gradient.cols = c("#006EAE", "#CA9B23", "#C5373D"),
    repel = TRUE,
    title = paste("PC", pc_x, "vs PC", pc_y)
  )
  
  # THEME
  p <- p +
    theme_minimal() +
    theme(
      panel.grid.major = element_blank(), 
      panel.grid.minor = element_blank(),
      axis.line = element_line(colour = "black"),
      plot.title = element_text(size = 10, face = "bold"),
      
      # SET ALL BACKGROUNDS TO TRANSPARENT
      plot.background = element_rect(fill = "transparent", color = NA),
      panel.background = element_rect(fill = "transparent", color = NA),
      legend.background = element_rect(fill = "transparent", color = NA)
    )
  
  return(p)
}

# RUN FUNCTION IN LOOP
plot_list <- list()
num_pcs <- ncol(isoform_pca$x)

for (i in 1:min(4, (num_pcs - 1))) {
  plot_list[[i]] <- create_biplot(isoform_pca, i, i + 1)
}

# ARANGE IN 2X2 GRID
combined_pca_plot <- gridExtra::grid.arrange(
  grobs = plot_list, 
  ncol = 2, 
  nrow = 2,
  top = grid::textGrob("TP53 Isoform PCA Loading Comparisons", 
                       gp = grid::gpar(fontsize = 14, fontface = "bold"))
)


# SAVE TO OUTPUT DIRECTORY
ggplot2::ggsave(
  filename = file.path(output_dir, "TP53_PCA_grid_TRANSPARENT.png"),
  plot = combined_pca_plot,
  width = 10, 
  height = 10, 
  dpi = 300,
  bg = "transparent" 
)

# 6.4 EXTRACT PATIENT SCORES ------------------------------------------------------------------
isoform_pca <- as.data.frame(isoform_pca$x)

7. LOAD TCGA DATA ###=======================================================

## 7.1 RNASEQ ------------------------------------------------------------------
rnaseq_file <- "EB++AdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena.gz"
rnaseq <- data.table::fread(file.path(input_dir, rnaseq_file))
rnaseq <- as.matrix(rnaseq[, -1, with = FALSE], rownames = rnaseq[[1]])
rnaseq <- rnaseq[, intersect(colnames(rnaseq), brca_samples)]
rnaseq <- as.matrix(rnaseq[!duplicated(rownames(rnaseq)), ])

## 7.2 MUTATION ----------------------------------------------------------------
mutation_file <- "mc3.v0.2.8.PUBLIC.nonsilentGene.xena.gz"
mutation <- data.table::fread(file.path(input_dir, mutation_file))
mutation <- as.matrix(mutation[, -1, with = FALSE], rownames = mutation[[1]])
mutation <- mutation[, intersect(colnames(mutation), brca_samples)]

## 7.3 RPPA --------------------------------------------------------------------
rppa_file <- "TCGA-RPPA-pancan-clean.xena.gz"
rppa <- data.table::fread(file.path(input_dir, rppa_file))
rppa <- as.matrix(rppa[, -1, with = FALSE], rownames = rppa[[1]])
rppa <- rppa[, intersect(colnames(rppa), brca_samples)]

## 7.4 STEMNESS ----------------------------------------------------------------
stemness_file <- "StemnessScores_RNAexp_20170127.2.tsv.gz"
stemness <- data.table::fread(file.path(input_dir, stemness_file))
stemness <- as.matrix(stemness[, -1, with = FALSE], rownames = stemness[[1]])
stemness <- stemness[, intersect(colnames(stemness), brca_samples)]

## 7.5 HDR ---------------------------------------------------------------------
hdr_file <- "TCGA.HRD_withSampleID.txt.gz"
hdr <- data.table::fread(file.path(input_dir, hdr_file))
hdr <- as.matrix(hdr[, -1, with = FALSE], rownames = hdr[[1]])
hdr <- hdr[, intersect(colnames(hdr), brca_samples)]

## 7.6 IMMUNE ------------------------------------------------------------------
immune_subtype_file <- "TCGA_pancancer_10852whitelistsamples_68ImmuneSigs.xena.gz"
immune <- data.table::fread(file.path(input_dir, immune_subtype_file))
immune <- as.matrix(immune[, -1, with = FALSE], rownames = immune[[1]])
immune <- immune[, intersect(colnames(immune), brca_samples)]

8. LOAD TCGA METADATA ###===================================================

## 8.1 SURVIVAL DATA -----------------------------------------------------------
survival_file <- "Survival_SupplementalTable_S1_20171025_xena_sp"
survival <- data.table::fread(file.path(input_dir, survival_file))

# RENAME
survival <- survival %>%
  dplyr::select(
    sample, 
    age = age_at_initial_pathologic_diagnosis, 
    histology = histological_type,
    menopause = menopause_status,
    tumor_status,
    OS, OS.time, DSS, DSS.time, DFI, DFI.time, PFI, PFI.time,
    ajcc_stage = ajcc_pathologic_tumor_stage
  )

# COLLAPSE STAGE INTO 4
survival <- survival %>%
  dplyr::mutate(
    stage = dplyr::case_when(
      stringr::str_detect(ajcc_stage, "^Stage I[A-B]?$") ~ "Stage 1",
      stringr::str_detect(ajcc_stage, "^Stage II[A-B]?$") ~ "Stage 2",
      stringr::str_detect(ajcc_stage, "^Stage III[A-C]?$") ~ "Stage 3",
      stringr::str_detect(ajcc_stage, "^Stage IV$") ~ "Stage 4",
      TRUE ~ NA_character_
    ),
    stage = factor(stage, levels = c("Stage 1", "Stage 2", "Stage 3", "Stage 4"))
  )

# SUBSET
survival <- survival %>%
  dplyr::filter(sample %in% brca_samples) %>%
  dplyr::distinct(sample, .keep_all = TRUE) %>%
  tibble::column_to_rownames("sample")

## 8.4 MERGE WITH "meta" -------------------------------------------------------
meta <- merge(meta, survival, by = "row.names", all.x = TRUE) %>% tibble::column_to_rownames("Row.names")

9. MOFA FACTORS ###=========================================================

## 9.1 FEATURE SELECTION -------------------------------------------------------

# RNASEQ: TOP 5000 BY VARIANCE
rna_vars <- apply(rnaseq, 1, var, na.rm = TRUE)
mofa_rna_features <- names(sort(rna_vars, decreasing = TRUE))[1:5000]

# MUTATION: 1% FREQUENCY
mut_freq <- rowMeans(mutation, na.rm = TRUE)
mofa_mut_features <- names(mut_freq[mut_freq >= 0.01])

# RPPA: REMOVE EMPTY ROWS
mofa_protein_features <- rownames(rppa)[rowSums(!is.na(rppa)) > 0]

## 9.2 INITIALIZE MOFA ---------------------------------------------------------

# GET UNIQUE SAMPLES
all_samples <- unique(c(colnames(rnaseq), colnames(mutation), colnames(rppa), colnames(immune)))

# SUBSET DATA
mofa_rna <- rnaseq[mofa_rna_features, ]
mofa_mut <- mutation[mofa_mut_features, ]
mofa_prot <- rppa[mofa_protein_features, ]
mofa_imm  <- immune

# ALIGN
mofa_input <- list(
  RNAseq   = mofa_rna[, match(all_samples, colnames(mofa_rna))],
  Mutation = mofa_mut[, match(all_samples, colnames(mofa_mut))],
  Protein  = mofa_prot[, match(all_samples, colnames(mofa_prot))],
  Immune   = mofa_imm[, match(all_samples, colnames(mofa_imm))]
)

# FIX COLNAMES
for(i in 1:4) { 
  colnames(mofa_input[[i]]) <- all_samples 
}

# CREATE MOFA OBJECT
MOFAobj <- MOFA2::create_mofa(mofa_input)

## 9.3 TRAIN MODEL -------------------------------------------------------------

# SETTINGS
model_opts <- MOFA2::get_default_model_options(MOFAobj)
model_opts$num_factors <- 24 
model_opts$likelihoods[["Mutation"]] <- "bernoulli"

MOFAobj <- MOFA2::prepare_mofa(MOFAobj, model_options = model_opts) %>% 
  MOFA2::run_mofa(use_basilisk = TRUE)

## 9.4 SCREE PLOT --------------------------------------------------------------

# EXTRACT VARIANCE EXPLAINED
vars <- MOFA2::get_variance_explained(MOFAobj)$r2_per_factor[[1]]

# CALCULATE TOTAL VARIANCE PER FACTOR
scree_data <- data.frame(
  Factor = paste0("Factor", 1:nrow(vars)),
  Variance = rowSums(vars)
) %>%
  mutate(Factor = factor(Factor, levels = Factor))

# SCREE PLOT
ggplot(scree_data, aes(x = Factor, y = Variance, group = 1)) +
  geom_line(color = "steelblue", size = 1) +
  geom_point(color = "darkblue", size = 3) +
  theme_minimal() +
  labs(title = "MOFA Scree Plot",
       subtitle = "Total Variance Explained across all Omics Views",
       y = "Total Variance Explained (%)",
       x = "Latent Factors") +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

# EXTRACT FACTORS
mofa_factors <- MOFA2::get_factors(MOFAobj, factors = "all")[[1]]
colnames(mofa_factors) <- paste0("MOFA_", colnames(mofa_factors))

## 9.5 FACTOR LOADINGS ---------------------------------------------------------

# GET FACTOR LOADINGS (EXAMPLE USAGE)
imm_weights <- MOFA2::get_weights(MOFAobj, views = "Immune", factors = "Factor3")[[1]]
imm_weights

## 9.6 K-MEANS CLUSTERING ------------------------------------------------------

# EXTRACT TOP 8 FACTORS
mofa_factors_mat <- MOFA2::get_factors(MOFAobj)[[1]][, 1:8]

# RUN UMAP
MOFAobj <- MOFA2::run_umap(MOFAobj, factors = 1:8, n_neighbors = 15, min_dist = 0.1)

# DIAGNOSTIC PLOTS
p_elbow <- factoextra::fviz_nbclust(mofa_factors_mat, kmeans, method = "wss") +
  labs(title = "Elbow Method (WSS)", x = "Number of Clusters (k)") +
  theme_minimal()

p_sil <- factoextra::fviz_nbclust(mofa_factors_mat, kmeans, method = "silhouette") +
  labs(title = "Silhouette Analysis", x = "Number of Clusters (k)") +
  theme_minimal()

# SAVE DIAGNOSTIC PLOTS
diag_combined <- gridExtra::arrangeGrob(p_elbow, p_sil, ncol = 2)
ggplot2::ggsave(file.path(output_dir, "MOFA_Clustering_Diagnostics_Top8.png"), 
                diag_combined, width = 10, height = 5, bg = "transparent")

# RUN K-MEANS CLUSTERING (K=4)
km_res <- kmeans(mofa_factors_mat, centers = 3, nstart = 25)

# SYNC CLUSTERS TO METADATA AND MOFA OBJ
meta$mofa_cluster <- as.factor(paste0("Cluster_", km_res$cluster[match(rownames(meta), names(km_res$cluster))]))

MOFA2::samples_metadata(MOFAobj) <- meta %>%
  mutate(sample = rownames(.)) %>%
  relocate(sample)

# DEFINE THEME
clean_theme <- theme_minimal() + 
  theme(
    panel.grid.major = element_blank(), 
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "transparent", color = NA),
    plot.background  = element_rect(fill = "transparent", color = NA),
    axis.line        = element_line(color = "black"),
    legend.background = element_rect(fill = "transparent", color = NA)
  )

# PLOT UMAP (MOFA CLUSTERS)
p1 <- MOFA2::plot_dimred(
  MOFAobj, 
  method = "UMAP", 
  color_by = "mofa_cluster", 
  dot_size = 2
) + 
  scale_color_brewer(palette = "Set1") +
  clean_theme +
  labs(title = "MOFA UMAP: Multi-omic Clusters", subtitle = "K-means (K=4) on Top 8 Factors")

# PLOT UMAP (PAM50 SUBTYPE)
p2 <- MOFA2::plot_dimred(
  MOFAobj, 
  method = "UMAP", 
  color_by = "breast_cancer_subtype", 
  dot_size = 2
) + 
  scale_color_brewer(palette = "Dark2") +
  clean_theme +
  labs(title = "MOFA UMAP: PAM50 Subtypes", subtitle = "Clinical Label Comparison")

# SAVE UMAPS
umap_combined <- gridExtra::arrangeGrob(p1, p2, ncol = 2)
ggplot2::ggsave(file.path(output_dir, "MOFA_UMAP_Clusters_vs_Subtypes.png"), 
                umap_combined, width = 12, height = 6, bg = "transparent")

# DISPLAY UMAPS
grid::grid.draw(umap_combined)

10. MARTINGALE RESIDUALS ###=================================================

## 10.1 EXTRACT RESIDUALS ------------------------------------------------------
target_outcomes <- list(os = "OS", pfi = "PFI")

for (m in names(target_outcomes)) {
  
  # 1. CONSTRUCT FORMULA
  surv_formula <- as.formula(paste0("survival::Surv(", target_outcomes[[m]], ".time, ", 
                                    target_outcomes[[m]], ") ~ 1"))
  
  # 2. FIT COX MODEL
  fit <- survival::coxph(surv_formula, data = meta, na.action = na.exclude)
  
  # 3. ATTACH TO META
  meta[[paste0(m, "_risk_score")]] <- residuals(fit, type = "martingale")
}

11. BUILD MAE OBJECT ###=====================================================

## 11.1 TRANSPOSE --------------------------------------------------------------
isoform_pca <- t(isoform_pca)
mofa_factors <- t(mofa_factors)

## 11.2 PREPARE EXPERIMENT LIST ------------------------------------------------
EXP_LIST <- list(
    isoform_pca    = isoform_pca,  
    isoform_logcpm = isoform_logcpm,  
    rnaseq         = rnaseq,          
    mutation       = mutation,        
    rppa           = rppa,           
    stemness       = stemness,        
    hrd            = hdr,             
    immune         = immune,          
    mofa_factors   = mofa_factors
)

## 11.3 CREATE MAE OBJECT ------------------------------------------------------
mae <- MultiAssayExperiment::MultiAssayExperiment(
    experiments = EXP_LIST,
    colData     = meta
)

# PRINT SUMMARY
print(mae)
A MultiAssayExperiment object of 9 listed
 experiments with user-defined names and respective classes.
 Containing an ExperimentList class object of length 9:
 [1] isoform_pca: matrix with 5 rows and 1212 columns
 [2] isoform_logcpm: matrix with 5 rows and 1212 columns
 [3] rnaseq: matrix with 20530 rows and 1212 columns
 [4] mutation: matrix with 40543 rows and 788 columns
 [5] rppa: matrix with 258 rows and 874 columns
 [6] stemness: matrix with 2 rows and 1187 columns
 [7] hrd: matrix with 4 rows and 1053 columns
 [8] immune: matrix with 68 rows and 1187 columns
 [9] mofa_factors: matrix with 24 rows and 1212 columns
Functionality:
 experiments() - obtain the ExperimentList instance
 colData() - the primary/phenotype DataFrame
 sampleMap() - the sample coordination DataFrame
 `$`, `[`, `[[` - extract colData columns, subset, or experiment
 *Format() - convert into a long or wide DataFrame
 assays() - convert ExperimentList to a SimpleList of matrices
 exportClass() - save data to flat files
## 11.4 SAVING & CLEANUP -------------------------------------------------------
saveRDS(mae, file = file.path(output_dir, "TP53_Isoforms_MAE.rds"))

12. FUNCTION TO EXTRACT DATA ###=============================================

## 12.1 FUNCTION TO RETRIEVE DATA -----------------------------------------------
get_data <- function(mae, config, clinical = NULL) {
  
  extracted <- list()
  for (assay in names(config)) {
    items <- config[[assay]]
    if (isTRUE(items)) {
      extracted[[assay]] <- mae[[assay]]
    } else {
      valid <- intersect(items, rownames(mae[[assay]]))
      extracted[[assay]] <- mae[[assay]][valid, , drop = FALSE]
    }
  }
  
  temp_mae <- MultiAssayExperiment(experiments = extracted, colData = colData(mae))
  
  df <- as.data.frame(longForm(temp_mae, colDataCols = clinical)) %>%
    dplyr::select(-assay, -colname) %>%  # Drop these to allow collapsing
    tidyr::pivot_wider(names_from = "rowname", values_from = "value") %>%
    # CONVERT BARCODE COLUMN TO ROWNAMES
    tibble::column_to_rownames("primary")
  
  return(df)
}

PART B - BAYESIAN NETWORK INFERENCE

13. MAIN FUNCTION ###=======================================================

#' RUN BAYESIAN NETWORKS (REFACTORED)
#' @param df_raw DATAFRAME
#' @param n_restarts NUMBER OF RANDOM RESTARTS
#' @param bl BLACKLIST
#' @param run_name FOLDER NAME OF OUTPUT
#' @param ref_color REFERENCE COLOUR (AIC)
#' @param comp_color COMPARISON COLOUR (MATCHES/MB)

#' RUN BAYESIAN NETWORKS (REFACTORED V3)
run_networks <- function(df_raw, n_restarts = 50, bl = NULL, 
                         run_name = "Final_TP53_Analysis",
                         ref_color = "orange3",
                         comp_color = "orange") {
  
  # --- 1. RUN 6 BNs -------------------------------
  run_path <- file.path(output_dir, run_name)
  dir.create(run_path, recursive = TRUE, showWarnings = FALSE)
  
  cat("================================================\n")
  cat("BN PIPELINE RUN: ", run_name, "\n")
  cat("SAMPLE SIZE (N):", nrow(df_raw), "\n")
  cat("================================================\n")
  
  disc_levels  <- c(2, 3)
  scores       <- c("aic", "bic", "bde")
  all_results  <- list()
  ref_id       <- "2disc.aic"
  target_node  <- "RISK"
  master_nodes <- colnames(df_raw)

  for (d in disc_levels) {
    df_disc <- bnlearn::discretize(df_raw, method = 'quantile', breaks = d)
    df_disc <- as.data.frame(lapply(df_disc, droplevels))[, master_nodes]

    for (s in scores) {
      id <- paste0(d, "disc.", s)
      cat("\n>>> PROCESSING CONFIGURATION:", id, "\n")
      
      set.seed(123)
      starts <- c(list(empty.graph(master_nodes)),
                  random.graph(nodes = master_nodes, num = n_restarts - 1, method = "ic-dag", max.degree = 1))
      
      net_list <- lapply(starts, function(g) {
        tryCatch({ structural.em(df_disc, maximize = "hc", start = g, 
                                 maximize.args = list(score = s, blacklist = bl))
        }, error = function(e) return(NULL))
      })
      
      net_list <- net_list[!sapply(net_list, is.null)]
      avg_dag  <- averaged.network(custom.strength(net_list, nodes = master_nodes))
      
      print(avg_dag) 
      
      mb_nodes <- mb(avg_dag, target_node)
      curr_amat <- amat(avg_dag)[master_nodes, master_nodes]
      
      res_obj <- list(dag = avg_dag, adjacency = curr_amat, nodes = master_nodes, data = df_disc, mb = mb_nodes)
      all_results[[id]] <- res_obj
      
      saveRDS(res_obj, file.path(run_path, paste0(id, ".rds")))
      assign(paste0(run_name, ".", id), res_obj, envir = .GlobalEnv)
      
      # HIGH RES BN PLOTS
      png(file.path(run_path, paste0(id, ".png")), width = 2400, height = 2100, res = 300, bg = "transparent", type = "cairo")
      if (length(mb_nodes) > 0) {
        graphviz.plot(avg_dag, highlight = list(nodes = mb_nodes, fill = comp_color, col = "black"), 
                      main = paste(id, "| MB of", target_node))
      } else { 
        graphviz.plot(avg_dag, main = paste(id, "| No MB for", target_node)) 
      }
      dev.off()
    }
  }

  # --- 2. ADJACENCY MATRIX -------------------------------
  # (Keeping your original adjacency grid logic as requested)
  adj_mats <- lapply(all_results, function(x) x$adjacency)
  summed_mat <- Reduce("+", adj_mats)
  universal_mat <- (summed_mat == length(adj_mats)) 
  ref_mat <- all_results[[ref_id]]$adjacency
  plot_list <- list()
  config_names <- names(all_results)

  for (i in seq_along(config_names)) {
    id <- config_names[i]
    curr_mat <- all_results[[id]]$adjacency
    
    plot_df <- as.data.frame(curr_mat) %>%
      tibble::rownames_to_column("from") %>%
      tidyr::pivot_longer(-from, names_to = "to", values_to = "exists") %>%
      rowwise() %>%
      mutate(
        is_univ = universal_mat[from, to],
        is_ref  = ref_mat[from, to],
        category = case_when(
          exists == 1  & is_univ == TRUE ~ "Universal",
          id == ref_id & exists == 1 & is_univ == FALSE ~ "Reference Only",
          exists == 1  & is_ref == 1 & is_univ == FALSE ~ "Match Ref",
          exists == 1  & is_ref == 0 & is_univ == FALSE ~ "New Edge",
          TRUE ~ "None"
        )
      ) %>% ungroup()

    is_left_col   <- i %in% c(1, 4)
    is_bottom_row <- i %in% c(4, 5, 6)

    p <- ggplot(plot_df, aes(x = factor(to, levels = master_nodes), 
                             y = factor(from, levels = rev(master_nodes)))) +
      geom_tile(aes(fill = category), color = "gray85", linewidth = 0.2) +
      scale_fill_manual(
        values = c("Universal" = "black", "Reference Only" = ref_color, 
                   "Match Ref" = comp_color, "New Edge" = "lightgray", "None" = "white"),
        guide = "none"
      ) +
      labs(title = id, x = NULL, y = NULL) +
      theme_minimal() +
      theme(
        aspect.ratio = 1,
        panel.border = element_rect(color = "black", fill = NA, linewidth = 0.8),
        panel.grid   = element_blank(),
        plot.title   = element_text(hjust = 0.5, size = 10, face = "bold"),
        axis.text.x  = if(is_bottom_row) element_text(angle = 90, vjust = 0.5, hjust = 1, size = 7) else element_blank(),
        axis.ticks.x = if(is_bottom_row) element_line(color = "black") else element_line(color = NA),
        axis.text.y  = if(is_left_col) element_text(size = 7) else element_blank(),
        axis.ticks.y = if(is_left_col) element_line(color = "black") else element_line(color = NA),
        axis.ticks.length = unit(3, "pt"),
        plot.margin = margin(5, 5, 5, 5),
        plot.background = element_rect(fill = "transparent", color = NA)
      )
    plot_list[[i]] <- p
  }

  final_grid <- patchwork::wrap_plots(plot_list, ncol = 3, nrow = 2) + 
    plot_annotation(
      title = paste("Structural Stability Analysis:", run_name),
      subtitle = "Row 1: 2-level Discretization | Row 2: 3-level Discretization",
      theme = theme(plot.title = element_text(size = 14, face = "bold", hjust = 0.5),
                    plot.subtitle = element_text(size = 11, hjust = 0.5),
                    plot.background = element_rect(fill = "transparent", color = NA))
    )
  
  ggsave(file.path(run_path, "Adjacency_Comparison_Grid_Final.png"), 
         plot = final_grid, width = 11, height = 7.5, bg = "transparent", type = "cairo")

  # --- 3. CPT (REFERENCE ONLY) -------------------------------
  ref_res  <- all_results[[ref_id]]
  mb_nodes <- ref_res$mb
  
  # Identify variables and order them (putting PC3 before Risk if present)
  pc_in_mb <- grep("^PC[1-5]$", mb_nodes, value = TRUE)
  other_mb <- setdiff(mb_nodes, pc_in_mb)
  
  if ("PC3" %in% pc_in_mb) {
    ordered_vars <- c(other_mb, setdiff(pc_in_mb, "PC3"), "PC3", target_node)
  } else {
    ordered_vars <- c(other_mb, pc_in_mb, target_node)
  }
  
  # 1. Generate Frequency Table (N) and Probability Table
  raw_counts   <- table(ref_res$data[, ordered_vars])
  prob_table   <- prop.table(raw_counts, margin = 1:(length(ordered_vars) - 1))
  
  # 2. Convert to DataFrames
  df_probs  <- as.data.frame(prob_table)
  df_counts <- as.data.frame(raw_counts)
  
  # 3. Merge Probabilities and Sample Numbers (N)
  final_cpt_output <- df_probs %>%
    dplyr::rename(Probability = Freq) %>%
    dplyr::mutate(N = df_counts$Freq)

  # 4. Save CSV
  write.csv(final_cpt_output, file.path(run_path, "Reference_Risk_CPT_with_N.csv"), row.names = FALSE)

  # --- 4. MOSAIC PLOT (REFERENCE ONLY) -------------------------------
  n_configs <- prod(dim(counts_table)[-length(dim(counts_table))])
  p_high    <- as.vector(cpt_table)[(n_configs + 1):(2 * n_configs)]
  p_high[is.na(p_high)] <- 0
  
  p_contrast <- 1 / (1 + exp(-10 * (p_high - 0.5))) 
  grad_pal   <- colorRampPalette(c("#F5F5F5", ref_color))(100)
  risk_colors <- grad_pal[round(p_contrast * 99) + 1]
  color_array <- array(c(rep("#FFFFFF00", n_configs), risk_colors), dim = dim(counts_table))

  # MOSAIC PLOT WITH THIN BLACK OUTLINES
  png(file.path(run_path, "Reference_Mosaic_Risk_Final.png"), width = 3600, height = 2700, res = 300, bg = "transparent", type = "cairo")
  mosaic(counts_table, 
         gp = gpar(fill = color_array, col = "black", lwd = 0.5), # col = "black" for thin black outlines
         main = paste("Reference Run:", run_name, "Prognostic Risk Hierarchy"),
         labeling = labeling_border(gp_labels = gpar(fontsize = 9, fontface = "bold"), rot_labels = c(0, 90, 0, 90)))
  dev.off()

  cat("\nDONE. All outputs saved to:", run_path, "\n")
  return(all_results)
}

14. MOFA FACTORS ###========================================================

## 2.1 GET DATA
# DEFINE VARIABLES
selection <- list(
  isoform_pca  = TRUE, 
  mofa_factors = paste0("Factor", 1:8), 
  hrd          = "HRD",
  stemness     = "RNAss"
)

# DEFINE METADATA
meta_vars <- c("os_risk_score")

# RUN get_data
mofa_data <- as.data.frame(get_data(mae, selection, meta_vars)) %>%
  rename(RISK = os_risk_score)

## 2.2 BLACKLIST
mofa_bl <- data.frame(
  from = "RISK", 
  to   = setdiff(colnames(mofa_data), "RISK")
)

## 2.3 RUN MAIN FUNCTION
my_nets <- run_networks(
  df_raw = mofa_data,
  n_restarts = 100,
  bl = mofa_bl,
  run_name = "mofa",
  ref_color = "#006EAE",
  comp_color = "#9BCAE9"
)
================================================
BN PIPELINE RUN:  mofa 
SAMPLE SIZE (N): 1212 
================================================

>>> PROCESSING CONFIGURATION: 2disc.aic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 16 
  arcs:                                  55 
    undirected arcs:                     1 
    directed arcs:                       54 
  average markov blanket size:           8.25 
  average neighbourhood size:            6.88 
  average branching factor:              3.38 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.89 


>>> PROCESSING CONFIGURATION: 2disc.bic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 16 
  arcs:                                  29 
    undirected arcs:                     5 
    directed arcs:                       24 
  average markov blanket size:           4.25 
  average neighbourhood size:            3.62 
  average branching factor:              1.50 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.33 


>>> PROCESSING CONFIGURATION: 2disc.bde 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 16 
  arcs:                                  27 
    undirected arcs:                     3 
    directed arcs:                       24 
  average markov blanket size:           4.25 
  average neighbourhood size:            3.38 
  average branching factor:              1.50 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.62 


>>> PROCESSING CONFIGURATION: 3disc.aic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 16 
  arcs:                                  35 
    undirected arcs:                     4 
    directed arcs:                       31 
  average markov blanket size:           5.00 
  average neighbourhood size:            4.38 
  average branching factor:              1.94 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.03 


>>> PROCESSING CONFIGURATION: 3disc.bic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 16 
  arcs:                                  20 
    undirected arcs:                     4 
    directed arcs:                       16 
  average markov blanket size:           2.75 
  average neighbourhood size:            2.50 
  average branching factor:              1.00 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.01 


>>> PROCESSING CONFIGURATION: 3disc.bde 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 16 
  arcs:                                  16 
    undirected arcs:                     6 
    directed arcs:                       10 
  average markov blanket size:           2.12 
  average neighbourhood size:            2.00 
  average branching factor:              0.62 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.01 


DONE. All outputs saved to: C:/Users/alden/Dissertation/TP53_isoforms_project/output/mofa 

15. FIT EVIDENCE ###========================================================

# 1. PREPARE THE TEMPLATE AND GLOBAL DATA --------------------------------------
# EXTRACT THE COMPLETED DIRECTED ACYCLIC GRAPH
global_dag <- bnlearn::cextend(mofa.2disc.aic$dag)
DAG_NODES  <- bnlearn::nodes(global_dag)

# CRITICAL: RE-SYNC MOFA CLUSTERS TO YOUR DATA FRAME
# ASSUMING 'mae' IS YOUR MASTER EXPERIMENT OBJECT AND CLUSTERS ARE IN colData
mofa_data$mofa_cluster <- colData(mae)[rownames(mofa_data), "mofa_cluster"]

# PERFORM GLOBAL DISCRETIZATION ONLY ON DAG NODES
# THIS ENSURES NO EXTRA COLUMNS (LIKE CLUSTER) BREAK THE BN.FIT LATER
full_disc <- bnlearn::discretize(
  mofa_data[, intersect(colnames(mofa_data), DAG_NODES)], 
  method = 'quantile', 
  breaks = 2
)

# RE-ATTACH CLUSTER ASSIGNMENTS TO THE DISCRETIZED DATA
full_disc$mofa_cluster <- mofa_data$mofa_cluster

# VERIFY CLUSTERS ARE PRESENT BEFORE PROCEEDING
# IF THIS PRINTS 0, THE CLUSTER ASSIGNMENT ABOVE FAILED
cat("UNIQUE CLUSTERS FOUND:", length(unique(na.omit(full_disc$mofa_cluster))), "\n")
UNIQUE CLUSTERS FOUND: 3 
# DEFINE TARGET OUTCOME AND PRIMARY DRIVER NODES
target_node     <- "RISK"
driver_node     <- "PC3"
high_risk_label <- levels(full_disc[[target_node]])[2]
low_pc3_label   <- levels(full_disc[[driver_node]])[1]
high_pc3_label  <- levels(full_disc[[driver_node]])[2]

# 2. THE REFIT LOOP ------------------------------------------------------------
# IDENTIFY UNIQUE CLUSTERS WHILE IGNORING NA VALUES
clusters <- as.character(unique(na.omit(full_disc$mofa_cluster)))
refit_results <- list()

# EXTRACT EXACT NODE NAMES REQUIRED BY THE GLOBAL DAG
dag_nodes <- bnlearn::nodes(global_dag)

for (cl in clusters) {
    cat(">>> REFITTING GLOBAL GRAPH FOR:", cl, "\n")
    
    # A. SUBSET DATA FOR THE CURRENT CLUSTER
    cl_data  <- full_disc[full_disc$mofa_cluster == cl, ]
    cl_n     <- nrow(cl_data)
    
    # B. ENSURE DATA ONLY CONTAINS DAG NODES TO PREVENT DIMENSION ERRORS
    cl_input <- cl_data[, intersect(colnames(cl_data), dag_nodes)]
    
    # VALIDATE THAT ALL REQUIRED NODES ARE PRESENT IN THE SUBSET
    missing_nodes <- setdiff(dag_nodes, colnames(cl_input))
    if(length(missing_nodes) > 0) {
        stop(paste("DATA IS MISSING NODES REQUIRED BY DAG:", paste(missing_nodes, collapse=", ")))
    }

    # C. REFIT PARAMETERS USING BAYESIAN ESTIMATION (ISS=1 HANDLES SPARSE DATA)
    cl_fit <- bnlearn::bn.fit(global_dag, cl_input, method = "bayes", iss = 1)
    
    # D. CALCULATE MARGINAL EFFECTS VIA CONDITIONAL PROBABILITY TABLE ANALYSIS
    cpt_risk <- as.data.frame(cl_fit[[target_node]]$prob)
    cpt_high <- cpt_risk[cpt_risk[[target_node]] == high_risk_label, ]
    
    # PIVOT TABLE TO COMPARE PC3 HIGH VS LOW ACROSS ALL PARENT STATES
    wide_cpt <- tidyr::pivot_wider(
      cpt_high, 
      names_from = !!sym(driver_node), 
      values_from = Freq, 
      names_prefix = "PC3_"
    )
    
    col_low  <- paste0("PC3_", low_pc3_label)
    col_high <- paste0("PC3_", high_pc3_label)
    
    if (col_low %in% colnames(wide_cpt) & col_high %in% colnames(wide_cpt)) {
        # COMPUTE DIFFERENCE IN RISK FOR EVERY BACKGROUND CONFIGURATION
        all_deltas <- wide_cpt[[col_high]] - wide_cpt[[col_low]]
        
        # 1. CALCULATE MEAN NET EFFECT (DIRECTIONAL IMPACT)
        mean_delta <- mean(all_deltas, na.rm = TRUE)
        
        # 2. CALCULATE MEAN ABSOLUTE INFLUENCE (TOTAL BIOLOGICAL WEIGHT)
        mean_abs_delta <- mean(abs(all_deltas), na.rm = TRUE)
        
        # 3. CALCULATE INTERACTION INDEX (SD OF EFFECTS / CONTEXT DEPENDENCY)
        interaction_idx <- sd(all_deltas, na.rm = TRUE)
        
        cat("    MEAN DELTA RISK:", round(mean_delta, 4), "\n")
        cat("    MEAN ABSOLUTE INFLUENCE:", round(mean_abs_delta, 4), "\n")
    } else {
        mean_delta <- NA; mean_abs_delta <- NA; interaction_idx <- NA
    }
    
    # E. STORE RESULTS IN DATAFRAME
    refit_results[[cl]] <- data.frame(
        Cluster = cl,
        N = cl_n,
        Mean_Net_Effect = mean_delta,
        Mean_Absolute_Influence = mean_abs_delta,
        Interaction_Index = interaction_idx
    )
    
    # SAVE THE CLUSTER-SPECIFIC FITTED OBJECT
    saveRDS(cl_fit, file.path(output_dir, "mofa", paste0("Refitted_Graph_", cl, ".rds")))
}
>>> REFITTING GLOBAL GRAPH FOR: Cluster_2 
    MEAN DELTA RISK: -0.1505 
    MEAN ABSOLUTE INFLUENCE: 0.241 
>>> REFITTING GLOBAL GRAPH FOR: Cluster_3 
    MEAN DELTA RISK: -0.0541 
    MEAN ABSOLUTE INFLUENCE: 0.1563 
>>> REFITTING GLOBAL GRAPH FOR: Cluster_1 
    MEAN DELTA RISK: 0.0324 
    MEAN ABSOLUTE INFLUENCE: 0.1136 
# 3. FINAL SUMMARY AND EXPORT --------------------------------------------------
final_delta_df <- do.call(rbind, refit_results)
rownames(final_delta_df) <- NULL
print(final_delta_df)

write.csv(final_delta_df, file.path(output_dir, "mofa", "Cluster_PC3_Marginal_Effects_Full.csv"), row.names = FALSE)

16. VIOLIN PLOTS ###========================================================

# 1. DATA EXTRACTION -----------------------------------------------------------
SELECTION_HALLMARKS <- list(
   rppa = c("ERALPHA", "PR", "HER2", "GATA3"),
   rnaseq = c("SOX10", "MKI67") 
)
HALLMARK_DATA <- as.data.frame(get_data(mae, SELECTION_HALLMARKS, c("mofa_cluster")))

# 2. LONG FORMAT PREP ----------------------------------------------------------
TARGET_FEATURES <- c("ERALPHA", "PR", "HER2", "GATA3", "SOX10", "MKI67")
HALLMARK_LONG <- HALLMARK_DATA %>%
  select(mofa_cluster, all_of(TARGET_FEATURES)) %>%
  filter(!is.na(mofa_cluster)) %>%
  mutate(mofa_cluster = factor(mofa_cluster)) %>%
  pivot_longer(cols = all_of(TARGET_FEATURES), 
               names_to = "Feature", 
               values_to = "Level")

# 3. SETTINGS ------------------------------------------------------------------
MY_PALETTE <- c("#006EAE", "#CA9B23", "#C5373D")
MY_COMPARISONS <- combn(levels(HALLMARK_LONG$mofa_cluster), 2, simplify = FALSE)

# 4. PLOT GENERATION -----------------------------------------------------------
P_HALLMARKS_FINAL <- ggplot(HALLMARK_LONG, aes(x = mofa_cluster, y = Level, fill = mofa_cluster)) +
  geom_violin(trim = FALSE, alpha = 0.7, color = "black") +
  geom_boxplot(width = 0.1, fill = "white", outlier.shape = NA, color = "black") +
  
  facet_wrap(~Feature, scales = "free", ncol = 3) +
  
  stat_compare_means(comparisons = MY_COMPARISONS, 
                     label = "p.signif", 
                     method = "wilcox.test",
                     step.increase = 0.15,
                     vjust = -1.1) +
  
  stat_compare_means(method = "kruskal.test", 
                     label = "p.format", 
                     label.x.npc = 0.05, 
                     label.y.npc = 0.95, 
                     size = 3) +
  
  # Sample sizes at the bottom
  stat_summary(fun.data = function(x) {
    return(data.frame(y = min(x) - (diff(range(x)) * 0.1), label = paste0("n=", length(x))))
  }, geom = "text", size = 2.5, fontface = "italic") +

  coord_cartesian(clip = "off") + 
  scale_y_continuous(expand = expansion(mult = c(0.15, 0.35))) +
  scale_fill_manual(values = MY_PALETTE) +
  
  theme_pubr() + 
  theme(
    legend.position = "none",
    # --- THE "TRUE TRANSPARENCY" TRINITY ---
    panel.background = element_rect(fill = "transparent", colour = NA), 
    plot.background  = element_rect(fill = "transparent", colour = NA),
    legend.background = element_rect(fill = "transparent", colour = NA),
    # ---------------------------------------
    strip.background = element_blank(),
    strip.text = element_text(face = "bold", size = 11),
    axis.text.x = element_text(angle = 45, hjust = 1),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    panel.spacing.y = unit(3, "lines"),
    plot.margin = margin(t = 30, r = 15, b = 15, l = 15)
  ) +
  labs(title = "Clinical Hallmark Expression by Cluster",
       x = NULL, y = "Relative Level")

# 5. SAVE COMMAND --------------------------------------------------------------
# Make sure to include bg = "transparent" here too!
ggsave(file.path(output_dir, "mofa", "Hallmarks_Final_Pure_Transparent.png"), 
       P_HALLMARKS_FINAL, 
       width = 12, height = 10, 
       dpi = 300, 
       bg = "transparent")

17. STEMNESS ###============================================================

## 17.1 GET DATA
# DEFINE VARIABLES
selection <- list(
  isoform_pca  = TRUE, 
  stemness     = "RNAss",
  mutation     = "TP53",
  rnaseq       = c(    
    
    # p53 CONTEXT
    "TP63", "TP73","PPP1R13B", "TP53BP2", "TP53BP1", "PPP1R13L", "MDM2", 
    
    # NEUROENDOCRINE MARKETS
    "ASCL1", "INSM1", "SYP", "CHGA", "GATA3", "VGF",
    
    # STEMNESS
    "SOX10", "SOX2", "MYC", "PROM1",
    
    # EMT
    "SNAI1", "ZEB1", "CDH1")
)

# DEFINE METADATA
meta_vars <- c("os_risk_score", "mofa_cluster")

# RUN get_data
stemness_data <- as.data.frame(get_data(mae, selection, meta_vars)) %>%
  rename(RISK = os_risk_score) %>%
  filter(mofa_cluster == "Cluster_2") %>%
  select(-mofa_cluster)

## 17.2 BLACKLIST
stemness_bl <- data.frame(
  from = "RISK", 
  to   = setdiff(colnames(mofa_data), "RISK")
)

## 17.3 AS FACTOR
stemness_data$TP53 <- as.factor(stemness_data$TP53)

## 17.4 RUN MAIN FUNCTION
my_nets <- run_networks(
  df_raw = stemness_data,
  n_restarts = 100,
  bl = stemness_bl,
  run_name = "stemness",
  ref_color = "#CA9B23",
  comp_color = "#F6DC87"
)
================================================
BN PIPELINE RUN:  stemness 
SAMPLE SIZE (N): 861 
================================================

>>> PROCESSING CONFIGURATION: 2disc.aic 

  Consensus Bayesian network

  model:
   [PC1][PC2|PC1][PC5|PC1][PC3|PC1:PC2:PC5][RNAss|PC2][PC4|PC1:PC2:PC3][TP53|RNAss][TP63|PC1:PC4:RNAss:TP53]
   [ZEB1|PC2:RNAss:TP63][TP73|PC3:PC5:ZEB1][TP53BP2|PC3:RNAss:TP73:ZEB1][CDH1|RNAss:TP73:TP53BP2:ZEB1]
   [PPP1R13B|TP63:TP73:CDH1][TP53BP1|PC3:TP53BP2:CDH1][INSM1|PC4:TP53:CDH1][PPP1R13L|RNAss:TP63:PPP1R13B:TP53BP1:ZEB1]
   [SYP|PC5:TP53BP1:INSM1][GATA3|PC2:RNAss:TP53:TP53BP1:ZEB1][CHGA|RNAss:TP63:SYP][SOX2|PC5:RNAss:INSM1:GATA3]
   [MDM2|TP53:PPP1R13L:CHGA:CDH1][VGF|PC2:RNAss:TP53BP2:SYP:CHGA][MYC|RNAss:TP63:CHGA:VGF:CDH1]
   [SNAI1|TP53:TP53BP1:MDM2:ZEB1][ASCL1|RNAss:TP53:TP53BP2:PPP1R13L:SNAI1][SOX10|RNAss:TP63:TP73:VGF:MYC]
   [RISK|PC3:PC5:SOX10][PROM1|TP63:SYP:SOX10:MYC:CDH1]
  nodes:                                 28 
  arcs:                                  93 
    undirected arcs:                     0 
    directed arcs:                       93 
  average markov blanket size:           11.00 
  average neighbourhood size:            6.64 
  average branching factor:              3.32 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.49 


>>> PROCESSING CONFIGURATION: 2disc.bic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 28 
  arcs:                                  45 
    undirected arcs:                     2 
    directed arcs:                       43 
  average markov blanket size:           4.29 
  average neighbourhood size:            3.21 
  average branching factor:              1.54 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.56 


>>> PROCESSING CONFIGURATION: 2disc.bde 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 28 
  arcs:                                  41 
    undirected arcs:                     2 
    directed arcs:                       39 
  average markov blanket size:           3.71 
  average neighbourhood size:            2.93 
  average branching factor:              1.39 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.73 


>>> PROCESSING CONFIGURATION: 3disc.aic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 28 
  arcs:                                  57 
    undirected arcs:                     3 
    directed arcs:                       54 
  average markov blanket size:           5.79 
  average neighbourhood size:            4.07 
  average branching factor:              1.93 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.46 


>>> PROCESSING CONFIGURATION: 3disc.bic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 28 
  arcs:                                  27 
    undirected arcs:                     12 
    directed arcs:                       15 
  average markov blanket size:           2.00 
  average neighbourhood size:            1.93 
  average branching factor:              0.54 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.11 


>>> PROCESSING CONFIGURATION: 3disc.bde 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 28 
  arcs:                                  21 
    undirected arcs:                     10 
    directed arcs:                       11 
  average markov blanket size:           1.64 
  average neighbourhood size:            1.50 
  average branching factor:              0.39 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.99 


DONE. All outputs saved to: C:/Users/alden/Dissertation/TP53_isoforms_project/output/stemness 

18. IMMUNE ###==============================================================

## 18.1 GET DATA
# DEFINE VARIABLES
selection <- list(
  isoform_pca  = TRUE, 
  mutation     = "TP53",
  rnaseq       = c(    
        "TP63", "TP73","PPP1R13B", "TP53BP2", "TP53BP1", "PPP1R13L", "MDM2"
    
    ),
  immune = c(
        "TAMsurr_score", "Tcell_receptors_score", "Bcell_receptors_score", "PD1_PDL1_score",
    "IFNG_score_21050467", "MHC2_21978456", "TGFB_PCA_17349583", "Troester_WoundSig_19887484",
    "Chemokine12_score", "Module11_Prolif_score"
    )
)

# DEFINE METADATA
meta_vars <- c("os_risk_score", "mofa_cluster")

# RUN get_data
immune_data <- as.data.frame(get_data(mae, selection, meta_vars)) %>%
  rename(
    # NEW_NAME = OLD_NAME
    RISK           = os_risk_score,
    TAMsurr        = TAMsurr_score,
    TCell_Rec      = Tcell_receptors_score,
    BCell_Rec      = Bcell_receptors_score,
    PD1_PDL1       = PD1_PDL1_score,
    IFNG           = IFNG_score_21050467,
    MHC2           = MHC2_21978456,
    TGFB           = TGFB_PCA_17349583,
    Wound_Healing  = Troester_WoundSig_19887484,
    Chemokine12    = Chemokine12_score,
    Proliferation  = Module11_Prolif_score
  ) %>%
  filter(mofa_cluster == "Cluster_2") %>%
  select(-mofa_cluster)

## 18.2 BLACKLIST
immune_bl <- data.frame(
  from = "RISK", 
  to   = setdiff(colnames(immune_data), "RISK")
)

## 18.3 AS FACTOR
immune_data$TP53 <- as.factor(immune_data$TP53)

## 18.4 RUN MAIN FUNCTION
my_nets <- run_networks(
  df_raw = immune_data,
  n_restarts = 100,
  bl = immune_bl,
  run_name = "immune",
  ref_color = "#C5373D",
  comp_color = "#E9A0A5"
)
================================================
BN PIPELINE RUN:  immune 
SAMPLE SIZE (N): 861 
================================================

>>> PROCESSING CONFIGURATION: 2disc.aic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 24 
  arcs:                                  78 
    undirected arcs:                     1 
    directed arcs:                       77 
  average markov blanket size:           10.50 
  average neighbourhood size:            6.50 
  average branching factor:              3.21 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.5 


>>> PROCESSING CONFIGURATION: 2disc.bic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 24 
  arcs:                                  42 
    undirected arcs:                     1 
    directed arcs:                       41 
  average markov blanket size:           4.83 
  average neighbourhood size:            3.50 
  average branching factor:              1.71 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.43 


>>> PROCESSING CONFIGURATION: 2disc.bde 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 24 
  arcs:                                  40 
    undirected arcs:                     2 
    directed arcs:                       38 
  average markov blanket size:           4.50 
  average neighbourhood size:            3.33 
  average branching factor:              1.58 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.51 


>>> PROCESSING CONFIGURATION: 3disc.aic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 24 
  arcs:                                  50 
    undirected arcs:                     5 
    directed arcs:                       45 
  average markov blanket size:           6.08 
  average neighbourhood size:            4.17 
  average branching factor:              1.88 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.49 


>>> PROCESSING CONFIGURATION: 3disc.bic 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 24 
  arcs:                                  28 
    undirected arcs:                     7 
    directed arcs:                       21 
  average markov blanket size:           2.67 
  average neighbourhood size:            2.33 
  average branching factor:              0.88 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.49 


>>> PROCESSING CONFIGURATION: 3disc.bde 

  Consensus Bayesian network

  model:
    [partially directed graph]
  nodes:                                 24 
  arcs:                                  26 
    undirected arcs:                     11 
    directed arcs:                       15 
  average markov blanket size:           2.42 
  average neighbourhood size:            2.17 
  average branching factor:              0.62 

  generation algorithm:                  Model Averaging 
  significance threshold:                0.25 


DONE. All outputs saved to: C:/Users/alden/Dissertation/TP53_isoforms_project/output/immune 

19. RFS ###=================================================================

## ============================================================
## 19. RANDOM FOREST SURVIVAL — FULL PIPELINE
## ============================================================

library(survival)
library(survminer)
library(randomForestSRC)
library(ggplot2)
library(patchwork)
library(dplyr)
library(scales)

## ============================================================
## 19.1 GET DATA
## ============================================================

# DEFINE VARIABLES
selection <- list(
  isoform_pca  = c("PC3", "PC5"),
  mofa_factors = c("Factor2", "Factor3", "Factor8"),
  stemness     = "RNAss",
  rnaseq       = c("SOX10"),
  immune       = c("Troester_WoundSig_19887484"),
  mutation     = "TP53"
)

# DEFINE METADATA
meta_vars <- c("OS", "OS.time", "mofa_cluster")

# RUN get_data
rfs_data <- as.data.frame(get_data(mae, selection, meta_vars)) %>%
  filter(mofa_cluster == "Cluster_2") %>%
  select(-mofa_cluster)

## ── PRE-FILTER DATA (7-Year Window) ─────────────────────────────────────────
rfs_clean <- rfs_data %>%
  filter(OS.time <= 2555) %>%
  drop_na()

cat(sprintf("\nSample size after filtering: N = %d\n", nrow(rfs_clean)))

Sample size after filtering: N = 517
## ── CREATE OUTPUT DIRECTORY ──────────────────────────────────────────────────

survival_dir <- file.path(output_dir, "survival")
if (!dir.exists(survival_dir)) dir.create(survival_dir, recursive = TRUE)
cat(sprintf("Saving plots to: %s\n", survival_dir))
Saving plots to: C:/Users/alden/Dissertation/TP53_isoforms_project/output/survival
## ── Shared save helper ───────────────────────────────────────────────────────
# Single function so dpi/bg are consistent everywhere

save_png <- function(plot, filename, width = 10, height = 6) {
  ggsave(
    filename = file.path(survival_dir, filename),
    plot     = plot,
    width    = width,
    height   = height,
    dpi      = 300,
    bg       = "transparent"
  )
  cat(sprintf("Saved: %s\n", filename))
}

## ============================================================
## 19.2 DEFINE FEATURE SETS
## ============================================================

all_features    <- colnames(rfs_clean)[!colnames(rfs_clean) %in% c("OS", "OS.time", "PC3", "PC5")]
features_full   <- c("PC3", "PC5", all_features)
features_no_p53 <- all_features

## ============================================================
## 19.3 MONTE CARLO CROSS-VALIDATION (20 × 50:50 Split)
## ============================================================

all_model_results <- list()
all_km_data       <- list()

for (model_type in c("Full", "No_p53")) {

  current_features <- if (model_type == "Full") features_full else features_no_p53
  rf_formula <- as.formula(
    paste("Surv(OS.time, OS) ~", paste(current_features, collapse = "+"))
  )

  results_list <- list()
  km_data_list <- list()

  for (i in 1:20) {

    set.seed(i)

    # A. 50:50 random split
    train_idx <- sample(seq_len(nrow(rfs_clean)), size = floor(0.50 * nrow(rfs_clean)))
    train_set <- rfs_clean[train_idx, ]
    test_set  <- rfs_clean[-train_idx, ]

    # B. Train RSF
    rf_model <- randomForestSRC::rfsrc(rf_formula, data = train_set, ntree = 1000)

    # C. Predict on test set
    rf_pred <- randomForestSRC:::predict.rfsrc(rf_model, test_set)
    test_set$Predicted_Risk <- rf_pred$predicted

    # D. Optimal cutpoint
    res_cut <- try(
      survminer::surv_cutpoint(test_set,
                               time      = "OS.time",
                               event     = "OS",
                               variables = "Predicted_Risk"),
      silent = TRUE
    )

    if (!inherits(res_cut, "try-error")) {

      optimal_cut    <- res_cut$cutpoint$cutpoint
      test_set$Group <- factor(
        ifelse(test_set$Predicted_Risk > optimal_cut, "High", "Low"),
        levels = c("Low", "High")
      )

      # E. Discrete Cox model
      cox_mod <- survival::coxph(Surv(OS.time, OS) ~ Group, data = test_set)
      sum_cox <- summary(cox_mod)

      results_list[[i]] <- data.frame(
        Iteration = i,
        Model     = model_type,
        P_Value   = sum_cox$logtest["pvalue"],
        HR        = sum_cox$conf.int[1],
        Lower_CI  = sum_cox$conf.int[3],
        Upper_CI  = sum_cox$conf.int[4],
        C_Index   = 1 - rf_model$err.rate[rf_model$ntree]
      )

      # F. Store test set for KM
      km_data_list[[i]] <- test_set %>%
        select(OS.time, OS, Group) %>%
        mutate(Iteration = i)
    }
  }

  all_model_results[[model_type]] <- dplyr::bind_rows(results_list)
  all_km_data[[model_type]]       <- dplyr::bind_rows(km_data_list)
}

## ============================================================
## 19.4 ITERATION TABLE — ALL 20 FOLDS + SUMMARY
## ============================================================

iteration_table <- dplyr::bind_rows(all_model_results) %>%
  filter(HR < 1000) %>%
  select(Model, Iteration, HR, Lower_CI, Upper_CI, P_Value, C_Index) %>%
  arrange(Model, Iteration)

cat("\n════════════════════════════════════════════════════════════════\n")

════════════════════════════════════════════════════════════════
cat("  PER-ITERATION RESULTS (all 20 folds, both models)\n")
  PER-ITERATION RESULTS (all 20 folds, both models)
cat("════════════════════════════════════════════════════════════════\n")
════════════════════════════════════════════════════════════════
print(iteration_table, row.names = FALSE, digits = 3)

summary_table <- iteration_table %>%
  group_by(Model) %>%
  summarise(
    N           = n(),
    Mean_HR     = mean(HR,       na.rm = TRUE),
    SD_HR       = sd(HR,         na.rm = TRUE),
    Mean_LCI    = mean(Lower_CI, na.rm = TRUE),
    SD_LCI      = sd(Lower_CI,   na.rm = TRUE),
    Mean_UCI    = mean(Upper_CI, na.rm = TRUE),
    SD_UCI      = sd(Upper_CI,   na.rm = TRUE),
    Mean_P      = mean(P_Value,  na.rm = TRUE),
    SD_P        = sd(P_Value,    na.rm = TRUE),
    Mean_CIndex = mean(C_Index,  na.rm = TRUE),
    SD_CIndex   = sd(C_Index,    na.rm = TRUE),
    .groups     = "drop"
  )

cat("\n════════════════════════════════════════════════════════════════\n")

════════════════════════════════════════════════════════════════
cat("  SUMMARY: MEAN ± SD ACROSS 20 ITERATIONS\n")
  SUMMARY: MEAN ± SD ACROSS 20 ITERATIONS
cat("════════════════════════════════════════════════════════════════\n")
════════════════════════════════════════════════════════════════
print(as.data.frame(summary_table), row.names = FALSE, digits = 3)

# ── Save tables as CSVs ───────────────────────────────────────────────────────

write.csv(iteration_table,
          file.path(survival_dir, "iterations_all_folds.csv"),
          row.names = FALSE)
cat("Saved: iterations_all_folds.csv\n")
Saved: iterations_all_folds.csv
write.csv(summary_table,
          file.path(survival_dir, "iterations_summary.csv"),
          row.names = FALSE)
cat("Saved: iterations_summary.csv\n")
Saved: iterations_summary.csv
## ============================================================
## 19.5 KAPLAN-MEIER: 20 ITERATION CURVES + MEAN 95% CI
## ============================================================

time_grid <- seq(0, 2555, by = 5)

# ── Helper: per-iteration step curves ────────────────────────────────────────

build_iter_curves <- function(km_data_model) {
  lapply(split(km_data_model, km_data_model$Iteration), function(iter_df) {
    lapply(c("Low", "High"), function(grp) {
      sub <- iter_df[iter_df$Group == grp, ]
      if (nrow(sub) < 2) return(NULL)
      fit <- survfit(Surv(OS.time, OS) ~ 1, data = sub)
      sf  <- stepfun(fit$time, c(1, fit$surv))
      data.frame(
        time      = time_grid,
        surv      = sf(time_grid),
        Group     = grp,
        Iteration = unique(iter_df$Iteration)
      )
    }) %>% dplyr::bind_rows()
  }) %>% dplyr::bind_rows()
}

# ── Helper: mean survival + 95% CI ───────────────────────────────────────────

build_mean_ci <- function(curves_df) {
  curves_df %>%
    group_by(Group, time) %>%
    summarise(
      mean_surv = mean(surv, na.rm = TRUE),
      sd_surv   = sd(surv,   na.rm = TRUE),
      n         = sum(!is.na(surv)),
      se_surv   = sd_surv / sqrt(n),
      ci_lo     = pmax(mean_surv - 1.96 * se_surv, 0),
      ci_hi     = pmin(mean_surv + 1.96 * se_surv, 1),
      .groups   = "drop"
    )
}

curves_full   <- build_iter_curves(all_km_data[["Full"]])
curves_nop53  <- build_iter_curves(all_km_data[["No_p53"]])
mean_ci_full  <- build_mean_ci(curves_full)
mean_ci_nop53 <- build_mean_ci(curves_nop53)

# ── Plot function ─────────────────────────────────────────────────────────────

plot_20_km <- function(curves_df, mean_ci_df, model_results_df, title_text) {

  hr_summary <- model_results_df %>%
    filter(HR < 1000) %>%
    summarise(
      mHR  = mean(HR,       na.rm = TRUE),
      sdHR = sd(HR,         na.rm = TRUE),
      mLCI = mean(Lower_CI, na.rm = TRUE),
      mUCI = mean(Upper_CI, na.rm = TRUE),
      mP   = mean(P_Value,  na.rm = TRUE),
      mCI  = mean(C_Index,  na.rm = TRUE)
    )

  subtitle_text <- sprintf(
    "Mean HR = %.2f (SD %.2f)  |  Mean 95%% CI [%.2f–%.2f]  |  Mean p = %.3f  |  Mean C-index = %.3f",
    hr_summary$mHR, hr_summary$sdHR,
    hr_summary$mLCI, hr_summary$mUCI,
    hr_summary$mP,   hr_summary$mCI
  )

  ggplot() +
    geom_step(data = curves_df,
              aes(x = time, y = surv, colour = Group,
                  group = interaction(Group, Iteration)),
              alpha = 0.50, linewidth = 0.45) +
    geom_ribbon(data = mean_ci_df,
                aes(x = time, ymin = ci_lo, ymax = ci_hi,
                    fill = Group, group = Group),
                alpha = 0.20) +
    geom_step(data = mean_ci_df,
              aes(x = time, y = mean_surv, colour = Group, group = Group),
              linewidth = 1.3) +
    scale_colour_manual(
      values = c(High = "#E84855", Low = "#2E86AB"),
      labels = c(High = "High Risk", Low = "Low Risk")
    ) +
    scale_fill_manual(
      values = c(High = "#E84855", Low = "#2E86AB"),
      labels = c(High = "High Risk", Low = "Low Risk")
    ) +
    scale_x_continuous(
      breaks = seq(0, 2555, by = 365),
      labels = paste0(0:7, "y")
    ) +
    scale_y_continuous(
      limits = c(0, 1),
      labels = scales::percent_format(accuracy = 1)
    ) +
    labs(
      title    = title_text,
      subtitle = subtitle_text,
      x        = "Time",
      y        = "Overall Survival Probability",
      colour   = "Risk Group",
      fill     = "Risk Group",
      caption  = "Faint lines = 20 Monte Carlo iterations (50:50 split) | Bold = mean | Ribbon = mean 95% CI"
    ) +
    theme_classic(base_size = 13) +
    theme(
      panel.background  = element_rect(fill = "transparent", colour = NA),
      plot.background   = element_rect(fill = "transparent", colour = NA),
      legend.background = element_rect(fill = "transparent", colour = NA),
      legend.key        = element_rect(fill = "transparent", colour = NA),
      plot.title        = element_text(face = "bold", size = 14),
      plot.subtitle     = element_text(size = 9.5, colour = "grey35"),
      plot.caption      = element_text(size = 8,   colour = "grey55"),
      legend.position   = "bottom",
      legend.title      = element_text(face = "bold")
    )
}

p_full  <- plot_20_km(curves_full,  mean_ci_full,
                      all_model_results[["Full"]],
                      "RFS Model: Full (with PC3/PC5)")

p_nop53 <- plot_20_km(curves_nop53, mean_ci_nop53,
                      all_model_results[["No_p53"]],
                      "RFS Model: No PC3/PC5")

print(p_full / p_nop53)

save_png(p_full,            "km_full_model.png",    width = 10, height = 6)
Saved: km_full_model.png
save_png(p_nop53,           "km_no_pc35_model.png", width = 10, height = 6)
Saved: km_no_pc35_model.png
save_png(p_full / p_nop53,  "km_both_models.png",   width = 10, height = 12)
Saved: km_both_models.png

## ============================================================
## 19.6 FINAL MODEL (Full) — VIMP + PARTIAL DEPENDENCE PLOTS
## ============================================================

cat("\nTraining final model on full dataset for VIMP and PDP...\n")

Training final model on full dataset for VIMP and PDP...
set.seed(42)
rf_formula_full <- as.formula(
  paste("Surv(OS.time, OS) ~", paste(features_full, collapse = "+"))
)

final_model <- randomForestSRC::rfsrc(
  rf_formula_full,
  data       = rfs_clean,
  ntree      = 1000,
  importance = TRUE
)

# ── VIMP table ────────────────────────────────────────────────────────────────

vimp_df <- data.frame(
  Feature = names(final_model$importance),
  VIMP    = as.numeric(final_model$importance)
) %>%
  arrange(desc(VIMP))

cat("\n════════════════════════════════════════\n")

════════════════════════════════════════
cat("  VARIABLE IMPORTANCE (VIMP) — Full Model\n")
  VARIABLE IMPORTANCE (VIMP) — Full Model
cat("════════════════════════════════════════\n")
════════════════════════════════════════
print(vimp_df, row.names = FALSE, digits = 4)

write.csv(vimp_df,
          file.path(survival_dir, "vimp_full_model.csv"),
          row.names = FALSE)
cat("Saved: vimp_full_model.csv\n")
Saved: vimp_full_model.csv
# ── VIMP plot ─────────────────────────────────────────────────────────────────

p_vimp <- ggplot(vimp_df,
                 aes(x = reorder(Feature, VIMP), y = VIMP, fill = VIMP > 0)) +
  geom_col(width = 0.7) +
  geom_hline(yintercept = 0, linetype = "dashed", colour = "grey50") +
  scale_fill_manual(
    values = c(`TRUE` = "#2E86AB", `FALSE` = "#E84855"),
    guide  = "none"
  ) +
  coord_flip() +
  labs(
    title   = "Variable Importance (VIMP) — Full RFS Model",
    x       = NULL,
    y       = "VIMP",
    caption = "Blue = positive importance  |  Red = feature may add noise"
  ) +
  theme_classic(base_size = 13) +
  theme(
    panel.background  = element_rect(fill = "transparent", colour = NA),
    plot.background   = element_rect(fill = "transparent", colour = NA),
    legend.background = element_rect(fill = "transparent", colour = NA),
    legend.key        = element_rect(fill = "transparent", colour = NA),
    plot.title        = element_text(face = "bold")
  )

print(p_vimp)
save_png(p_vimp, "vimp_full_model.png", width = 8, height = 6)
Saved: vimp_full_model.png
# ── Partial dependence plots ──────────────────────────────────────────────────
# plot.variable() is base R graphics; png() / dev.off() is the correct
# capture method. bg = "transparent" sets the device background.
# Note: the plot panels themselves will retain a white fill as this is
# controlled internally by randomForestSRC and cannot be overridden
# without reimplementing PDPs in ggplot2.

png(file.path(survival_dir, "pdp_full_model.png"),
    width  = 10, height = 8,
    units  = "in",
    res    = 300,
    bg     = "transparent")
  randomForestSRC::plot.variable(
    final_model,
    partial        = TRUE,
    sorted         = TRUE,
    plots.per.page = 3,
    main           = "Partial Dependence — Full RFS Model (with PC3/PC5)"
  )
dev.off()
png 
  2 

cat("Saved: pdp_full_model.png\n")
Saved: pdp_full_model.png
# Also render to screen
randomForestSRC::plot.variable(
  final_model,
  partial        = TRUE,
  sorted         = TRUE,
  plots.per.page = 3,
  main           = "Partial Dependence — Full RFS Model (with PC3/PC5)"
)


## ── Final summary ─────────────────────────────────────────────────────────────

cat("\n════════════════════════════════════════════════════════════════\n")

════════════════════════════════════════════════════════════════
cat("  ALL FILES SAVED TO:", survival_dir, "\n")
  ALL FILES SAVED TO: C:/Users/alden/Dissertation/TP53_isoforms_project/output/survival 
cat("════════════════════════════════════════════════════════════════\n")
════════════════════════════════════════════════════════════════
cat("  iterations_all_folds.csv\n")
  iterations_all_folds.csv
cat("  iterations_summary.csv\n")
  iterations_summary.csv
cat("  km_full_model.png\n")
  km_full_model.png
cat("  km_no_pc35_model.png\n")
  km_no_pc35_model.png
cat("  km_both_models.png\n")
  km_both_models.png
cat("  vimp_full_model.csv\n")
  vimp_full_model.csv
cat("  vimp_full_model.png\n")
  vimp_full_model.png
cat("  pdp_full_model.png\n")
  pdp_full_model.png
cat("════════════════════════════════════════════════════════════════\n")
════════════════════════════════════════════════════════════════
---
title: "R Notebook"
output: html_notebook
---

### 0. SETUP ###================================================================

```{r}
## 0.1 SET PATHS ---------------------------------------------------------------
# DEFINE ROOT DIRECTORY
root_path <- "C:/Users/alden/Dissertation/TP53_isoforms_project"

# SET OTHER DIRECTORIES RELATIVE TO ROOT
scripts_dir <- file.path(root_path, "scripts")
input_dir   <- file.path(root_path, "input")
output_dir  <- file.path(root_path, "output")

## 0.2 OTHER -------------------------------------------------------------------
# SET SEED FOR REPRODUCABILITY
set.seed(42)
```

### 1. INSTALL PACKAGES ###=====================================================

```{r}
## 1.1 INSTALL PACKAGES --------------------------------------------------------
# INSTALL BIOCMANAGER
if (!requireNamespace("BiocManager", quietly = TRUE)) {
    install.packages("BiocManager")
}

# LIST PACKAGES
packages <- c(
  "BiocManager", 
  "MultiAssayExperiment", 
  "SummarizedExperiment", 
  "GenomicRanges",
  "snapcount", 
  "edgeR", 
  "limma", 
  "GSVA", 
  "BiocParallel", 
  "MOFA2", 
  "factoextra", 
  "infotheo", 
  "bnlearn", 
  "survival", 
  "gridExtra",
  "dplyr", 
  "tidyr", 
  "ggplot2", 
  "stringr", 
  "tibble", 
  "data.table", 
  "readxl"
)

# CHECK FOR MISSING PACKAGES
missing_pkgs <- packages[!(packages %in% rownames(installed.packages()))]

# INSTALL
if (length(missing_pkgs) > 0) {
    BiocManager::install(missing_pkgs, update = FALSE, ask = FALSE)
}

# LOAD
invisible(lapply(packages, library, character.only = TRUE))
```

################################################################################

# PART A - BUILDING THE MAE

################################################################################

### 2. GET ISOFORM COUNTS ###===================================================

```{r}
## 2.1 DEFINE TARGET REGION ----------------------------------------------------
targets <- GenomicRanges::GRanges(
    seqnames = "chr17",
    ranges = IRanges(
        start = c(7673267, 7673207, 7685260, 7675239, 7674526),
        end = c(7673339, 7673266, 7686371, 7675493, 7674858)
    ),
    strand = "-",
    label = c("exon 9B", "exon 9y", "intron 2", "intron 4", "intron 6")
)

## 2.2 GET DATA ----------------------------------------------------------------
# INITIALIZE QUERY FOR TP53 ACROSS ALL TCGA SAMPLES
qb <- snapcount::QueryBuilder(compilation = "tcga", regions = "TP53")
rse <- snapcount::query_exon(qb, return_rse = TRUE)

# IDENTIFY OVERLAPS THAT MATCH hg38 COORDINATES
hits <- findOverlaps(targets, rowRanges(rse), type = "equal")

# EXTRACT COUNTS TO MATRIX AND APPLY LABLES
filtered_counts <- as.matrix(assay(rse)[subjectHits(hits), , drop = FALSE])
rownames(filtered_counts) <- targets$label[queryHits(hits)]

# AS MATRIX
isoform_counts <- as.matrix(filtered_counts)

## 2.3 SUMMARISE + SAVE --------------------------------------------------------
# GET HEAD AND DIMENSION
head(isoform_counts[, 1:10], 10)
dim(isoform_counts)

# SAVE
saveRDS(isoform_counts, file = file.path(output_dir, "TP53_isoform_counts.rds"))
```
### 3. GET METADATA ###=========================================================

```{r}
## 3.1 DEFINE METADATA ---------------------------------------------------------
# CREATE METADATA LIST AND RENAME
meta_cols <- c(
  
    # IDENTIFIERS
    barcode_full  =   "gdc_cases.samples.portions.analytes.aliquots.submitter_id",
    project       =   "gdc_cases.project.project_id",
    
    # LIBRARY INFORMATION
    lib.size      =   "mapped_read_count",
    a260_a280     =   "gdc_cases.samples.portions.analytes.a260_a280_ratio",
    
    # BRCA HORMONE STATUS
    er_status     =   "xml_breast_carcinoma_estrogen_receptor_status", 
    pr_status     =   "xml_breast_carcinoma_progesterone_receptor_status", 
    her2_ish      =   "xml_lab_procedure_her2_neu_in_situ_hybrid_outcome_type", 
    her2_ihc      =   "xml_lab_proc_her2_neu_immunohistochemistry_receptor_status",
    
    # THERAPY INFORMATION
    drug_name     =   "cgc_drug_therapy_drug_name",
    drug_type     =   "cgc_drug_therapy_pharmaceutical_therapy_type"
)

## 3.2 EXTRACT METADATA --------------------------------------------------------
meta <- as.data.frame(colData(rse)) %>%
  dplyr::select(any_of(meta_cols)) %>%
  dplyr::rename(any_of(meta_cols)) %>%
  
# DERIVE PATIENT IDENTIFIERS
  dplyr::mutate(
    rail_id            = rownames(.), 
    patient_barcode_15 = substr(barcode_full, 1, 15),
    patient_barcode_12 = substr(barcode_full, 1, 12),
    batch              = stringr::str_split_i(barcode_full, "-", 6)
  ) %>%
  
# CONSOLIDATE SUBTYPE
  dplyr::mutate(
    her2_consolidated = dplyr::case_when(
      her2_ish %in% c("Positive", "Negative") ~ her2_ish,
      her2_ihc %in% c("Positive", "Negative") ~ her2_ihc,
      TRUE ~ ""
    ),
    hr_consolidated = dplyr::case_when(
      er_status == "Positive" | pr_status == "Positive" ~ "Positive", 
      er_status == "Negative" & pr_status == "Negative" ~ "Negative",
      TRUE ~ ""
    ),

# CLASSIFY CANCER SUBTYPE (PAM50)
    breast_cancer_subtype = dplyr::case_when(
      project != "TCGA-BRCA" ~ "Non-BRCA",
      hr_consolidated == "Positive" & her2_consolidated == "Negative" ~ "LuminalA",
      hr_consolidated == "Positive" & her2_consolidated == "Positive" ~ "LuminalB",
      hr_consolidated == "Negative" & her2_consolidated == "Positive" ~ "HER2-enriched",
      hr_consolidated == "Negative" & her2_consolidated == "Negative" ~ "Triple-negative",
      TRUE ~ "Unclassified BRCA"
    )
  ) %>%

# DEDUPLICATE ALIQUOTS
  dplyr::distinct(patient_barcode_15, .keep_all = TRUE) %>%

# UPDATE ROWNAMES
  tibble::remove_rownames() %>%
  tibble::column_to_rownames("patient_barcode_15")

## 3.3 UPDATE "isoform_counts" -------------------------------------------------

# RENAME USING TCGA BARCODE
isoform_counts <- isoform_counts[, meta$rail_id] # EXPLICIT ORDERING FOR CORRECT MAPPING
colnames(isoform_counts) <- rownames(meta)

## 3.4 SUBSET TO BRCA ----------------------------------------------------------

# SUBSET "meta"
meta <- meta %>%
  dplyr::filter(project == "TCGA-BRCA")

# SAVE SAMPLE LIST
brca_samples <- rownames(meta)

# SUBSET "isoform_counts"
isoform_counts <- isoform_counts[, brca_samples]

## 3.3 SUMMARISE ---------------------------------------------------------------
# GET ISOFORM MEANS
isoform_means <- rowMeans(isoform_counts)
print(isoform_means)

# GET "meta" HEAD AND DIMENSION
head(meta)
dim(meta)

# GET "isoform_counts" HEAD AND DIMENSION
head(isoform_counts[, 1:10], 10)
dim(isoform_counts)
```

### 4. NORMALISATION ###========================================================

```{r}
# PREPARE DGE LIST
dge <- edgeR::DGEList(
    counts = isoform_counts, 
    samples = meta, 
    lib.size = meta$lib.size
)

# TMM NORMALISATION
dge <- edgeR::calcNormFactors(dge, method = "TMM")

# LOGCPM NORMALISATION
isoform_logcpm <- edgeR::cpm(dge, log = TRUE, prior.count = 1)
```

# BATCH CORRECTION (limma::removeBatchEffect) REDUCED MUTUAL INFORMATION BETWEEN ISOFORMS AND SURVIVAL

### 6. PCA ###==================================================================

```{r}
## 6.1 COMPUTE PCA -------------------------------------------------------------
isoform_pca <- prcomp(t(isoform_logcpm), scale. = TRUE)
head(isoform_pca$rotation, 10)

## 6.2 SCREE PLOT --------------------------------------------------------------
factoextra::fviz_eig(
    isoform_pca, 
    choice = "variance",  
    geom = "line",        
    addlabels = TRUE,     
) +
theme_minimal() +
theme(
    panel.grid.major = element_blank(), 
    panel.grid.minor = element_blank(),
    axis.line = element_line(colour = "black"),
    panel.border = element_blank()
)

## 6.3 BIPLOTS -----------------------------------------------------------------

# FUNCTION TI CREATE BIPLOTS
create_biplot <- function(pca_obj, pc_x, pc_y) {
  
  # PLOT 
  p <- factoextra::fviz_pca_var(
    pca_obj,
    axes = c(pc_x, pc_y),
    col.var = "contrib",
    gradient.cols = c("#006EAE", "#CA9B23", "#C5373D"),
    repel = TRUE,
    title = paste("PC", pc_x, "vs PC", pc_y)
  )
  
  # THEME
  p <- p +
    theme_minimal() +
    theme(
      panel.grid.major = element_blank(), 
      panel.grid.minor = element_blank(),
      axis.line = element_line(colour = "black"),
      plot.title = element_text(size = 10, face = "bold"),
      
      # SET ALL BACKGROUNDS TO TRANSPARENT
      plot.background = element_rect(fill = "transparent", color = NA),
      panel.background = element_rect(fill = "transparent", color = NA),
      legend.background = element_rect(fill = "transparent", color = NA)
    )
  
  return(p)
}

# RUN FUNCTION IN LOOP
plot_list <- list()
num_pcs <- ncol(isoform_pca$x)

for (i in 1:min(4, (num_pcs - 1))) {
  plot_list[[i]] <- create_biplot(isoform_pca, i, i + 1)
}

# ARANGE IN 2X2 GRID
combined_pca_plot <- gridExtra::grid.arrange(
  grobs = plot_list, 
  ncol = 2, 
  nrow = 2,
  top = grid::textGrob("TP53 Isoform PCA Loading Comparisons", 
                       gp = grid::gpar(fontsize = 14, fontface = "bold"))
)

# SAVE TO OUTPUT DIRECTORY
ggplot2::ggsave(
  filename = file.path(output_dir, "TP53_PCA_grid_TRANSPARENT.png"),
  plot = combined_pca_plot,
  width = 10, 
  height = 10, 
  dpi = 300,
  bg = "transparent" 
)

# 6.4 EXTRACT PATIENT SCORES ------------------------------------------------------------------
isoform_pca <- as.data.frame(isoform_pca$x)
```

### 7. LOAD TCGA DATA ###=======================================================

```{r}
## 7.1 RNASEQ ------------------------------------------------------------------
rnaseq_file <- "EB++AdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena.gz"
rnaseq <- data.table::fread(file.path(input_dir, rnaseq_file))
rnaseq <- as.matrix(rnaseq[, -1, with = FALSE], rownames = rnaseq[[1]])
rnaseq <- rnaseq[, intersect(colnames(rnaseq), brca_samples)]
rnaseq <- as.matrix(rnaseq[!duplicated(rownames(rnaseq)), ])

## 7.2 MUTATION ----------------------------------------------------------------
mutation_file <- "mc3.v0.2.8.PUBLIC.nonsilentGene.xena.gz"
mutation <- data.table::fread(file.path(input_dir, mutation_file))
mutation <- as.matrix(mutation[, -1, with = FALSE], rownames = mutation[[1]])
mutation <- mutation[, intersect(colnames(mutation), brca_samples)]

## 7.3 RPPA --------------------------------------------------------------------
rppa_file <- "TCGA-RPPA-pancan-clean.xena.gz"
rppa <- data.table::fread(file.path(input_dir, rppa_file))
rppa <- as.matrix(rppa[, -1, with = FALSE], rownames = rppa[[1]])
rppa <- rppa[, intersect(colnames(rppa), brca_samples)]

## 7.4 STEMNESS ----------------------------------------------------------------
stemness_file <- "StemnessScores_RNAexp_20170127.2.tsv.gz"
stemness <- data.table::fread(file.path(input_dir, stemness_file))
stemness <- as.matrix(stemness[, -1, with = FALSE], rownames = stemness[[1]])
stemness <- stemness[, intersect(colnames(stemness), brca_samples)]

## 7.5 HDR ---------------------------------------------------------------------
hdr_file <- "TCGA.HRD_withSampleID.txt.gz"
hdr <- data.table::fread(file.path(input_dir, hdr_file))
hdr <- as.matrix(hdr[, -1, with = FALSE], rownames = hdr[[1]])
hdr <- hdr[, intersect(colnames(hdr), brca_samples)]

## 7.6 IMMUNE ------------------------------------------------------------------
immune_subtype_file <- "TCGA_pancancer_10852whitelistsamples_68ImmuneSigs.xena.gz"
immune <- data.table::fread(file.path(input_dir, immune_subtype_file))
immune <- as.matrix(immune[, -1, with = FALSE], rownames = immune[[1]])
immune <- immune[, intersect(colnames(immune), brca_samples)]
```

### 8. LOAD TCGA METADATA ###===================================================

```{r}
## 8.1 SURVIVAL DATA -----------------------------------------------------------
survival_file <- "Survival_SupplementalTable_S1_20171025_xena_sp"
survival <- data.table::fread(file.path(input_dir, survival_file))

# RENAME
survival <- survival %>%
  dplyr::select(
    sample, 
    age = age_at_initial_pathologic_diagnosis, 
    histology = histological_type,
    menopause = menopause_status,
    tumor_status,
    OS, OS.time, DSS, DSS.time, DFI, DFI.time, PFI, PFI.time,
    ajcc_stage = ajcc_pathologic_tumor_stage
  )

# COLLAPSE STAGE INTO 4
survival <- survival %>%
  dplyr::mutate(
    stage = dplyr::case_when(
      stringr::str_detect(ajcc_stage, "^Stage I[A-B]?$") ~ "Stage 1",
      stringr::str_detect(ajcc_stage, "^Stage II[A-B]?$") ~ "Stage 2",
      stringr::str_detect(ajcc_stage, "^Stage III[A-C]?$") ~ "Stage 3",
      stringr::str_detect(ajcc_stage, "^Stage IV$") ~ "Stage 4",
      TRUE ~ NA_character_
    ),
    stage = factor(stage, levels = c("Stage 1", "Stage 2", "Stage 3", "Stage 4"))
  )

# SUBSET
survival <- survival %>%
  dplyr::filter(sample %in% brca_samples) %>%
  dplyr::distinct(sample, .keep_all = TRUE) %>%
  tibble::column_to_rownames("sample")

## 8.4 MERGE WITH "meta" -------------------------------------------------------
meta <- merge(meta, survival, by = "row.names", all.x = TRUE) %>% tibble::column_to_rownames("Row.names")
```

### 9. MOFA FACTORS ###=========================================================

```{r}
## 9.1 FEATURE SELECTION -------------------------------------------------------

# RNASEQ: TOP 5000 BY VARIANCE
rna_vars <- apply(rnaseq, 1, var, na.rm = TRUE)
mofa_rna_features <- names(sort(rna_vars, decreasing = TRUE))[1:5000]

# MUTATION: 1% FREQUENCY
mut_freq <- rowMeans(mutation, na.rm = TRUE)
mofa_mut_features <- names(mut_freq[mut_freq >= 0.01])

# RPPA: REMOVE EMPTY ROWS
mofa_protein_features <- rownames(rppa)[rowSums(!is.na(rppa)) > 0]

## 9.2 INITIALIZE MOFA ---------------------------------------------------------

# GET UNIQUE SAMPLES
all_samples <- unique(c(colnames(rnaseq), colnames(mutation), colnames(rppa), colnames(immune)))

# SUBSET DATA
mofa_rna <- rnaseq[mofa_rna_features, ]
mofa_mut <- mutation[mofa_mut_features, ]
mofa_prot <- rppa[mofa_protein_features, ]
mofa_imm  <- immune

# ALIGN
mofa_input <- list(
  RNAseq   = mofa_rna[, match(all_samples, colnames(mofa_rna))],
  Mutation = mofa_mut[, match(all_samples, colnames(mofa_mut))],
  Protein  = mofa_prot[, match(all_samples, colnames(mofa_prot))],
  Immune   = mofa_imm[, match(all_samples, colnames(mofa_imm))]
)

# FIX COLNAMES
for(i in 1:4) { 
  colnames(mofa_input[[i]]) <- all_samples 
}

# CREATE MOFA OBJECT
MOFAobj <- MOFA2::create_mofa(mofa_input)

## 9.3 TRAIN MODEL -------------------------------------------------------------

# SETTINGS
model_opts <- MOFA2::get_default_model_options(MOFAobj)
model_opts$num_factors <- 24 
model_opts$likelihoods[["Mutation"]] <- "bernoulli"

MOFAobj <- MOFA2::prepare_mofa(MOFAobj, model_options = model_opts) %>% 
  MOFA2::run_mofa(use_basilisk = TRUE)

## 9.4 SCREE PLOT --------------------------------------------------------------

# EXTRACT VARIANCE EXPLAINED
vars <- MOFA2::get_variance_explained(MOFAobj)$r2_per_factor[[1]]

# CALCULATE TOTAL VARIANCE PER FACTOR
scree_data <- data.frame(
  Factor = paste0("Factor", 1:nrow(vars)),
  Variance = rowSums(vars)
) %>%
  mutate(Factor = factor(Factor, levels = Factor))

# SCREE PLOT
ggplot(scree_data, aes(x = Factor, y = Variance, group = 1)) +
  geom_line(color = "steelblue", size = 1) +
  geom_point(color = "darkblue", size = 3) +
  theme_minimal() +
  labs(title = "MOFA Scree Plot",
       subtitle = "Total Variance Explained across all Omics Views",
       y = "Total Variance Explained (%)",
       x = "Latent Factors") +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

# EXTRACT FACTORS
mofa_factors <- MOFA2::get_factors(MOFAobj, factors = "all")[[1]]
colnames(mofa_factors) <- paste0("MOFA_", colnames(mofa_factors))

## 9.5 FACTOR LOADINGS ---------------------------------------------------------

# GET FACTOR LOADINGS (EXAMPLE USAGE)
imm_weights <- MOFA2::get_weights(MOFAobj, views = "Immune", factors = "Factor3")[[1]]
imm_weights

## 9.6 K-MEANS CLUSTERING ------------------------------------------------------

# EXTRACT TOP 8 FACTORS
mofa_factors_mat <- MOFA2::get_factors(MOFAobj)[[1]][, 1:8]

# RUN UMAP
MOFAobj <- MOFA2::run_umap(MOFAobj, factors = 1:8, n_neighbors = 15, min_dist = 0.1)

# DIAGNOSTIC PLOTS
p_elbow <- factoextra::fviz_nbclust(mofa_factors_mat, kmeans, method = "wss") +
  labs(title = "Elbow Method (WSS)", x = "Number of Clusters (k)") +
  theme_minimal()

p_sil <- factoextra::fviz_nbclust(mofa_factors_mat, kmeans, method = "silhouette") +
  labs(title = "Silhouette Analysis", x = "Number of Clusters (k)") +
  theme_minimal()

# SAVE DIAGNOSTIC PLOTS
diag_combined <- gridExtra::arrangeGrob(p_elbow, p_sil, ncol = 2)
ggplot2::ggsave(file.path(output_dir, "MOFA_Clustering_Diagnostics_Top8.png"), 
                diag_combined, width = 10, height = 5, bg = "transparent")

# RUN K-MEANS CLUSTERING (K=4)
km_res <- kmeans(mofa_factors_mat, centers = 3, nstart = 25)

# SYNC CLUSTERS TO METADATA AND MOFA OBJ
meta$mofa_cluster <- as.factor(paste0("Cluster_", km_res$cluster[match(rownames(meta), names(km_res$cluster))]))

MOFA2::samples_metadata(MOFAobj) <- meta %>%
  mutate(sample = rownames(.)) %>%
  relocate(sample)

# DEFINE THEME
clean_theme <- theme_minimal() + 
  theme(
    panel.grid.major = element_blank(), 
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "transparent", color = NA),
    plot.background  = element_rect(fill = "transparent", color = NA),
    axis.line        = element_line(color = "black"),
    legend.background = element_rect(fill = "transparent", color = NA)
  )

# PLOT UMAP (MOFA CLUSTERS)
p1 <- MOFA2::plot_dimred(
  MOFAobj, 
  method = "UMAP", 
  color_by = "mofa_cluster", 
  dot_size = 2
) + 
  scale_color_brewer(palette = "Set1") +
  clean_theme +
  labs(title = "MOFA UMAP: Multi-omic Clusters", subtitle = "K-means (K=4) on Top 8 Factors")

# PLOT UMAP (PAM50 SUBTYPE)
p2 <- MOFA2::plot_dimred(
  MOFAobj, 
  method = "UMAP", 
  color_by = "breast_cancer_subtype", 
  dot_size = 2
) + 
  scale_color_brewer(palette = "Dark2") +
  clean_theme +
  labs(title = "MOFA UMAP: PAM50 Subtypes", subtitle = "Clinical Label Comparison")

# SAVE UMAPS
umap_combined <- gridExtra::arrangeGrob(p1, p2, ncol = 2)
ggplot2::ggsave(file.path(output_dir, "MOFA_UMAP_Clusters_vs_Subtypes.png"), 
                umap_combined, width = 12, height = 6, bg = "transparent")

# DISPLAY UMAPS
grid::grid.draw(umap_combined)
```

### 10. MARTINGALE RESIDUALS ###=================================================

```{r}
## 10.1 EXTRACT RESIDUALS ------------------------------------------------------
target_outcomes <- list(os = "OS", pfi = "PFI")

for (m in names(target_outcomes)) {
  
  # 1. CONSTRUCT FORMULA
  surv_formula <- as.formula(paste0("survival::Surv(", target_outcomes[[m]], ".time, ", 
                                    target_outcomes[[m]], ") ~ 1"))
  
  # 2. FIT COX MODEL
  fit <- survival::coxph(surv_formula, data = meta, na.action = na.exclude)
  
  # 3. ATTACH TO META
  meta[[paste0(m, "_risk_score")]] <- residuals(fit, type = "martingale")
}
```

### 11. BUILD MAE OBJECT ###=====================================================

```{r}
## 11.1 TRANSPOSE --------------------------------------------------------------
isoform_pca <- t(isoform_pca)
mofa_factors <- t(mofa_factors)

## 11.2 PREPARE EXPERIMENT LIST ------------------------------------------------
EXP_LIST <- list(
    isoform_pca    = isoform_pca,  
    isoform_logcpm = isoform_logcpm,  
    rnaseq         = rnaseq,          
    mutation       = mutation,        
    rppa           = rppa,           
    stemness       = stemness,        
    hrd            = hdr,             
    immune         = immune,          
    mofa_factors   = mofa_factors
)

## 11.3 CREATE MAE OBJECT ------------------------------------------------------
mae <- MultiAssayExperiment::MultiAssayExperiment(
    experiments = EXP_LIST,
    colData     = meta
)

# PRINT SUMMARY
print(mae)

## 11.4 SAVING & CLEANUP -------------------------------------------------------
saveRDS(mae, file = file.path(output_dir, "TP53_Isoforms_MAE.rds"))
```

### 12. FUNCTION TO EXTRACT DATA ###=============================================

```{r}
## 12.1 FUNCTION TO RETRIEVE DATA -----------------------------------------------
get_data <- function(mae, config, clinical = NULL) {
  
  extracted <- list()
  for (assay in names(config)) {
    items <- config[[assay]]
    if (isTRUE(items)) {
      extracted[[assay]] <- mae[[assay]]
    } else {
      valid <- intersect(items, rownames(mae[[assay]]))
      extracted[[assay]] <- mae[[assay]][valid, , drop = FALSE]
    }
  }
  
  temp_mae <- MultiAssayExperiment(experiments = extracted, colData = colData(mae))
  
  df <- as.data.frame(longForm(temp_mae, colDataCols = clinical)) %>%
    dplyr::select(-assay, -colname) %>%  # Drop these to allow collapsing
    tidyr::pivot_wider(names_from = "rowname", values_from = "value") %>%
    # CONVERT BARCODE COLUMN TO ROWNAMES
    tibble::column_to_rownames("primary")
  
  return(df)
}
```

################################################################################

# PART B - BAYESIAN NETWORK INFERENCE

################################################################################

### 13. MAIN FUNCTION ###=======================================================

```{r}
#' RUN BAYESIAN NETWORKS
#' @param df_raw DATAFRAME
#' @param n_restarts NUMBER OF RANDOM RESTARTS
#' @param bl BLACKLIST
#' @param run_name FOLDER NAME OF OUTPUT
#' @param ref_color REFERENCE COLOUR (AIC)
#' @param comp_color COMPARISON COLOUR (MATCHES/MB)

#' RUN BAYESIAN NETWORKS (REFACTORED V3)
run_networks <- function(df_raw, n_restarts = 50, bl = NULL, 
                         run_name = "Final_TP53_Analysis",
                         ref_color = "orange3",
                         comp_color = "orange") {
  
  # --- 13.1 RUN 6 BNs -------------------------------
  run_path <- file.path(output_dir, run_name)
  dir.create(run_path, recursive = TRUE, showWarnings = FALSE)
  
  cat("================================================\n")
  cat("BN PIPELINE RUN: ", run_name, "\n")
  cat("SAMPLE SIZE (N):", nrow(df_raw), "\n")
  cat("================================================\n")
  
  disc_levels  <- c(2, 3)
  scores       <- c("aic", "bic", "bde")
  all_results  <- list()
  ref_id       <- "2disc.aic"
  target_node  <- "RISK"
  master_nodes <- colnames(df_raw)

  for (d in disc_levels) {
    df_disc <- bnlearn::discretize(df_raw, method = 'quantile', breaks = d)
    df_disc <- as.data.frame(lapply(df_disc, droplevels))[, master_nodes]

    for (s in scores) {
      id <- paste0(d, "disc.", s)
      cat("\n>>> PROCESSING CONFIGURATION:", id, "\n")
      
      set.seed(123)
      starts <- c(list(empty.graph(master_nodes)),
                  random.graph(nodes = master_nodes, num = n_restarts - 1, method = "ic-dag", max.degree = 1))
      
      net_list <- lapply(starts, function(g) {
        tryCatch({ structural.em(df_disc, maximize = "hc", start = g, 
                                 maximize.args = list(score = s, blacklist = bl))
        }, error = function(e) return(NULL))
      })
      
      net_list <- net_list[!sapply(net_list, is.null)]
      avg_dag  <- averaged.network(custom.strength(net_list, nodes = master_nodes))
      
      print(avg_dag) 
      
      mb_nodes <- mb(avg_dag, target_node)
      curr_amat <- amat(avg_dag)[master_nodes, master_nodes]
      
      res_obj <- list(dag = avg_dag, adjacency = curr_amat, nodes = master_nodes, data = df_disc, mb = mb_nodes)
      all_results[[id]] <- res_obj
      
      saveRDS(res_obj, file.path(run_path, paste0(id, ".rds")))
      assign(paste0(run_name, ".", id), res_obj, envir = .GlobalEnv)
      
      # HIGH RES BN PLOTS
      png(file.path(run_path, paste0(id, ".png")), width = 2400, height = 2100, res = 300, bg = "transparent", type = "cairo")
      if (length(mb_nodes) > 0) {
        graphviz.plot(avg_dag, highlight = list(nodes = mb_nodes, fill = comp_color, col = "black"), 
                      main = paste(id, "| MB of", target_node))
      } else { 
        graphviz.plot(avg_dag, main = paste(id, "| No MB for", target_node)) 
      }
      dev.off()
    }
  }

  # --- 13.2 ADJACENCY MATRIX -------------------------------
  # (Keeping your original adjacency grid logic as requested)
  adj_mats <- lapply(all_results, function(x) x$adjacency)
  summed_mat <- Reduce("+", adj_mats)
  universal_mat <- (summed_mat == length(adj_mats)) 
  ref_mat <- all_results[[ref_id]]$adjacency
  plot_list <- list()
  config_names <- names(all_results)

  for (i in seq_along(config_names)) {
    id <- config_names[i]
    curr_mat <- all_results[[id]]$adjacency
    
    plot_df <- as.data.frame(curr_mat) %>%
      tibble::rownames_to_column("from") %>%
      tidyr::pivot_longer(-from, names_to = "to", values_to = "exists") %>%
      rowwise() %>%
      mutate(
        is_univ = universal_mat[from, to],
        is_ref  = ref_mat[from, to],
        category = case_when(
          exists == 1  & is_univ == TRUE ~ "Universal",
          id == ref_id & exists == 1 & is_univ == FALSE ~ "Reference Only",
          exists == 1  & is_ref == 1 & is_univ == FALSE ~ "Match Ref",
          exists == 1  & is_ref == 0 & is_univ == FALSE ~ "New Edge",
          TRUE ~ "None"
        )
      ) %>% ungroup()

    is_left_col   <- i %in% c(1, 4)
    is_bottom_row <- i %in% c(4, 5, 6)

    p <- ggplot(plot_df, aes(x = factor(to, levels = master_nodes), 
                             y = factor(from, levels = rev(master_nodes)))) +
      geom_tile(aes(fill = category), color = "gray85", linewidth = 0.2) +
      scale_fill_manual(
        values = c("Universal" = "black", "Reference Only" = ref_color, 
                   "Match Ref" = comp_color, "New Edge" = "lightgray", "None" = "white"),
        guide = "none"
      ) +
      labs(title = id, x = NULL, y = NULL) +
      theme_minimal() +
      theme(
        aspect.ratio = 1,
        panel.border = element_rect(color = "black", fill = NA, linewidth = 0.8),
        panel.grid   = element_blank(),
        plot.title   = element_text(hjust = 0.5, size = 10, face = "bold"),
        axis.text.x  = if(is_bottom_row) element_text(angle = 90, vjust = 0.5, hjust = 1, size = 7) else element_blank(),
        axis.ticks.x = if(is_bottom_row) element_line(color = "black") else element_line(color = NA),
        axis.text.y  = if(is_left_col) element_text(size = 7) else element_blank(),
        axis.ticks.y = if(is_left_col) element_line(color = "black") else element_line(color = NA),
        axis.ticks.length = unit(3, "pt"),
        plot.margin = margin(5, 5, 5, 5),
        plot.background = element_rect(fill = "transparent", color = NA)
      )
    plot_list[[i]] <- p
  }

  final_grid <- patchwork::wrap_plots(plot_list, ncol = 3, nrow = 2) + 
    plot_annotation(
      title = paste("Structural Stability Analysis:", run_name),
      subtitle = "Row 1: 2-level Discretization | Row 2: 3-level Discretization",
      theme = theme(plot.title = element_text(size = 14, face = "bold", hjust = 0.5),
                    plot.subtitle = element_text(size = 11, hjust = 0.5),
                    plot.background = element_rect(fill = "transparent", color = NA))
    )
  
  ggsave(file.path(run_path, "Adjacency_Comparison_Grid_Final.png"), 
         plot = final_grid, width = 11, height = 7.5, bg = "transparent", type = "cairo")

  # --- 13.3 CPT (REFERENCE ONLY) -------------------------------
  ref_res  <- all_results[[ref_id]]
  mb_nodes <- ref_res$mb
  
  # Identify variables and order them (putting PC3 before Risk if present)
  pc_in_mb <- grep("^PC[1-5]$", mb_nodes, value = TRUE)
  other_mb <- setdiff(mb_nodes, pc_in_mb)
  
  if ("PC3" %in% pc_in_mb) {
    ordered_vars <- c(other_mb, setdiff(pc_in_mb, "PC3"), "PC3", target_node)
  } else {
    ordered_vars <- c(other_mb, pc_in_mb, target_node)
  }
  
  # 1. Generate Frequency Table (N) and Probability Table
  raw_counts   <- table(ref_res$data[, ordered_vars])
  prob_table   <- prop.table(raw_counts, margin = 1:(length(ordered_vars) - 1))
  
  # 2. Convert to DataFrames
  df_probs  <- as.data.frame(prob_table)
  df_counts <- as.data.frame(raw_counts)
  
  # 3. Merge Probabilities and Sample Numbers (N)
  final_cpt_output <- df_probs %>%
    dplyr::rename(Probability = Freq) %>%
    dplyr::mutate(N = df_counts$Freq)

  # 4. Save CSV
  write.csv(final_cpt_output, file.path(run_path, "Reference_Risk_CPT_with_N.csv"), row.names = FALSE)

  # --- 13.4 MOSAIC PLOT (REFERENCE ONLY) -------------------------------
  n_configs <- prod(dim(counts_table)[-length(dim(counts_table))])
  p_high    <- as.vector(cpt_table)[(n_configs + 1):(2 * n_configs)]
  p_high[is.na(p_high)] <- 0
  
  p_contrast <- 1 / (1 + exp(-10 * (p_high - 0.5))) 
  grad_pal   <- colorRampPalette(c("#F5F5F5", ref_color))(100)
  risk_colors <- grad_pal[round(p_contrast * 99) + 1]
  color_array <- array(c(rep("#FFFFFF00", n_configs), risk_colors), dim = dim(counts_table))

  # MOSAIC PLOT WITH THIN BLACK OUTLINES
  png(file.path(run_path, "Reference_Mosaic_Risk_Final.png"), width = 3600, height = 2700, res = 300, bg = "transparent", type = "cairo")
  mosaic(counts_table, 
         gp = gpar(fill = color_array, col = "black", lwd = 0.5), # col = "black" for thin black outlines
         main = paste("Reference Run:", run_name, "Prognostic Risk Hierarchy"),
         labeling = labeling_border(gp_labels = gpar(fontsize = 9, fontface = "bold"), rot_labels = c(0, 90, 0, 90)))
  dev.off()

  cat("\nDONE. All outputs saved to:", run_path, "\n")
  return(all_results)
}
```

### 14. MOFA FACTORS ###========================================================

```{r}
## 14.1 GET DATA
# DEFINE VARIABLES
selection <- list(
  isoform_pca  = TRUE, 
  mofa_factors = paste0("Factor", 1:8), 
  hrd          = "HRD",
  stemness     = "RNAss"
)

# DEFINE METADATA
meta_vars <- c("os_risk_score")

# RUN get_data
mofa_data <- as.data.frame(get_data(mae, selection, meta_vars)) %>%
  rename(RISK = os_risk_score)

## 14.2 BLACKLIST
mofa_bl <- data.frame(
  from = "RISK", 
  to   = setdiff(colnames(mofa_data), "RISK")
)

## 14.3 RUN MAIN FUNCTION
my_nets <- run_networks(
  df_raw = mofa_data,
  n_restarts = 100,
  bl = mofa_bl,
  run_name = "mofa",
  ref_color = "#006EAE",
  comp_color = "#9BCAE9"
)
```

### 15. FIT EVIDENCE ###========================================================

```{r}
## 15.1 PREPARE THE TEMPLATE AND  DATA -----------------------------------------
global_dag <- bnlearn::cextend(mofa.2disc.aic$dag)
DAG_NODES  <- bnlearn::nodes(global_dag)

# RE-SYNC MOFA CLUSTERS TO YOUR DATA FRAME
mofa_data$mofa_cluster <- colData(mae)[rownames(mofa_data), "mofa_cluster"]

# PERFORM GLOBAL DISCRETIZATION ONLY ON DAG NODES
full_disc <- bnlearn::discretize(
  mofa_data[, intersect(colnames(mofa_data), DAG_NODES)], 
  method = 'quantile', 
  breaks = 2
)

# ADD CLUSTER ASSIGNMENTS TO THE DISCRETIZED DATA
full_disc$mofa_cluster <- mofa_data$mofa_cluster

# DEFINE TARGET OUTCOME AND PRIMARY DRIVER NODES
target_node     <- "RISK"
driver_node     <- "PC3"
high_risk_label <- levels(full_disc[[target_node]])[2]
low_pc3_label   <- levels(full_disc[[driver_node]])[1]
high_pc3_label  <- levels(full_disc[[driver_node]])[2]

# 15.2 THE REFIT LOOP ----------------------------------------------------------
# IDENTIFY UNIQUE CLUSTERS WHILE IGNORING
clusters <- as.character(unique(na.omit(full_disc$mofa_cluster)))
refit_results <- list()

# EXTRACT EXACT NODE NAMES REQUIRED BY THE GLOBAL DAG
dag_nodes <- bnlearn::nodes(global_dag)

for (cl in clusters) {
    cat(">>> REFITTING GLOBAL GRAPH FOR:", cl, "\n")
    
    # A. SUBSET DATA FOR THE CURRENT CLUSTER
    cl_data  <- full_disc[full_disc$mofa_cluster == cl, ]
    cl_n     <- nrow(cl_data)
    
    # B. ENSURE DATA ONLY CONTAINS DAG NODES TO PREVENT DIMENSION ERRORS
    cl_input <- cl_data[, intersect(colnames(cl_data), dag_nodes)]
    
    # VALIDATE THAT ALL REQUIRED NODES ARE PRESENT IN THE SUBSET
    missing_nodes <- setdiff(dag_nodes, colnames(cl_input))
    if(length(missing_nodes) > 0) {
        stop(paste("DATA IS MISSING NODES REQUIRED BY DAG:", paste(missing_nodes, collapse=", ")))
    }

    # C. REFIT PARAMETERS USING BAYESIAN ESTIMATION (ISS=1)
    cl_fit <- bnlearn::bn.fit(global_dag, cl_input, method = "bayes", iss = 1)
    
    # D. CALCULATE MARGINAL EFFECTS VIA CONDITIONAL PROBABILITY TABLE 
    cpt_risk <- as.data.frame(cl_fit[[target_node]]$prob)
    cpt_high <- cpt_risk[cpt_risk[[target_node]] == high_risk_label, ]
    
    # PIVOT TABLE TO COMPARE PC3 HIGH VS LOW ACROSS ALL PARENT STATES
    wide_cpt <- tidyr::pivot_wider(
      cpt_high, 
      names_from = !!sym(driver_node), 
      values_from = Freq, 
      names_prefix = "PC3_"
    )
    
    col_low  <- paste0("PC3_", low_pc3_label)
    col_high <- paste0("PC3_", high_pc3_label)
    
    if (col_low %in% colnames(wide_cpt) & col_high %in% colnames(wide_cpt)) {
        # COMPUTE DIFFERENCE IN RISK FOR EVERY BACKGROUND CONFIGURATION
        all_deltas <- wide_cpt[[col_high]] - wide_cpt[[col_low]]
        
        # 1. CALCULATE MEAN NET EFFECT (DIRECTIONAL IMPACT)
        mean_delta <- mean(all_deltas, na.rm = TRUE)
        
        # 2. CALCULATE MEAN ABSOLUTE INFLUENCE (TOTAL BIOLOGICAL WEIGHT)
        mean_abs_delta <- mean(abs(all_deltas), na.rm = TRUE)
        
        cat("    MEAN DELTA RISK:", round(mean_delta, 4), "\n")
        cat("    MEAN ABSOLUTE INFLUENCE:", round(mean_abs_delta, 4), "\n")
    } else {
        mean_delta <- NA; mean_abs_delta <- NA; interaction_idx <- NA
    }
    
    # E. STORE RESULTS IN DATAFRAME
    refit_results[[cl]] <- data.frame(
        Cluster = cl,
        N = cl_n,
        Mean_Net_Effect = mean_delta,
        Mean_Absolute_Influence = mean_abs_delta,
        Interaction_Index = interaction_idx
    )
    
    # SAVE THE CLUSTER-SPECIFIC FITTED OBJECT
    saveRDS(cl_fit, file.path(output_dir, "mofa", paste0("Refitted_Graph_", cl, ".rds")))
}

# 3. FINAL SUMMARY AND EXPORT --------------------------------------------------
final_delta_df <- do.call(rbind, refit_results)
rownames(final_delta_df) <- NULL
print(final_delta_df)

write.csv(final_delta_df, file.path(output_dir, "mofa", "Cluster_PC3_Marginal_Effects_Full.csv"), row.names = FALSE)
```

### 16. VIOLIN PLOTS ###========================================================

```{r}
# 16.1 DATA EXTRACTION ---------------------------------------------------------
SELECTION_HALLMARKS <- list(
   rppa = c("ERALPHA", "PR", "HER2", "GATA3"),
   rnaseq = c("SOX10", "MKI67") 
)

HALLMARK_DATA <- as.data.frame(get_data(mae, SELECTION_HALLMARKS, c("mofa_cluster")))

# 16.2 LONG FORMAT PREP --------------------------------------------------------
TARGET_FEATURES <- c("ERALPHA", "PR", "HER2", "GATA3", "SOX10", "MKI67")

HALLMARK_LONG <- HALLMARK_DATA %>%
  select(mofa_cluster, all_of(TARGET_FEATURES)) %>%
  filter(!is.na(mofa_cluster)) %>%
  mutate(mofa_cluster = factor(mofa_cluster)) %>%
  pivot_longer(cols = all_of(TARGET_FEATURES), 
               names_to = "Feature", 
               values_to = "Level")

# 16.3 SETTINGS ----------------------------------------------------------------
MY_PALETTE <- c("#006EAE", "#CA9B23", "#C5373D")
MY_COMPARISONS <- combn(levels(HALLMARK_LONG$mofa_cluster), 2, simplify = FALSE)

# 16.4 PLOT GENERATION ---------------------------------------------------------
P_HALLMARKS_FINAL <- ggplot(HALLMARK_LONG, aes(x = mofa_cluster, y = Level, fill = mofa_cluster)) +
  geom_violin(trim = FALSE, alpha = 0.7, color = "black") +
  geom_boxplot(width = 0.1, fill = "white", outlier.shape = NA, color = "black") +
  
  facet_wrap(~Feature, scales = "free", ncol = 3) +
  
  # SIGNIFICANCE STARS
  stat_compare_means(comparisons = MY_COMPARISONS, 
                     label = "p.signif", 
                     method = "wilcox.test") +
  
  coord_cartesian(clip = "off") + 
  
  scale_fill_manual(values = MY_PALETTE) +
  theme_pubr() + 
  theme(
    legend.position = "none",
    # THEME
    panel.background = element_blank(),
    plot.background  = element_blank(),
    legend.background = element_blank(),
    
    strip.background = element_blank(),
    strip.text = element_text(face = "bold", size = 12),
    
    axis.text.x = element_text(angle = 45, hjust = 1),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    panel.spacing.y = unit(2, "lines"),
    plot.margin = margin(t = 30, r = 15, b = 15, l = 15)
  ) +
  labs(title = "Clinical Hallmark Expression by Cluster",
       x = NULL, 
       y = "Relative Level")

# 16.5 SAVE  -------------------------------------------------------------------
ggsave(file.path(output_dir, "mofa", "Hallmarks_Final_Transparent.png"), 
       P_HALLMARKS_FINAL, 
       width = 12, height = 10, 
       dpi = 300, 
       bg = "transparent")
```

### 17. STEMNESS ###============================================================

```{r}
## 17.1 GET DATA
# DEFINE VARIABLES
selection <- list(
  isoform_pca  = TRUE, 
  stemness     = "RNAss",
  mutation     = "TP53",
  rnaseq       = c(    
    
    # p53 CONTEXT
    "TP63", "TP73","PPP1R13B", "TP53BP2", "TP53BP1", "PPP1R13L", "MDM2", 
    
    # NEUROENDOCRINE MARKETS
    "ASCL1", "INSM1", "SYP", "CHGA", "GATA3", "VGF",
    
    # STEMNESS
    "SOX10", "SOX2", "MYC", "PROM1",
    
    # EMT
    "SNAI1", "ZEB1", "CDH1")
)

# DEFINE METADATA
meta_vars <- c("os_risk_score", "mofa_cluster")

# RUN get_data
stemness_data <- as.data.frame(get_data(mae, selection, meta_vars)) %>%
  rename(RISK = os_risk_score) %>%
  filter(mofa_cluster == "Cluster_2") %>%
  select(-mofa_cluster)

## 17.2 BLACKLIST
stemness_bl <- data.frame(
  from = "RISK", 
  to   = setdiff(colnames(stemness_data), "RISK")
)

## 17.3 AS FACTOR
stemness_data$TP53 <- as.factor(stemness_data$TP53)

## 17.4 RUN MAIN FUNCTION
my_nets <- run_networks(
  df_raw = stemness_data,
  n_restarts = 100,
  bl = stemness_bl,
  run_name = "stemness",
  ref_color = "#CA9B23",
  comp_color = "#F6DC87"
)
```

### 18. IMMUNE ###==============================================================

```{r}
## 18.1 GET DATA
# DEFINE VARIABLES
selection <- list(
  isoform_pca  = TRUE, 
  mutation     = "TP53",
  rnaseq       = c(    
        "TP63", "TP73","PPP1R13B", "TP53BP2", "TP53BP1", "PPP1R13L", "MDM2"
    
    ),
  immune = c(
        "TAMsurr_score", "Tcell_receptors_score", "Bcell_receptors_score", "PD1_PDL1_score",
    "IFNG_score_21050467", "MHC2_21978456", "TGFB_PCA_17349583", "Troester_WoundSig_19887484",
    "Chemokine12_score", "Module11_Prolif_score"
    )
)

# DEFINE METADATA
meta_vars <- c("os_risk_score", "mofa_cluster")

# RUN get_data
immune_data <- as.data.frame(get_data(mae, selection, meta_vars)) %>%
  rename(
    # NEW_NAME = OLD_NAME
    RISK           = os_risk_score,
    TAMsurr        = TAMsurr_score,
    TCell_Rec      = Tcell_receptors_score,
    BCell_Rec      = Bcell_receptors_score,
    PD1_PDL1       = PD1_PDL1_score,
    IFNG           = IFNG_score_21050467,
    MHC2           = MHC2_21978456,
    TGFB           = TGFB_PCA_17349583,
    Wound_Healing  = Troester_WoundSig_19887484,
    Chemokine12    = Chemokine12_score,
    Proliferation  = Module11_Prolif_score
  ) %>%
  filter(mofa_cluster == "Cluster_2") %>%
  select(-mofa_cluster)

## 18.2 BLACKLIST
immune_bl <- data.frame(
  from = "RISK", 
  to   = setdiff(colnames(immune_data), "RISK")
)

## 18.3 AS FACTOR
immune_data$TP53 <- as.factor(immune_data$TP53)

## 18.4 RUN MAIN FUNCTION
my_nets <- run_networks(
  df_raw = immune_data,
  n_restarts = 100,
  bl = immune_bl,
  run_name = "immune",
  ref_color = "#C5373D",
  comp_color = "#E9A0A5"
)
```

### 19. RFS ###=================================================================


```{r}
## 19.1 GET DATA ---------------------------------------------------------------

# DEFINE VARIABLES
selection <- list(
  isoform_pca  = c("PC3", "PC5"),
  mofa_factors = c("Factor2", "Factor3", "Factor8"),
  stemness     = "RNAss",
  rnaseq       = c("SOX10"),
  immune       = c("Troester_WoundSig_19887484"),
  mutation     = "TP53"
)

# DEFINE METADATA
meta_vars <- c("OS", "OS.time", "mofa_cluster")

# RUN get_data
rfs_data <- as.data.frame(get_data(mae, selection, meta_vars)) %>%
  filter(mofa_cluster == "Cluster_2") %>%
  select(-mofa_cluster)

## 19.2 PRE-FILTER DATA (7-YEARS ONLY) -----------------------------------------
rfs_clean <- rfs_data %>%
  filter(OS.time <= 2555) %>%
  drop_na()

cat(sprintf("\nSample size after filtering: N = %d\n", nrow(rfs_clean)))

## 19.3 SET SURVIVAL DIR -------------------------------------------------------

survival_dir <- file.path(output_dir, "survival")

## 19.4 SAVE FUNCTION ----------------------------------------------------------

save_png <- function(plot, filename, width = 10, height = 6) {
  ggsave(
    filename = file.path(survival_dir, filename),
    plot     = plot,
    width    = width,
    height   = height,
    dpi      = 300,
    bg       = "transparent"
  )
  cat(sprintf("Saved: %s\n", filename))
}

## 19.5 DEFINE FEATURE SETS ----------------------------------------------------

all_features    <- colnames(rfs_clean)[!colnames(rfs_clean) %in% c("OS", "OS.time", "PC3", "PC5")]
features_full   <- c("PC3", "PC5", all_features)
features_no_p53 <- all_features

## 19.6 CROSS-VALIDATION (20 × 50:50 Split) ------------------------------------

all_model_results <- list()
all_km_data       <- list()

for (model_type in c("Full", "No_p53")) {

  current_features <- if (model_type == "Full") features_full else features_no_p53
  rf_formula <- as.formula(
    paste("Surv(OS.time, OS) ~", paste(current_features, collapse = "+"))
  )

  results_list <- list()
  km_data_list <- list()

  for (i in 1:20) {

    set.seed(i)

    # A. 50:50 RANDOM SPLITS
    train_idx <- sample(seq_len(nrow(rfs_clean)), size = floor(0.50 * nrow(rfs_clean)))
    train_set <- rfs_clean[train_idx, ]
    test_set  <- rfs_clean[-train_idx, ]

    # B. TRAIN RFS
    rf_model <- randomForestSRC::rfsrc(rf_formula, data = train_set, ntree = 1000)

    # C. PREDICT ON TEST SET
    rf_pred <- randomForestSRC:::predict.rfsrc(rf_model, test_set)
    test_set$Predicted_Risk <- rf_pred$predicted

    # D. OPTIMAL CUTPOINT
    res_cut <- try(
      survminer::surv_cutpoint(test_set,
                               time      = "OS.time",
                               event     = "OS",
                               variables = "Predicted_Risk"),
      silent = TRUE
    )

    if (!inherits(res_cut, "try-error")) {

      optimal_cut    <- res_cut$cutpoint$cutpoint
      test_set$Group <- factor(
        ifelse(test_set$Predicted_Risk > optimal_cut, "High", "Low"),
        levels = c("Low", "High")
      )

      # E. DISCRETE COX MODEL
      cox_mod <- survival::coxph(Surv(OS.time, OS) ~ Group, data = test_set)
      sum_cox <- summary(cox_mod)

      results_list[[i]] <- data.frame(
        Iteration = i,
        Model     = model_type,
        P_Value   = sum_cox$logtest["pvalue"],
        HR        = sum_cox$conf.int[1],
        Lower_CI  = sum_cox$conf.int[3],
        Upper_CI  = sum_cox$conf.int[4],
        C_Index   = 1 - rf_model$err.rate[rf_model$ntree]
      )

      # F. STORE TEST SETS FOR KM
      km_data_list[[i]] <- test_set %>%
        select(OS.time, OS, Group) %>%
        mutate(Iteration = i)
    }
  }

  all_model_results[[model_type]] <- dplyr::bind_rows(results_list)
  all_km_data[[model_type]]       <- dplyr::bind_rows(km_data_list)
}

## 19.7 ITERATION TABLE --------------------------------------------------------

iteration_table <- dplyr::bind_rows(all_model_results) %>%
  filter(HR < 1000) %>%
  select(Model, Iteration, HR, Lower_CI, Upper_CI, P_Value, C_Index) %>%
  arrange(Model, Iteration)

cat("\n════════════════════════════════════════════════════════════════\n")
cat("  PER-ITERATION RESULTS (all 20 folds, both models)\n")
cat("════════════════════════════════════════════════════════════════\n")
print(iteration_table, row.names = FALSE, digits = 3)

summary_table <- iteration_table %>%
  group_by(Model) %>%
  summarise(
    N           = n(),
    Mean_HR     = mean(HR,       na.rm = TRUE),
    SD_HR       = sd(HR,         na.rm = TRUE),
    Mean_LCI    = mean(Lower_CI, na.rm = TRUE),
    SD_LCI      = sd(Lower_CI,   na.rm = TRUE),
    Mean_UCI    = mean(Upper_CI, na.rm = TRUE),
    SD_UCI      = sd(Upper_CI,   na.rm = TRUE),
    Mean_P      = mean(P_Value,  na.rm = TRUE),
    SD_P        = sd(P_Value,    na.rm = TRUE),
    Mean_CIndex = mean(C_Index,  na.rm = TRUE),
    SD_CIndex   = sd(C_Index,    na.rm = TRUE),
    .groups     = "drop"
  )

cat("\n════════════════════════════════════════════════════════════════\n")
cat("  SUMMARY: MEAN ± SD ACROSS 20 ITERATIONS\n")
cat("════════════════════════════════════════════════════════════════\n")
print(as.data.frame(summary_table), row.names = FALSE, digits = 3)

# SAVE TABLES AS CSV

write.csv(iteration_table,
          file.path(survival_dir, "iterations_all_folds.csv"),
          row.names = FALSE)
cat("Saved: iterations_all_folds.csv\n")

write.csv(summary_table,
          file.path(survival_dir, "iterations_summary.csv"),
          row.names = FALSE)
cat("Saved: iterations_summary.csv\n")

## 19.8 KAPLAN-MEIER CURVES ----------------------------------------------------

time_grid <- seq(0, 2555, by = 5)

# FUNCTION TO MAKE INDIVIDUAL

build_iter_curves <- function(km_data_model) {
  lapply(split(km_data_model, km_data_model$Iteration), function(iter_df) {
    lapply(c("Low", "High"), function(grp) {
      sub <- iter_df[iter_df$Group == grp, ]
      if (nrow(sub) < 2) return(NULL)
      fit <- survfit(Surv(OS.time, OS) ~ 1, data = sub)
      sf  <- stepfun(fit$time, c(1, fit$surv))
      data.frame(
        time      = time_grid,
        surv      = sf(time_grid),
        Group     = grp,
        Iteration = unique(iter_df$Iteration)
      )
    }) %>% dplyr::bind_rows()
  }) %>% dplyr::bind_rows()
}

# 19.9 FUNCTION TO GET HR STATS ------------------------------------------------

build_mean_ci <- function(curves_df) {
  curves_df %>%
    group_by(Group, time) %>%
    summarise(
      mean_surv = mean(surv, na.rm = TRUE),
      sd_surv   = sd(surv,   na.rm = TRUE),
      n         = sum(!is.na(surv)),
      se_surv   = sd_surv / sqrt(n),
      ci_lo     = pmax(mean_surv - 1.96 * se_surv, 0),
      ci_hi     = pmin(mean_surv + 1.96 * se_surv, 1),
      .groups   = "drop"
    )
}

curves_full   <- build_iter_curves(all_km_data[["Full"]])
curves_nop53  <- build_iter_curves(all_km_data[["No_p53"]])
mean_ci_full  <- build_mean_ci(curves_full)
mean_ci_nop53 <- build_mean_ci(curves_nop53)

# 19.10 PLOT OVERLAID CURVES ---------------------------------------------------

plot_20_km <- function(curves_df, mean_ci_df, model_results_df, title_text) {

  hr_summary <- model_results_df %>%
    filter(HR < 1000) %>%
    summarise(
      mHR  = mean(HR,       na.rm = TRUE),
      sdHR = sd(HR,         na.rm = TRUE),
      mLCI = mean(Lower_CI, na.rm = TRUE),
      mUCI = mean(Upper_CI, na.rm = TRUE),
      mP   = mean(P_Value,  na.rm = TRUE),
      mCI  = mean(C_Index,  na.rm = TRUE)
    )

  subtitle_text <- sprintf(
    "Mean HR = %.2f (SD %.2f)  |  Mean 95%% CI [%.2f–%.2f]  |  Mean p = %.3f  |  Mean C-index = %.3f",
    hr_summary$mHR, hr_summary$sdHR,
    hr_summary$mLCI, hr_summary$mUCI,
    hr_summary$mP,   hr_summary$mCI
  )

  ggplot() +
    geom_step(data = curves_df,
              aes(x = time, y = surv, colour = Group,
                  group = interaction(Group, Iteration)),
              alpha = 0.50, linewidth = 0.45) +
    geom_ribbon(data = mean_ci_df,
                aes(x = time, ymin = ci_lo, ymax = ci_hi,
                    fill = Group, group = Group),
                alpha = 0.20) +
    geom_step(data = mean_ci_df,
              aes(x = time, y = mean_surv, colour = Group, group = Group),
              linewidth = 1.3) +
    scale_colour_manual(
      values = c(High = "#E84855", Low = "#2E86AB"),
      labels = c(High = "High Risk", Low = "Low Risk")
    ) +
    scale_fill_manual(
      values = c(High = "#E84855", Low = "#2E86AB"),
      labels = c(High = "High Risk", Low = "Low Risk")
    ) +
    scale_x_continuous(
      breaks = seq(0, 2555, by = 365),
      labels = paste0(0:7, "y")
    ) +
    scale_y_continuous(
      limits = c(0, 1),
      labels = scales::percent_format(accuracy = 1)
    ) +
    labs(
      title    = title_text,
      subtitle = subtitle_text,
      x        = "Time",
      y        = "Overall Survival Probability",
      colour   = "Risk Group",
      fill     = "Risk Group",
      caption  = "Faint lines = 20 Monte Carlo iterations (50:50 split) | Bold = mean | Ribbon = mean 95% CI"
    ) +
    theme_classic(base_size = 13) +
    theme(
      panel.background  = element_rect(fill = "transparent", colour = NA),
      plot.background   = element_rect(fill = "transparent", colour = NA),
      legend.background = element_rect(fill = "transparent", colour = NA),
      legend.key        = element_rect(fill = "transparent", colour = NA),
      plot.title        = element_text(face = "bold", size = 14),
      plot.subtitle     = element_text(size = 9.5, colour = "grey35"),
      plot.caption      = element_text(size = 8,   colour = "grey55"),
      legend.position   = "bottom",
      legend.title      = element_text(face = "bold")
    )
}

p_full  <- plot_20_km(curves_full,  mean_ci_full,
                      all_model_results[["Full"]],
                      "RFS Model: Full (with PC3/PC5)")

p_nop53 <- plot_20_km(curves_nop53, mean_ci_nop53,
                      all_model_results[["No_p53"]],
                      "RFS Model: No PC3/PC5")

print(p_full / p_nop53)

save_png(p_full,            "km_full_model.png",    width = 10, height = 6)
save_png(p_nop53,           "km_no_pc35_model.png", width = 10, height = 6)
save_png(p_full / p_nop53,  "km_both_models.png",   width = 10, height = 12)

## 19.11 MODEL ON ALL DATA (VIMP + PDPs) ---------------------------------------

cat("\nTraining final model on full dataset for VIMP and PDP...\n")

set.seed(42)
rf_formula_full <- as.formula(
  paste("Surv(OS.time, OS) ~", paste(features_full, collapse = "+"))
)

final_model <- randomForestSRC::rfsrc(
  rf_formula_full,
  data       = rfs_clean,
  ntree      = 1000,
  importance = TRUE
)

## 19.12 VIMP PLOT -------------------------------------------------------------

vimp_df <- data.frame(
  Feature = names(final_model$importance),
  VIMP    = as.numeric(final_model$importance)
) %>%
  arrange(desc(VIMP))

cat("\n════════════════════════════════════════\n")
cat("  VARIABLE IMPORTANCE (VIMP) — Full Model\n")
cat("════════════════════════════════════════\n")
print(vimp_df, row.names = FALSE, digits = 4)

write.csv(vimp_df,
          file.path(survival_dir, "vimp_full_model.csv"),
          row.names = FALSE)
cat("Saved: vimp_full_model.csv\n")

# PLOT VIMP PLOT

p_vimp <- ggplot(vimp_df,
                 aes(x = reorder(Feature, VIMP), y = VIMP, fill = VIMP > 0)) +
  geom_col(width = 0.7) +
  geom_hline(yintercept = 0, linetype = "dashed", colour = "grey50") +
  scale_fill_manual(
    values = c(`TRUE` = "#2E86AB", `FALSE` = "#E84855"),
    guide  = "none"
  ) +
  coord_flip() +
  labs(
    title   = "Variable Importance (VIMP) — Full RFS Model",
    x       = NULL,
    y       = "VIMP",
    caption = "Blue = positive importance  |  Red = feature may add noise"
  ) +
  theme_classic(base_size = 13) +
  theme(
    panel.background  = element_rect(fill = "transparent", colour = NA),
    plot.background   = element_rect(fill = "transparent", colour = NA),
    legend.background = element_rect(fill = "transparent", colour = NA),
    legend.key        = element_rect(fill = "transparent", colour = NA),
    plot.title        = element_text(face = "bold")
  )

print(p_vimp)
save_png(p_vimp, "vimp_full_model.png", width = 8, height = 6)

# SAVE PDP ---------------------------------------------------------------------

png(file.path(survival_dir, "pdp_full_model.png"),
    width  = 10, height = 8,
    units  = "in",
    res    = 300,
    bg     = "transparent")
  randomForestSRC::plot.variable(
    final_model,
    partial        = TRUE,
    sorted         = TRUE,
    plots.per.page = 3,
    main           = "Partial Dependence — Full RFS Model (with PC3/PC5)"
  )
dev.off()
cat("Saved: pdp_full_model.png\n")

randomForestSRC::plot.variable(
  final_model,
  partial        = TRUE,
  sorted         = TRUE,
  plots.per.page = 3,
  main           = "Partial Dependence — Full RFS Model (with PC3/PC5)"
)
```




