1 load libraries

knitr::opts_chunk$set(
  echo    = TRUE,
  message = FALSE,
  warning = FALSE,
  fig.align = "center",
  dpi     = 150
)

library(Seurat)
library(reticulate)
library(ggplot2)
library(pheatmap)
library(dplyr)
library(RColorBrewer)

# Python environment
Sys.setenv(RETICULATE_PYTHON = "/home/bioinfo/.virtualenvs/r-reticulate/bin/python")
use_python("/home/bioinfo/.virtualenvs/r-reticulate/bin/python", required = TRUE)

sc <- import("scanpy")
ad <- import("anndata")

cat("✓ Python libraries imported\n")
✓ Python libraries imported
# Output directory
dir.create("Output_Figures", showWarnings = FALSE)

# State constants
STATE_ORDER <- c("CD4 Naive","CD4 TCM","CD4 TEM","CD4 Temra/CTL","Treg")
STATE_COLORS <- c(
  "CD4 Naive"     = "#4472C4",
  "CD4 TCM"       = "#70AD47",
  "CD4 TEM"       = "#ED7D31",
  "CD4 Temra/CTL" = "#C00000",
  "Treg"          = "#7030A0"
)

2 Load Reference Object

cd4_ref <- readRDS("cd4_ref_dual_trajectory.rds")

cat("Cells loaded:", ncol(cd4_ref), "\n")
Cells loaded: 11466 
stopifnot(
  "predicted.celltype.l2 missing"    = "predicted.celltype.l2" %in% colnames(cd4_ref@meta.data),
  "milestone missing"                = "milestone" %in% colnames(cd4_ref@meta.data),
  "mst_pseudotime_norm missing"      = "mst_pseudotime_norm" %in% colnames(cd4_ref@meta.data),
  "monocle3_pseudotime_norm missing" = "monocle3_pseudotime_norm" %in% colnames(cd4_ref@meta.data),
  "pca reduction missing"            = "pca" %in% names(cd4_ref@reductions),
  "umap reduction missing"           = "umap" %in% names(cd4_ref@reductions)
)
cat("✅ All slots verified\n")
✅ All slots verified
cat("\nCell state distribution:\n")

Cell state distribution:
print(table(cd4_ref$predicted.celltype.l2))

    CD4 Naive       CD4 TCM       CD4 TEM CD4 Temra/CTL          Treg 
         2037          9067           145            10           207 
cat("\nMilestone distribution:\n")

Milestone distribution:
print(table(cd4_ref$milestone))

 M00  M01  M02  M03  M04  M05  M06  M07 
2037  388 8679   64   81   10  146   61 

3 UMAP Validation

p1 <- DimPlot(cd4_ref, group.by = "predicted.celltype.l2",
              reduction = "umap", label = TRUE, repel = TRUE,
              label.size = 3) + NoLegend() +
      ggtitle("Azimuth l2 states")

p2 <- DimPlot(cd4_ref, group.by = "milestone",
              reduction = "umap", label = TRUE, repel = TRUE,
              label.size = 3) + NoLegend() +
      ggtitle("Milestones M00–M07")

p3 <- DimPlot(cd4_ref, group.by = "seurat_clusters",
              reduction = "umap", label = TRUE, repel = TRUE,
              label.size = 3) + NoLegend() +
      ggtitle("Seurat clusters")

print(p1 | p2 | p3)


4 Build AnnData Object

# ── Build AnnData — PCA as X, no obsm needed ──────────────────────────────
DefaultAssay(cd4_ref) <- "integrated"

# Pass PCA embeddings directly as X (exactly like TF matrix worked before)
pca_embed  <- Embeddings(cd4_ref, "pca")[, 1:20]
umap_embed <- Embeddings(cd4_ref, "umap")

obs <- data.frame(
  row.names   = colnames(cd4_ref),
  celltype_l2 = as.character(cd4_ref$predicted.celltype.l2),
  milestone   = as.character(cd4_ref$milestone),
  mst_pt      = cd4_ref$mst_pseudotime_norm,
  monocle3_pt = cd4_ref$monocle3_pseudotime_norm
)

# ✅ PCA goes into X directly — same pattern that worked for TF matrix
adata <- ad$AnnData(X = pca_embed, obs = obs)

cat(sprintf("✅ AnnData: %d cells × %d PCs\n",
            py_to_r(adata$n_obs), py_to_r(adata$n_vars)))
✅ AnnData: 11466 cells × 20 PCs
# ── Neighbors using X directly (no obsm needed) ───────────────────────────
sc$pp$neighbors(
  adata,
  use_rep     = "X",     # ← matches working script pattern
  n_neighbors = 30L,
  metric      = "cosine"
)
cat("✅ KNN graph computed\n")
✅ KNN graph computed
# ── PAGA ──────────────────────────────────────────────────────────────────
sc$tl$paga(adata, groups = "celltype_l2")
cat("✅ PAGA state-level complete\n")
✅ PAGA state-level complete

5 PAGA Analysis

5.1 PAGA at State Level (l2)


cat("Running PAGA — Azimuth l2 states...\n")
Running PAGA — Azimuth l2 states...
sc$tl$paga(adata, groups = "celltype_l2")

py_run_string("
import numpy as np
paga_conn = r.adata.uns['paga']['connectivities'].todense()
state_names = r.adata.obs['celltype_l2'].cat.categories.values
np.savetxt('paga_state_conn.csv', paga_conn, delimiter=',', header=','.join(state_names), comments='')
print('State connectivity saved')
", local = TRUE)

# Read CSV - skip row names column, assign manually
paga_state_df <- read.csv("paga_state_conn.csv", header=TRUE, check.names=FALSE)
state_names <- colnames(paga_state_df)
paga_state_conn <- as.matrix(paga_state_df)
rownames(paga_state_conn) <- state_names
colnames(paga_state_conn) <- state_names

cat("✅ State-level PAGA complete\n\nConnectivity matrix:\n")
✅ State-level PAGA complete

Connectivity matrix:
print(round(paga_state_conn, 3))
              CD4 Naive CD4 TCM CD4 TEM CD4 Temra/CTL  Treg
CD4 Naive         0.000   0.630   0.006         0.000 0.143
CD4 TCM           0.630   0.000   0.844         0.889 0.490
CD4 TEM           0.006   0.844   0.000         1.000 0.037
CD4 Temra/CTL     0.000   0.889   1.000         0.000 0.000
Treg              0.143   0.490   0.037         0.000 0.000
# Clean up temp file
file.remove("paga_state_conn.csv")
[1] TRUE

5.2 PAGA at Milestone Level (M00–M07)

cat("Running PAGA — milestones M00-M07...\n")
Running PAGA — milestones M00-M07...
sc$tl$paga(adata, groups = "milestone")

py_run_string("
import numpy as np
paga_conn = r.adata.uns['paga']['connectivities'].todense()
ms_names = r.adata.obs['milestone'].cat.categories.values
np.savetxt('paga_ms_conn.csv', paga_conn, delimiter=',', header=','.join(ms_names), comments='')
print('Milestone connectivity saved')
", local = TRUE)

# Read CSV - skip row names column, assign manually
paga_ms_df <- read.csv("paga_ms_conn.csv", header=TRUE, check.names=FALSE)
ms_names <- colnames(paga_ms_df)
paga_ms_conn <- as.matrix(paga_ms_df)
rownames(paga_ms_conn) <- ms_names
colnames(paga_ms_conn) <- ms_names

cat("✅ Milestone-level PAGA complete\n\nConnectivity matrix:\n")
✅ Milestone-level PAGA complete

Connectivity matrix:
print(round(paga_ms_conn, 3))
      M00   M01   M02   M03   M04   M05   M06   M07
M00 0.000 0.183 0.650 0.013 0.000 0.000 0.203 0.000
M01 0.183 0.000 0.124 0.007 0.006 0.000 0.010 0.000
M02 0.650 0.124 0.000 0.741 0.992 0.929 0.440 0.681
M03 0.013 0.007 0.741 0.000 1.000 1.000 0.099 0.047
M04 0.000 0.006 0.992 1.000 0.000 1.000 0.000 0.000
M05 0.000 0.000 0.929 1.000 1.000 0.000 0.000 0.000
M06 0.203 0.010 0.440 0.099 0.000 0.000 0.000 1.000
M07 0.000 0.000 0.681 0.047 0.000 0.000 1.000 0.000
# Clean up temp file
file.remove("paga_ms_conn.csv")
[1] TRUE

5.3 Save PAGA Figures (Python)- state-level PAGA

# ✅ EXACT SCT script pattern
sc$tl$paga(adata, groups = "celltype_l2")
sc$pl$paga(adata, threshold = 0.15, show = FALSE, save = "_states.png")

file.rename("figures/paga_states.png", "Output_Figures/PAGA_states.png")
cat("✅ State PAGA saved\n")

5.4 Save PAGA Figures (Python)- Milestones PAGA


sc$tl$paga(adata, groups = "milestone")
sc$pl$paga(adata, threshold = 0.10, show = FALSE, save = "_milestones.png")
file.rename("figures/paga_milestones.png", "Output_Figures/PAGA_milestones.png")

knitr::include_graphics("Output_Figures/PAGA_states.png")

NA
NA
NA

5.5 PAGA Connectivity Heatmaps (R)


# ── State-level heatmap ───────────────────────────────────────────────────
conn_sub <- paga_state_conn[
  intersect(STATE_ORDER, rownames(paga_state_conn)),
  intersect(STATE_ORDER, colnames(paga_state_conn))
]

# Save to file
pheatmap(
  conn_sub,
  color           = colorRampPalette(c("white","#2980b9","#c0392b"))(50),
  display_numbers = round(conn_sub, 2),
  number_color    = "black",
  fontsize_number = 11,
  main            = "PAGA connectivity — CD4 T cell states",
  cluster_rows    = FALSE,
  cluster_cols    = FALSE,
  filename        = "Output_Figures/PAGA_state_heatmap.png",
  width = 7, height = 6
)

# Display in notebook  ← ADD THIS
pheatmap(
  conn_sub,
  color           = colorRampPalette(c("white","#2980b9","#c0392b"))(50),
  display_numbers = round(conn_sub, 2),
  number_color    = "black",
  fontsize_number = 11,
  main            = "PAGA connectivity — CD4 T cell states",
  cluster_rows    = FALSE,
  cluster_cols    = FALSE
  # NO filename = displays in notebook
)

# ── Milestone-level heatmap ───────────────────────────────────────────────
ms_order <- paste0("M", sprintf("%02d", 0:7))
ms_sub   <- paga_ms_conn[
  intersect(ms_order, rownames(paga_ms_conn)),
  intersect(ms_order, colnames(paga_ms_conn))
]

# Save to file
pheatmap(
  ms_sub,
  color           = colorRampPalette(c("white","#2980b9","#c0392b"))(50),
  display_numbers = round(ms_sub, 2),
  number_color    = "black",
  fontsize_number = 10,
  main            = "PAGA connectivity — Milestones M00–M07",
  cluster_rows    = FALSE,
  cluster_cols    = FALSE,
  filename        = "Output_Figures/PAGA_milestone_heatmap.png",
  width = 8, height = 7
)

# Display in notebook  ← ADD THIS
pheatmap(
  ms_sub,
  color           = colorRampPalette(c("white","#2980b9","#c0392b"))(50),
  display_numbers = round(ms_sub, 2),
  number_color    = "black",
  fontsize_number = 10,
  main            = "PAGA connectivity — Milestones M00–M07",
  cluster_rows    = FALSE,
  cluster_cols    = FALSE
  # NO filename = displays in notebook
)


cat("✅ Heatmaps saved\n")
✅ Heatmaps saved

6 MST Topology Validation

6.1 Define MST Edges

# MST topology:
# M00(Naive) → M01(TCM-early) → M02(TCM-late) → M03(TEM-early) → M04(TEM-late) → M05(Temra)
#                                        ↓
#                                   M06(Treg-naive) → M07(Treg-terminal)

# ── FIX 3: Correct state edges (removed duplicate CD4 TCM → CD4 TCM) ─────
mst_state_edges <- data.frame(
  from = c("CD4 Naive","CD4 TCM","CD4 TEM","CD4 TCM"),
  to   = c("CD4 TCM",  "CD4 TEM","CD4 Temra/CTL","Treg"),
  edge_label = c("Main axis","Main axis","Terminal","Treg branch")
)

mst_milestone_edges <- data.frame(
  from    = c("M00","M01","M02","M02","M03","M04","M06"),
  to      = c("M01","M02","M03","M06","M04","M05","M07"),
  biology = c("Naive→TCM-early","TCM-early→late","TCM-late→TEM-early",
              "TCM-late→Treg-naive","TEM-early→late","TEM-late→Temra",
              "Treg-naive→terminal")
)

cat("MST state edges:\n"); print(mst_state_edges)
MST state edges:
       from            to  edge_label
1 CD4 Naive       CD4 TCM   Main axis
2   CD4 TCM       CD4 TEM   Main axis
3   CD4 TEM CD4 Temra/CTL    Terminal
4   CD4 TCM          Treg Treg branch
cat("\nMST milestone edges:\n"); print(mst_milestone_edges)

MST milestone edges:
  from  to             biology
1  M00 M01     Naive→TCM-early
2  M01 M02      TCM-early→late
3  M02 M03  TCM-late→TEM-early
4  M02 M06 TCM-late→Treg-naive
5  M03 M04      TEM-early→late
6  M04 M05      TEM-late→Temra
7  M06 M07 Treg-naive→terminal

6.2 State-Level Validation

mst_state_edges$paga_score <- mapply(
  function(f, t) {
    if (f %in% rownames(paga_state_conn) && t %in% colnames(paga_state_conn))
      round(paga_state_conn[f, t], 3) else NA_real_
  },
  mst_state_edges$from, mst_state_edges$to
)

mst_state_edges$status <- case_when(
  mst_state_edges$paga_score > 0.50 ~ "Strong",
  mst_state_edges$paga_score > 0.30 ~ "Moderate",
  mst_state_edges$paga_score > 0.15 ~ "Weak",
  TRUE                               ~ "Not confirmed"
)

cat("\nMST state-level validation:\n")

MST state-level validation:
print(mst_state_edges)
       from            to  edge_label paga_score   status
1 CD4 Naive       CD4 TCM   Main axis      0.630   Strong
2   CD4 TCM       CD4 TEM   Main axis      0.844   Strong
3   CD4 TEM CD4 Temra/CTL    Terminal      1.000   Strong
4   CD4 TCM          Treg Treg branch      0.490 Moderate
cat(sprintf("\n%d/%d MST state edges confirmed (PAGA > 0.3)\n",
            sum(mst_state_edges$paga_score > 0.3, na.rm=TRUE),
            nrow(mst_state_edges)))

4/4 MST state edges confirmed (PAGA > 0.3)
ggplot(mst_state_edges,
       aes(x = reorder(paste(from, "→", to), paga_score),
           y = paga_score, fill = status)) +
  geom_col(width = 0.7) +
  geom_hline(yintercept = 0.3, linetype="dashed",
             colour="grey40", linewidth=0.8) +
  geom_hline(yintercept = 0.5, linetype="dotted",
             colour="grey40", linewidth=0.8) +
  annotate("text", x=0.6, y=0.32, label="Moderate (0.3)",
           size=3, hjust=0, colour="grey40") +
  annotate("text", x=0.6, y=0.52, label="Strong (0.5)",
           size=3, hjust=0, colour="grey40") +
  scale_fill_manual(
    values = c("Strong"="#27ae60","Moderate"="#2980b9",
               "Weak"="#f39c12","Not confirmed"="#c0392b"),
    name   = "Validation"
  ) +
  coord_flip() +
  theme_classic() +
  labs(x=NULL, y="PAGA connectivity score",
       title="Custom MST state edges — PAGA validation",
       subtitle="Dashed = 0.3 | Dotted = 0.5") +
  theme(plot.title=element_text(size=13, face="bold"))


ggsave("Output_Figures/MST_PAGA_state_validation.png",
       width=9, height=5, dpi=150)

6.3 Milestone-Level Validation

mst_milestone_edges$paga_score <- mapply(
  function(f, t) {
    if (f %in% rownames(paga_ms_conn) && t %in% colnames(paga_ms_conn))
      round(paga_ms_conn[f, t], 3) else NA_real_
  },
  mst_milestone_edges$from, mst_milestone_edges$to
)

mst_milestone_edges$status <- case_when(
  mst_milestone_edges$paga_score > 0.50 ~ "Strong",
  mst_milestone_edges$paga_score > 0.30 ~ "Moderate",
  mst_milestone_edges$paga_score > 0.15 ~ "Weak",
  TRUE                                   ~ "Not confirmed"
)

cat("MST milestone-level validation:\n")
MST milestone-level validation:
print(mst_milestone_edges)
  from  to             biology paga_score        status
1  M00 M01     Naive→TCM-early      0.183          Weak
2  M01 M02      TCM-early→late      0.124 Not confirmed
3  M02 M03  TCM-late→TEM-early      0.741        Strong
4  M02 M06 TCM-late→Treg-naive      0.440      Moderate
5  M03 M04      TEM-early→late      1.000        Strong
6  M04 M05      TEM-late→Temra      1.000        Strong
7  M06 M07 Treg-naive→terminal      1.000        Strong
cat(sprintf("\n%d/%d MST milestone edges confirmed (PAGA > 0.3)\n",
            sum(mst_milestone_edges$paga_score > 0.3, na.rm=TRUE),
            nrow(mst_milestone_edges)))

5/7 MST milestone edges confirmed (PAGA > 0.3)
ggplot(mst_milestone_edges,
       aes(x = reorder(paste(from, "→", to), paga_score),
           y = paga_score, fill = status)) +
  geom_col(width = 0.7) +
  geom_hline(yintercept = 0.3, linetype="dashed",
             colour="grey40", linewidth=0.8) +
  geom_text(aes(label=biology), hjust=-0.1, size=3) +
  scale_fill_manual(
    values = c("Strong"="#27ae60","Moderate"="#2980b9",
               "Weak"="#f39c12","Not confirmed"="#c0392b"),
    name   = "Validation"
  ) +
  coord_flip() +
  expand_limits(y=1.2) +
  theme_classic() +
  labs(x=NULL, y="PAGA connectivity score",
       title="MST milestone edges (M00–M07) — PAGA validation") +
  theme(plot.title=element_text(size=13, face="bold"))


ggsave("Output_Figures/MST_PAGA_milestone_validation.png",
       width=10, height=6, dpi=150)

6.4 Final Summary

cat("══════════════════════════════════════════\n")
══════════════════════════════════════════
cat("PAGA VALIDATION SUMMARY\n")
PAGA VALIDATION SUMMARY
cat("══════════════════════════════════════════\n")
══════════════════════════════════════════
cat(sprintf("State-level:     %d/%d edges confirmed (> 0.3)\n",
            sum(mst_state_edges$paga_score > 0.3, na.rm=TRUE),
            nrow(mst_state_edges)))
State-level:     4/4 edges confirmed (> 0.3)
cat(sprintf("Milestone-level: %d/%d edges confirmed (> 0.3)\n",
            sum(mst_milestone_edges$paga_score > 0.3, na.rm=TRUE),
            nrow(mst_milestone_edges)))
Milestone-level: 5/7 edges confirmed (> 0.3)
cat("\nFigures saved:\n")

Figures saved:
for (f in list.files("Output_Figures", pattern="PAGA|MST"))
  cat(sprintf("  ✅ %s\n", f))
  ✅ MST_PAGA_milestone_validation.png
  ✅ MST_PAGA_state_validation.png
  ✅ PAGA_milestone_heatmap.png
  ✅ PAGA_milestones.png
  ✅ PAGA_state_heatmap.png
  ✅ PAGA_states.png
# ── Save PAGA results ──────────────────────────────────────────────────────


# Add PAGA connectivity scores back to cd4_ref metadata (optional)
# So every cell knows its state's PAGA connectivity
saveRDS(cd4_ref, "cd4_ref_dual_trajectory_with_PAGA.rds")  # Re-save with any updates
cat("✅ cd4_ref updated → cd4_ref_dual_trajectory.rds\n")
✅ cd4_ref updated → cd4_ref_dual_trajectory.rds
# ── Quick reload check ─────────────────────────────────────────────────────
cat("\nTo reload PAGA results later:\n")

To reload PAGA results later:
cat("  paga_results <- readRDS('paga_validation_results.rds')\n")
  paga_results <- readRDS('paga_validation_results.rds')
cat("  paga_state_conn <- paga_results$state_connectivity\n")
  paga_state_conn <- paga_results$state_connectivity
cat("  paga_ms_conn    <- paga_results$milestone_connectivity\n")
  paga_ms_conn    <- paga_results$milestone_connectivity
cat("  adata <- anndata$read_h5ad('cd4_ref_PAGA_validated.h5ad')\n")
  adata <- anndata$read_h5ad('cd4_ref_PAGA_validated.h5ad')
# ── Final inventory ────────────────────────────────────────────────────────
cat("\n══════════════════════════════════════════\n")

══════════════════════════════════════════
cat("SAVED OBJECTS\n")
SAVED OBJECTS
cat("══════════════════════════════════════════\n")
══════════════════════════════════════════
cat("  cd4_ref_PAGA_validated.h5ad     ← AnnData with PAGA\n")
  cd4_ref_PAGA_validated.h5ad     ← AnnData with PAGA
cat("  paga_validation_results.rds     ← R matrices + validation\n")
  paga_validation_results.rds     ← R matrices + validation
cat("  cd4_ref_dual_trajectory.rds     ← Seurat object (unchanged)\n")
  cd4_ref_dual_trajectory.rds     ← Seurat object (unchanged)
cat("\nOutput_Figures/:\n")

Output_Figures/:
for (f in list.files("Output_Figures", pattern="PAGA|MST"))
  cat(sprintf("  ✅ %s\n", f))
  ✅ MST_PAGA_milestone_validation.png
  ✅ MST_PAGA_state_validation.png
  ✅ PAGA_milestone_heatmap.png
  ✅ PAGA_milestones.png
  ✅ PAGA_state_heatmap.png
  ✅ PAGA_states.png
---
title: "PAGA Topology Validation — CD4 T Cell Reference"
subtitle: "PAGA connectivity | Custom MST edge validation"
author: "Nasir Mahmood Abbasi"
date: "`r Sys.Date()`"
output:
  html_notebook:
    number_sections: true
    toc: true
    toc_float:
      collapsed: true
    theme: journal
---


# load libraries
```{r setup, include=TRUE}
knitr::opts_chunk$set(
  echo    = TRUE,
  message = FALSE,
  warning = FALSE,
  fig.align = "center",
  dpi     = 150
)

library(Seurat)
library(reticulate)
library(ggplot2)
library(pheatmap)
library(dplyr)
library(RColorBrewer)

# Python environment
Sys.setenv(RETICULATE_PYTHON = "/home/bioinfo/.virtualenvs/r-reticulate/bin/python")
use_python("/home/bioinfo/.virtualenvs/r-reticulate/bin/python", required = TRUE)

sc <- import("scanpy")
ad <- import("anndata")

cat("✓ Python libraries imported\n")

# Output directory
dir.create("Output_Figures", showWarnings = FALSE)

# State constants
STATE_ORDER <- c("CD4 Naive","CD4 TCM","CD4 TEM","CD4 Temra/CTL","Treg")
STATE_COLORS <- c(
  "CD4 Naive"     = "#4472C4",
  "CD4 TCM"       = "#70AD47",
  "CD4 TEM"       = "#ED7D31",
  "CD4 Temra/CTL" = "#C00000",
  "Treg"          = "#7030A0"
)

```

# Load Reference Object

```{r load-object}
cd4_ref <- readRDS("cd4_ref_dual_trajectory.rds")

cat("Cells loaded:", ncol(cd4_ref), "\n")

stopifnot(
  "predicted.celltype.l2 missing"    = "predicted.celltype.l2" %in% colnames(cd4_ref@meta.data),
  "milestone missing"                = "milestone" %in% colnames(cd4_ref@meta.data),
  "mst_pseudotime_norm missing"      = "mst_pseudotime_norm" %in% colnames(cd4_ref@meta.data),
  "monocle3_pseudotime_norm missing" = "monocle3_pseudotime_norm" %in% colnames(cd4_ref@meta.data),
  "pca reduction missing"            = "pca" %in% names(cd4_ref@reductions),
  "umap reduction missing"           = "umap" %in% names(cd4_ref@reductions)
)
cat("✅ All slots verified\n")

cat("\nCell state distribution:\n")
print(table(cd4_ref$predicted.celltype.l2))
cat("\nMilestone distribution:\n")
print(table(cd4_ref$milestone))
```

# UMAP Validation

```{r umap-check, fig.width=14, fig.height=5}
p1 <- DimPlot(cd4_ref, group.by = "predicted.celltype.l2",
              reduction = "umap", label = TRUE, repel = TRUE,
              label.size = 3) + NoLegend() +
      ggtitle("Azimuth l2 states")

p2 <- DimPlot(cd4_ref, group.by = "milestone",
              reduction = "umap", label = TRUE, repel = TRUE,
              label.size = 3) + NoLegend() +
      ggtitle("Milestones M00–M07")

p3 <- DimPlot(cd4_ref, group.by = "seurat_clusters",
              reduction = "umap", label = TRUE, repel = TRUE,
              label.size = 3) + NoLegend() +
      ggtitle("Seurat clusters")

print(p1 | p2 | p3)
```

---

# Build AnnData Object

```{r build-anndata}
# ── Build AnnData — PCA as X, no obsm needed ──────────────────────────────
DefaultAssay(cd4_ref) <- "integrated"

# Pass PCA embeddings directly as X (exactly like TF matrix worked before)
pca_embed  <- Embeddings(cd4_ref, "pca")[, 1:20]
umap_embed <- Embeddings(cd4_ref, "umap")

obs <- data.frame(
  row.names   = colnames(cd4_ref),
  celltype_l2 = as.character(cd4_ref$predicted.celltype.l2),
  milestone   = as.character(cd4_ref$milestone),
  mst_pt      = cd4_ref$mst_pseudotime_norm,
  monocle3_pt = cd4_ref$monocle3_pseudotime_norm
)

# ✅ PCA goes into X directly — same pattern that worked for TF matrix
adata <- ad$AnnData(X = pca_embed, obs = obs)

cat(sprintf("✅ AnnData: %d cells × %d PCs\n",
            py_to_r(adata$n_obs), py_to_r(adata$n_vars)))

# ── Neighbors using X directly (no obsm needed) ───────────────────────────
sc$pp$neighbors(
  adata,
  use_rep     = "X",     # ← matches working script pattern
  n_neighbors = 30L,
  metric      = "cosine"
)
cat("✅ KNN graph computed\n")

# ── PAGA ──────────────────────────────────────────────────────────────────
sc$tl$paga(adata, groups = "celltype_l2")
cat("✅ PAGA state-level complete\n")

```

---

# PAGA Analysis

## PAGA at State Level (l2)

```{r paga-state}

cat("Running PAGA — Azimuth l2 states...\n")
sc$tl$paga(adata, groups = "celltype_l2")

py_run_string("
import numpy as np
paga_conn = r.adata.uns['paga']['connectivities'].todense()
state_names = r.adata.obs['celltype_l2'].cat.categories.values
np.savetxt('paga_state_conn.csv', paga_conn, delimiter=',', header=','.join(state_names), comments='')
print('State connectivity saved')
", local = TRUE)

# Read CSV - skip row names column, assign manually
paga_state_df <- read.csv("paga_state_conn.csv", header=TRUE, check.names=FALSE)
state_names <- colnames(paga_state_df)
paga_state_conn <- as.matrix(paga_state_df)
rownames(paga_state_conn) <- state_names
colnames(paga_state_conn) <- state_names

cat("✅ State-level PAGA complete\n\nConnectivity matrix:\n")
print(round(paga_state_conn, 3))

# Clean up temp file
file.remove("paga_state_conn.csv")

```

## PAGA at Milestone Level (M00–M07)

```{r paga-milestone}
cat("Running PAGA — milestones M00-M07...\n")
sc$tl$paga(adata, groups = "milestone")

py_run_string("
import numpy as np
paga_conn = r.adata.uns['paga']['connectivities'].todense()
ms_names = r.adata.obs['milestone'].cat.categories.values
np.savetxt('paga_ms_conn.csv', paga_conn, delimiter=',', header=','.join(ms_names), comments='')
print('Milestone connectivity saved')
", local = TRUE)

# Read CSV - skip row names column, assign manually
paga_ms_df <- read.csv("paga_ms_conn.csv", header=TRUE, check.names=FALSE)
ms_names <- colnames(paga_ms_df)
paga_ms_conn <- as.matrix(paga_ms_df)
rownames(paga_ms_conn) <- ms_names
colnames(paga_ms_conn) <- ms_names

cat("✅ Milestone-level PAGA complete\n\nConnectivity matrix:\n")
print(round(paga_ms_conn, 3))

# Clean up temp file
file.remove("paga_ms_conn.csv")


```

## Save PAGA Figures (Python)- state-level PAGA
```{r}
# ✅ EXACT SCT script pattern
sc$tl$paga(adata, groups = "celltype_l2")
sc$pl$paga(adata, threshold = 0.15, show = FALSE, save = "_states.png")

file.rename("figures/paga_states.png", "Output_Figures/PAGA_states.png")
cat("✅ State PAGA saved\n")



```
## Save PAGA Figures (Python)- Milestones PAGA
```{r}

sc$tl$paga(adata, groups = "milestone")
sc$pl$paga(adata, threshold = 0.10, show = FALSE, save = "_milestones.png")
file.rename("figures/paga_milestones.png", "Output_Figures/PAGA_milestones.png")
```

```{r display-paga-graphs}

knitr::include_graphics("Output_Figures/PAGA_states.png")



```

```{r paga-milestones-nb, echo=FALSE}

knitr::include_graphics("Output_Figures/PAGA_milestones.png")


```


## PAGA Connectivity Heatmaps (R)

```{r paga-heatmaps-Cell_states, fig.width=8, fig.height=6}

# ── State-level heatmap ───────────────────────────────────────────────────
conn_sub <- paga_state_conn[
  intersect(STATE_ORDER, rownames(paga_state_conn)),
  intersect(STATE_ORDER, colnames(paga_state_conn))
]

# Save to file
pheatmap(
  conn_sub,
  color           = colorRampPalette(c("white","#2980b9","#c0392b"))(50),
  display_numbers = round(conn_sub, 2),
  number_color    = "black",
  fontsize_number = 11,
  main            = "PAGA connectivity — CD4 T cell states",
  cluster_rows    = FALSE,
  cluster_cols    = FALSE,
  filename        = "Output_Figures/PAGA_state_heatmap.png",
  width = 7, height = 6
)

# Display in notebook  ← ADD THIS
pheatmap(
  conn_sub,
  color           = colorRampPalette(c("white","#2980b9","#c0392b"))(50),
  display_numbers = round(conn_sub, 2),
  number_color    = "black",
  fontsize_number = 11,
  main            = "PAGA connectivity — CD4 T cell states",
  cluster_rows    = FALSE,
  cluster_cols    = FALSE
  # NO filename = displays in notebook
)
```

```{r paga-heatmaps-milestones, fig.width=8, fig.height=6}
# ── Milestone-level heatmap ───────────────────────────────────────────────
ms_order <- paste0("M", sprintf("%02d", 0:7))
ms_sub   <- paga_ms_conn[
  intersect(ms_order, rownames(paga_ms_conn)),
  intersect(ms_order, colnames(paga_ms_conn))
]

# Save to file
pheatmap(
  ms_sub,
  color           = colorRampPalette(c("white","#2980b9","#c0392b"))(50),
  display_numbers = round(ms_sub, 2),
  number_color    = "black",
  fontsize_number = 10,
  main            = "PAGA connectivity — Milestones M00–M07",
  cluster_rows    = FALSE,
  cluster_cols    = FALSE,
  filename        = "Output_Figures/PAGA_milestone_heatmap.png",
  width = 8, height = 7
)

# Display in notebook  ← ADD THIS
pheatmap(
  ms_sub,
  color           = colorRampPalette(c("white","#2980b9","#c0392b"))(50),
  display_numbers = round(ms_sub, 2),
  number_color    = "black",
  fontsize_number = 10,
  main            = "PAGA connectivity — Milestones M00–M07",
  cluster_rows    = FALSE,
  cluster_cols    = FALSE
  # NO filename = displays in notebook
)

cat("✅ Heatmaps saved\n")

```

---

# MST Topology Validation

## Define MST Edges

```{r mst-edges}
# MST topology:
# M00(Naive) → M01(TCM-early) → M02(TCM-late) → M03(TEM-early) → M04(TEM-late) → M05(Temra)
#                                        ↓
#                                   M06(Treg-naive) → M07(Treg-terminal)

# ── FIX 3: Correct state edges (removed duplicate CD4 TCM → CD4 TCM) ─────
mst_state_edges <- data.frame(
  from = c("CD4 Naive","CD4 TCM","CD4 TEM","CD4 TCM"),
  to   = c("CD4 TCM",  "CD4 TEM","CD4 Temra/CTL","Treg"),
  edge_label = c("Main axis","Main axis","Terminal","Treg branch")
)

mst_milestone_edges <- data.frame(
  from    = c("M00","M01","M02","M02","M03","M04","M06"),
  to      = c("M01","M02","M03","M06","M04","M05","M07"),
  biology = c("Naive→TCM-early","TCM-early→late","TCM-late→TEM-early",
              "TCM-late→Treg-naive","TEM-early→late","TEM-late→Temra",
              "Treg-naive→terminal")
)

cat("MST state edges:\n"); print(mst_state_edges)
cat("\nMST milestone edges:\n"); print(mst_milestone_edges)
```

## State-Level Validation

```{r mst-validate-state, fig.width=8, fig.height=5}
mst_state_edges$paga_score <- mapply(
  function(f, t) {
    if (f %in% rownames(paga_state_conn) && t %in% colnames(paga_state_conn))
      round(paga_state_conn[f, t], 3) else NA_real_
  },
  mst_state_edges$from, mst_state_edges$to
)

mst_state_edges$status <- case_when(
  mst_state_edges$paga_score > 0.50 ~ "Strong",
  mst_state_edges$paga_score > 0.30 ~ "Moderate",
  mst_state_edges$paga_score > 0.15 ~ "Weak",
  TRUE                               ~ "Not confirmed"
)

cat("\nMST state-level validation:\n")
print(mst_state_edges)
cat(sprintf("\n%d/%d MST state edges confirmed (PAGA > 0.3)\n",
            sum(mst_state_edges$paga_score > 0.3, na.rm=TRUE),
            nrow(mst_state_edges)))

ggplot(mst_state_edges,
       aes(x = reorder(paste(from, "→", to), paga_score),
           y = paga_score, fill = status)) +
  geom_col(width = 0.7) +
  geom_hline(yintercept = 0.3, linetype="dashed",
             colour="grey40", linewidth=0.8) +
  geom_hline(yintercept = 0.5, linetype="dotted",
             colour="grey40", linewidth=0.8) +
  annotate("text", x=0.6, y=0.32, label="Moderate (0.3)",
           size=3, hjust=0, colour="grey40") +
  annotate("text", x=0.6, y=0.52, label="Strong (0.5)",
           size=3, hjust=0, colour="grey40") +
  scale_fill_manual(
    values = c("Strong"="#27ae60","Moderate"="#2980b9",
               "Weak"="#f39c12","Not confirmed"="#c0392b"),
    name   = "Validation"
  ) +
  coord_flip() +
  theme_classic() +
  labs(x=NULL, y="PAGA connectivity score",
       title="Custom MST state edges — PAGA validation",
       subtitle="Dashed = 0.3 | Dotted = 0.5") +
  theme(plot.title=element_text(size=13, face="bold"))

ggsave("Output_Figures/MST_PAGA_state_validation.png",
       width=9, height=5, dpi=150)
```

## Milestone-Level Validation

```{r mst-validate-milestone, fig.width=10, fig.height=6}
mst_milestone_edges$paga_score <- mapply(
  function(f, t) {
    if (f %in% rownames(paga_ms_conn) && t %in% colnames(paga_ms_conn))
      round(paga_ms_conn[f, t], 3) else NA_real_
  },
  mst_milestone_edges$from, mst_milestone_edges$to
)

mst_milestone_edges$status <- case_when(
  mst_milestone_edges$paga_score > 0.50 ~ "Strong",
  mst_milestone_edges$paga_score > 0.30 ~ "Moderate",
  mst_milestone_edges$paga_score > 0.15 ~ "Weak",
  TRUE                                   ~ "Not confirmed"
)

cat("MST milestone-level validation:\n")
print(mst_milestone_edges)
cat(sprintf("\n%d/%d MST milestone edges confirmed (PAGA > 0.3)\n",
            sum(mst_milestone_edges$paga_score > 0.3, na.rm=TRUE),
            nrow(mst_milestone_edges)))

ggplot(mst_milestone_edges,
       aes(x = reorder(paste(from, "→", to), paga_score),
           y = paga_score, fill = status)) +
  geom_col(width = 0.7) +
  geom_hline(yintercept = 0.3, linetype="dashed",
             colour="grey40", linewidth=0.8) +
  geom_text(aes(label=biology), hjust=-0.1, size=3) +
  scale_fill_manual(
    values = c("Strong"="#27ae60","Moderate"="#2980b9",
               "Weak"="#f39c12","Not confirmed"="#c0392b"),
    name   = "Validation"
  ) +
  coord_flip() +
  expand_limits(y=1.2) +
  theme_classic() +
  labs(x=NULL, y="PAGA connectivity score",
       title="MST milestone edges (M00–M07) — PAGA validation") +
  theme(plot.title=element_text(size=13, face="bold"))

ggsave("Output_Figures/MST_PAGA_milestone_validation.png",
       width=10, height=6, dpi=150)
```

## Final Summary

```{r summary}
cat("══════════════════════════════════════════\n")
cat("PAGA VALIDATION SUMMARY\n")
cat("══════════════════════════════════════════\n")
cat(sprintf("State-level:     %d/%d edges confirmed (> 0.3)\n",
            sum(mst_state_edges$paga_score > 0.3, na.rm=TRUE),
            nrow(mst_state_edges)))
cat(sprintf("Milestone-level: %d/%d edges confirmed (> 0.3)\n",
            sum(mst_milestone_edges$paga_score > 0.3, na.rm=TRUE),
            nrow(mst_milestone_edges)))
cat("\nFigures saved:\n")
for (f in list.files("Output_Figures", pattern="PAGA|MST"))
  cat(sprintf("  ✅ %s\n", f))
```



```{r }
# ── Save PAGA results ──────────────────────────────────────────────────────


# Add PAGA connectivity scores back to cd4_ref metadata (optional)
# So every cell knows its state's PAGA connectivity
saveRDS(cd4_ref, "cd4_ref_dual_trajectory_with_PAGA.rds")  # Re-save with any updates
cat("✅ cd4_ref updated → cd4_ref_dual_trajectory.rds\n")

# ── Quick reload check ─────────────────────────────────────────────────────
cat("\nTo reload PAGA results later:\n")
cat("  paga_results <- readRDS('paga_validation_results.rds')\n")
cat("  paga_state_conn <- paga_results$state_connectivity\n")
cat("  paga_ms_conn    <- paga_results$milestone_connectivity\n")
cat("  adata <- anndata$read_h5ad('cd4_ref_PAGA_validated.h5ad')\n")

# ── Final inventory ────────────────────────────────────────────────────────
cat("\n══════════════════════════════════════════\n")
cat("SAVED OBJECTS\n")
cat("══════════════════════════════════════════\n")
cat("  cd4_ref_PAGA_validated.h5ad     ← AnnData with PAGA\n")
cat("  paga_validation_results.rds     ← R matrices + validation\n")
cat("  cd4_ref_dual_trajectory.rds     ← Seurat object (unchanged)\n")
cat("\nOutput_Figures/:\n")
for (f in list.files("Output_Figures", pattern="PAGA|MST"))
  cat(sprintf("  ✅ %s\n", f))

```
