suppressPackageStartupMessages({
  library(MASS)
  library(Matrix)
  library(dplyr)
  library(tidyr)
  library(ggplot2)
  library(forcats)
  library(scales)
  library(tibble)
  library(gridExtra)  # <- for tableGrob
  library(grid)       # <- for grid.draw
})
## Warning: package 'forcats' was built under R version 4.4.3
## Warning: package 'gridExtra' was built under R version 4.4.3
set.seed(12345)

## ===== Global config =====
nsim       <- 10000
n_per_arm  <- 225
alpha_tot  <- 0.05
sided      <- "two"
pC         <- 0.50

endpts <- c("H1","H2","H3","H4","H5","H6")

## ===== Correlation =====
Sigma <- matrix(c(
  1.0, 0.1, 0.5, 0.1, 0.1, 0.1,
  0.1, 1.0, 0.1, 0.8, 0.5, 0.5,
  0.5, 0.1, 1.0, 0.1, 0.1, 0.1,
  0.1, 0.8, 0.1, 1.0, 0.5, 0.5,
  0.1, 0.5, 0.1, 0.5, 1.0, 0.5,
  0.1, 0.5, 0.1, 0.5, 0.5, 1.0
), 6, 6, byrow=TRUE, dimnames=list(endpts,endpts))
if (min(eigen(Sigma, symmetric=TRUE, only.values=TRUE)$values) < 1e-8) {
  Sigma <- as.matrix(nearPD(Sigma)$mat)
}
Sigma <- (Sigma + t(Sigma))/2

## ===== Graphs for D1, D5, D7, D8 =====
# D1: G13
G13 <- rbind(
  H1 = c(0,0,1,0,0,0),
  H2 = c(0,0,1,0,0,0),
  H3 = c(0,0,0,1,0,0),
  H4 = c(0,0,0,0,1,0),
  H5 = c(0,0,0,0,0,1),
  H6 = c(0,0,0,0,0,0)
); colnames(G13) <- rownames(G13) <- endpts

# D5: G5
G5 <- rbind(
  H1 = c(0, 2/3, 1/3, 0, 0, 0),
  H2 = c(0, 0,   1,   0, 0, 0),
  H3 = c(0, 0,   0,   1, 0, 0),
  H4 = c(0, 0,   0,   0, 1, 0),
  H5 = c(0, 0,   0,   0, 0, 1),
  H6 = c(1, 0,   0,   0, 0, 0)
); colnames(G5) <- rownames(G5) <- endpts

# D7: like D5 but no flow to H6; H6 at alpha=0.05 only if H1–H5 all reject
G7 <- rbind(
  H1 = c(0, 2/3, 1/3, 0, 0, 0),
  H2 = c(0, 0,   1,   0, 0, 0),
  H3 = c(0, 0,   0,   1, 0, 0),
  H4 = c(0, 0,   0,   0, 1, 0),
  H5 = c(1, 0,   0,   0, 0, 0),  # recycle to H1 only
  H6 = c(0, 0,   0,   0, 0, 0)
); colnames(G7) <- rownames(G7) <- endpts

# D8: like D7, but H1 recycles 50% to H2 and 50% to H3; initial split 0.025 / 0.025
G8 <- rbind(
  H1 = c(0, 1/2, 1/2, 0, 0, 0),  # half to H2, half to H3 upon H1 rejection
  H2 = c(0, 0,   1,   0, 0, 0),
  H3 = c(0, 0,   0,   1, 0, 0),
  H4 = c(0, 0,   0,   0, 1, 0),
  H5 = c(1, 0,   0,   0, 0, 0),  # recycle to H1 only (same as D7)
  H6 = c(0, 0,   0,   0, 0, 0)
); colnames(G8) <- rownames(G8) <- endpts

## ===== Initial weights (add D8) =====
w_D1 <- c(H1=1/2, H2=1/2, H3=0, H4=0, H5=0, H6=0)  # 0.025, 0.025
w_D5 <- c(H1=4/5, H2=1/5, H3=0, H4=0, H5=0, H6=0)  # 0.04, 0.01
w_D7 <- c(H1=4/5, H2=1/5, H3=0, H4=0, H5=0, H6=0)  # 0.04, 0.01
w_D8 <- c(H1=1/2, H2=1/2, H3=0, H4=0, H5=0, H6=0)  # 0.025, 0.025

## ===== Colors (D1, D5, D7, D8) =====
cols <- c(
  "D1" = "#1f77b4",
  "D5" = "#2ca02c",  # green
  "D7" = "#8c564b",
  "D8" = "#9467bd"
)

## ===== Helpers (dynamic 25/50/75/100 top, ticks every 5%) =====
axis_cap_25_50_75 <- function(mx) {
  if (is.na(mx) || mx <= 0) return(0.25)
  if (mx <= 0.25) return(0.25)
  if (mx <= 0.50) return(0.50)
  if (mx <= 0.75) return(0.75)
  1.00
}
axis_breaks_5pct <- function(top) seq(0, top, by = 0.05)

## ===== Alpha flow helpers =====
has_path_to_nonrejected <- function(start, rejected, G) {
  q <- which(names(rejected)==start)
  visited <- setNames(rep(FALSE, length(rejected)), names(rejected))
  while (length(q)) {
    i <- q[1]; q <- q[-1]
    if (visited[i]) next
    visited[i] <- TRUE
    nbrs <- names(which(G[i,] > 0))
    for (jname in nbrs) {
      j <- which(names(rejected)==jname)
      if (!rejected[j]) return(TRUE)
      if (!visited[j]) q <- c(q, j)
    }
  }
  FALSE
}
push_alpha_through_rejected <- function(a, rejected, G) {
  repeat {
    moved <- FALSE
    for (i in seq_along(a)) {
      if (rejected[i] && a[i] > 0) {
        if (has_path_to_nonrejected(names(a)[i], rejected, G)) {
          ai <- a[i]; a <- a + ai * G[i, ]; a[i] <- 0; moved <- TRUE
        } else { a[i] <- 0; moved <- TRUE }
      }
    }
    if (!moved) break
  }
  a
}

run_graph <- function(p, w, G, alpha_tot=0.05) {
  stopifnot(abs(sum(w)-1) < 1e-12, identical(names(w), colnames(G)))
  a <- alpha_tot * w
  tested   <- setNames(rep(FALSE, 6), names(w))
  rejected <- setNames(rep(FALSE, 6), names(w))
  tested[a > 0] <- TRUE
  repeat {
    can_reject <- (!rejected) & (p <= a)
    if (!any(can_reject)) break
    rej_nodes <- names(which(can_reject))
    for (i in rej_nodes) {
      ai <- a[i]
      if (ai > 0) { a <- a + ai * G[i, ]; a[i] <- 0 }
      rejected[i] <- TRUE
    }
    a <- push_alpha_through_rejected(a, rejected, G)
    tested[a > 0] <- TRUE
  }
  list(tested=tested, rejected=rejected)
}

## AND-gate for H6 only (used by D7 and D8)
run_andgate_H6 <- function(p, w, G, alpha_tot=0.05) {
  a <- alpha_tot * w
  tested   <- setNames(rep(FALSE, 6), names(w))
  rejected <- setNames(rep(FALSE, 6), names(w))
  tested[a > 0] <- TRUE
  repeat {
    can_reject <- (!rejected[1:5]) & (p[1:5] <= a[1:5])
    if (!any(can_reject)) break
    rej_nodes <- names(which(can_reject))
    for (i in rej_nodes) {
      ai <- a[i]
      if (ai > 0) { a <- a + ai * G[i, ]; a[i] <- 0 }
      rejected[i] <- TRUE
    }
    a <- push_alpha_through_rejected(a, rejected, G)
    tested[a > 0] <- TRUE
  }
  if (all(rejected[1:5])) {
    tested["H6"]  <- TRUE
    rejected["H6"] <- (p["H6"] <= 0.05)
  }
  list(tested=tested, rejected=rejected)
}

## ===== Data generation & p-values =====
gen_arm <- function(trt, n, Sigma, Delta, SD, pC) {
  Z <- mvrnorm(n, mu=rep(0,6), Sigma=Sigma); colnames(Z) <- endpts
  means <- setNames(rep(0,6), endpts)
  if (trt==1) means[c("H1","H2","H4","H5","H6")] <- Delta[c("H1","H2","H4","H5","H6")]
  dat <- data.frame(
    kccq   = means["H1"] + SD["H1"] * Z[,"H1"],
    pvo2   = means["H2"] + SD["H2"] * Z[,"H2"],
    zscore = means["H4"] + SD["H4"] * Z[,"H4"],
    ntprob = means["H5"] + SD["H5"] * Z[,"H5"],
    lavi   = means["H6"] + SD["H6"] * Z[,"H6"],
    trt    = trt
  )
  p <- if (trt==1) pmin(pmax(pC + Delta["H3"], 1e-8), 1-1e-8) else pC
  dat$imp <- as.integer(Z[,"H3"] <= qnorm(p))
  dat
}
p_two_or_one <- function(tstat, sided=c("two","one"), delta_sign=1) {
  sided <- match.arg(sided)
  if (sided=="two") 2*pnorm(-abs(tstat)) else pnorm(-delta_sign*tstat)
}
pvals_from_data <- function(dt, Delta) {
  pv <- setNames(rep(NA_real_,6), endpts)
  pv["H1"] <- p_two_or_one(t.test(kccq ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H1"]))
  pv["H2"] <- p_two_or_one(t.test(pvo2 ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H2"]))
  pv["H3"] <- suppressWarnings(chisq.test(table(dt$trt, dt$imp), correct=FALSE)$p.value)
  pv["H4"] <- p_two_or_one(t.test(zscore ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H4"]))
  pv["H5"] <- p_two_or_one(t.test(ntprob ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H5"]))
  pv["H6"] <- p_two_or_one(t.test(lavi   ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H6"]))
  pv
}

## ===== 8 Scenarios =====
H5_fix <- list(diff=-1.0, sd=1.0)

scenarios <- list(
  "Scenario 1"  = list(H1=list(diff=3.75, sd=15), H2=list(diff=0.75, sd=3), H3=0.15, H4=list(diff=0.25, sd=1), H5=H5_fix, H6=list(diff=-3,   sd=10)),
  "Scenario 2"  = list(H1=list(diff=3.75, sd=15), H2=list(diff=0.54, sd=3), H3=0.15, H4=list(diff=0.25, sd=1), H5=H5_fix, H6=list(diff=-3,   sd=10)),
  "Scenario 3"  = list(H1=list(diff=2.70, sd=15), H2=list(diff=0.75, sd=3), H3=0.15, H4=list(diff=0.25, sd=1), H5=H5_fix, H6=list(diff=-3,   sd=10)),
  "Scenario 4"  = list(H1=list(diff=2.70, sd=15), H2=list(diff=0.54, sd=3), H3=0.15, H4=list(diff=0.25, sd=1), H5=H5_fix, H6=list(diff=-3,   sd=10)),
  "Scenario 5"  = list(H1=list(diff=3.75, sd=15), H2=list(diff=0.75, sd=3), H3=0.15, H4=list(diff=0.25, sd=1), H5=H5_fix, H6=list(diff=-1.8, sd=10)),
  "Scenario 6"  = list(H1=list(diff=3.75, sd=15), H2=list(diff=0.54, sd=3), H3=0.15, H4=list(diff=0.25, sd=1), H5=H5_fix, H6=list(diff=-1.8, sd=10)),
  "Scenario 7"  = list(H1=list(diff=2.70, sd=15), H2=list(diff=0.75, sd=3), H3=0.15, H4=list(diff=0.25, sd=1), H5=H5_fix, H6=list(diff=-1.8, sd=10)),
  "Scenario 8"  = list(H1=list(diff=2.70, sd=15), H2=list(diff=0.54, sd=3), H3=0.15, H4=list(diff=0.25, sd=1), H5=H5_fix, H6=list(diff=-1.8, sd=10))
)
stopifnot(length(scenarios) == 8)

## ===== Title helper: show ONLY standardized diffs for H1–H4 + H6 =====
mk_title <- function(name, sc) {
  paste0(
    name, " — H1=", sc$H1$diff/sc$H1$sd, "; ",
    "H2=", sc$H2$diff/sc$H2$sd, "; ",
    "H3=", sc$H3, "; ",
    "H4=", sc$H4$diff/sc$H4$sd, "; ",
    "H6=", sc$H6$diff/sc$H6$sd
  )
}

## ===== Simulation per scenario (D1, D5, D7, D8; add H1_or_H2 and H1_and_H2) =====
simulate_designs_for_scenario <- function(sc) {
  Delta <- c(H1=sc$H1$diff, H2=sc$H2$diff, H3=sc$H3, H4=sc$H4$diff, H5=sc$H5$diff, H6=sc$H6$diff)
  SD    <- c(H1=sc$H1$sd,   H2=sc$H2$sd,                 H4=sc$H4$sd,   H5=sc$H5$sd,   H6=sc$H6$sd)

  p_list <- vector("list", nsim)
  for (b in seq_len(nsim)) {
    trt <- gen_arm(1, n_per_arm, Sigma, Delta, SD, pC)
    ctl <- gen_arm(0, n_per_arm, Sigma, Delta, SD, pC)
    dt  <- rbind(trt, ctl)
    p_list[[b]] <- pvals_from_data(dt, Delta)
  }

  simulate_with_p <- function(label, w_init, Gmat, mode=c("generic","andgate")) {
    mode <- match.arg(mode)
    tested_acc <- matrix(0L, nsim, 6, dimnames=list(NULL,endpts))
    reject_acc <- matrix(0L, nsim, 6, dimnames=list(NULL,endpts))
    h1_rej <- logical(nsim); h2_rej <- logical(nsim)

    for (b in seq_len(nsim)) {
      p <- p_list[[b]]
      out <- if (mode=="andgate") run_andgate_H6(p, w_init, Gmat, alpha_tot)
             else                  run_graph      (p, w_init, Gmat, alpha_tot)
      tested_acc[b, ] <- as.integer(out$tested)
      reject_acc[b, ] <- as.integer(out$rejected)
      h1_rej[b] <- out$rejected["H1"]
      h2_rej[b] <- out$rejected["H2"]
    }

    # Standard endpoints
    res_endpt <- tibble(endpoint=endpts,
                        prob_tested=colMeans(tested_acc),
                        power=colMeans(reject_acc),
                        design=label)

    # Combined endpoints:
    res_or  <- tibble(endpoint="H1_or_H2",
                      prob_tested=NA_real_,
                      power=mean(h1_rej | h2_rej),
                      design=label)

    res_and <- tibble(endpoint="H1_and_H2",
                      prob_tested=NA_real_,
                      power=mean(h1_rej & h2_rej),
                      design=label)

    bind_rows(res_endpt, res_or, res_and)
  }

  bind_rows(
    simulate_with_p("D1", w_D1, G13, "generic"),
    simulate_with_p("D5", w_D5, G5,  "generic"),
    simulate_with_p("D7", w_D7, G7,  "andgate"),  # H6 only after H1–H5 all reject
    simulate_with_p("D8", w_D8, G8,  "andgate")   # same AND-gate rule; different split & H1 routing
  )
}

## ===== Run all scenarios and plot (no saving; titles show standardized diffs) =====
for (sc_name in names(scenarios)) {
  sc  <- scenarios[[sc_name]]
  res <- simulate_designs_for_scenario(sc)

  ## Probability of being tested: ONLY H3–H6
  res_prob <- res %>%
    filter(endpoint %in% paste0("H",3:6)) %>%
    mutate(endpoint=factor(endpoint, levels=paste0("H",3:6)))

  top_prob <- axis_cap_25_50_75(max(res_prob$prob_tested, na.rm = TRUE))

  p_prob <- ggplot(res_prob, aes(endpoint, prob_tested, fill=design)) +
    geom_col(position=position_dodge(width=0.82), width=0.72) +
    labs(
      title = mk_title(sc_name, sc),
      subtitle = "Probability of being tested (H3–H6)",
      x="Endpoint", y="Probability tested"
    ) +
    scale_y_continuous(limits = c(0, top_prob),
                       breaks = axis_breaks_5pct(top_prob),
                       labels = percent_format(accuracy=1),
                       expand = expansion(mult=c(0,0))) +
    theme_minimal(base_size=12) +
    scale_fill_manual(values=cols) +
    theme(plot.title = element_text(face="bold", size=14))
  print(p_prob)

  ## Power: include H1_or_H2, H1_and_H2, then H1–H6
  res_power <- res %>%
    mutate(endpoint=factor(endpoint,
                           levels=c("H1_or_H2", "H1_and_H2", paste0("H",1:6))))

  top_power <- axis_cap_25_50_75(max(res_power$power, na.rm = TRUE))

  p_pow <- ggplot(res_power, aes(endpoint, power, fill=design)) +
    geom_col(position=position_dodge(width=0.82), width=0.72) +
    labs(
      title = mk_title(sc_name, sc),
      x="Endpoint", y="Power"
    ) +
    scale_y_continuous(limits = c(0, top_power),
                       breaks = axis_breaks_5pct(top_power),
                       labels = percent_format(accuracy=1),
                       expand = expansion(mult=c(0,0))) +
    theme_minimal(base_size=12) +
    scale_fill_manual(values=cols) +
    theme(plot.title = element_text(face="bold", size=14))
  print(p_pow)

  ## ===== Smaller-font table as a graphic =====
  cat("\n===== ", sc_name, " — Table =====\n", sep = "")
  tbl <- res %>%
    arrange(match(endpoint, c("H1_or_H2", "H1_and_H2", paste0("H",1:6))), design)

  tg <- tableGrob(tbl, rows = NULL, theme = ttheme_minimal(base_size = 9))  # smaller font
  grid.newpage(); grid.draw(tg)
}

## 
## ===== Scenario 1 — Table =====

## 
## ===== Scenario 2 — Table =====

## 
## ===== Scenario 3 — Table =====

## 
## ===== Scenario 4 — Table =====

## 
## ===== Scenario 5 — Table =====

## 
## ===== Scenario 6 — Table =====

## 
## ===== Scenario 7 — Table =====

## 
## ===== Scenario 8 — Table =====