suppressPackageStartupMessages({
  library(MASS)
  library(Matrix)
  library(dplyr)
  library(tidyr)
  library(ggplot2)
  library(forcats)
  library(scales)
  library(tibble)
})
## Warning: package 'forcats' was built under R version 4.4.3
set.seed(12345)

## ===== Global config =====
nsim       <- 10000
n_per_arm  <- 225
alpha_tot  <- 0.05
sided      <- "two"
pC         <- 0.50

endpts <- c("H1","H2","H3","H4","H5","H6")

## ===== Correlation =====
Sigma <- matrix(c(
  1.0, 0.1, 0.5, 0.1, 0.1, 0.1,
  0.1, 1.0, 0.1, 0.8, 0.5, 0.5,
  0.5, 0.1, 1.0, 0.1, 0.1, 0.1,
  0.1, 0.8, 0.1, 1.0, 0.5, 0.5,
  0.1, 0.5, 0.1, 0.5, 1.0, 0.5,
  0.1, 0.5, 0.1, 0.5, 0.5, 1.0
), 6, 6, byrow=TRUE, dimnames=list(endpts,endpts))
if (min(eigen(Sigma, symmetric=TRUE, only.values=TRUE)$values) < 1e-8) {
  Sigma <- as.matrix(nearPD(Sigma)$mat)
}
Sigma <- (Sigma + t(Sigma))/2

## ===== Graphs for D1–D6 =====
G13 <- rbind(
  H1 = c(0,0,1,0,0,0),
  H2 = c(0,0,1,0,0,0),
  H3 = c(0,0,0,1,0,0),
  H4 = c(0,0,0,0,1,0),
  H5 = c(0,0,0,0,0,1),
  H6 = c(0,0,0,0,0,0)
); colnames(G13) <- rownames(G13) <- endpts

G4 <- rbind(
  H1 = c(0,0,1,0,0,0),
  H2 = c(0,0,0,1,0,0),
  H3 = c(0,1,0,0,0,0),
  H4 = c(0,0,0,0,1,0),
  H5 = c(1,0,0,0,0,0),
  H6 = c(0,0,0,0,0,0)
); colnames(G4) <- rownames(G4) <- endpts

G5 <- rbind(
  H1 = c(0, 2/3, 1/3, 0, 0, 0),
  H2 = c(0, 0,   1,   0, 0, 0),
  H3 = c(0, 0,   0,   1, 0, 0),
  H4 = c(0, 0,   0,   0, 1, 0),
  H5 = c(0, 0,   0,   0, 0, 1),
  H6 = c(1, 0,   0,   0, 0, 0)
); colnames(G5) <- rownames(G5) <- endpts

G6 <- rbind(
  H1 = c(0,0,1,0,0,0),
  H2 = c(0,0,0,1,0,0),
  H3 = c(0,0,0,0,1,0),
  H4 = c(1,0,0,0,0,0),
  H5 = c(0,1,0,0,0,0),
  H6 = c(0,0,0,0,0,0)
); colnames(G6) <- rownames(G6) <- endpts

## ===== Initial weights =====
w_D1 <- c(H1=1/2, H2=1/2, H3=0, H4=0, H5=0, H6=0)
w_D2 <- c(H1=4/5, H2=1/5, H3=0, H4=0, H5=0, H6=0)
w_D3 <- c(H1=2/3, H2=1/3, H3=0, H4=0, H5=0, H6=0)
w_D4 <- c(H1=1/2, H2=1/2, H3=0, H4=0, H5=0, H6=0)
w_D5 <- c(H1=4/5, H2=1/5, H3=0, H4=0, H5=0, H6=0)
w_D6 <- c(H1=4/5, H2=1/5, H3=0, H4=0, H5=0, H6=0)

## ===== Colors =====
cols <- c("D1"="#1f77b4","D2"="#ff7f0e","D3"="#2ca02c",
          "D4"="#9467bd","D5"="#8c564b","D6"="#17becf")

## ===== Helpers (dynamic 25/50/75/100 top, ticks every 5%) =====
axis_cap_25_50_75 <- function(mx) {
  if (is.na(mx) || mx <= 0) return(0.25)
  if (mx <= 0.25) return(0.25)
  if (mx <= 0.50) return(0.50)
  if (mx <= 0.75) return(0.75)
  1.00
}
axis_breaks_5pct <- function(top) seq(0, top, by = 0.05)

## ===== Alpha flow helpers =====
has_path_to_nonrejected <- function(start, rejected, G) {
  q <- which(names(rejected)==start)
  visited <- setNames(rep(FALSE, length(rejected)), names(rejected))
  while (length(q)) {
    i <- q[1]; q <- q[-1]
    if (visited[i]) next
    visited[i] <- TRUE
    nbrs <- names(which(G[i,] > 0))
    for (jname in nbrs) {
      j <- which(names(rejected)==jname)
      if (!rejected[j]) return(TRUE)
      if (!visited[j]) q <- c(q, j)
    }
  }
  FALSE
}
push_alpha_through_rejected <- function(a, rejected, G) {
  repeat {
    moved <- FALSE
    for (i in seq_along(a)) {
      if (rejected[i] && a[i] > 0) {
        if (has_path_to_nonrejected(names(a)[i], rejected, G)) {
          ai <- a[i]; a <- a + ai * G[i, ]; a[i] <- 0; moved <- TRUE
        } else { a[i] <- 0; moved <- TRUE }
      }
    }
    if (!moved) break
  }
  a
}

run_graph <- function(p, w, G, alpha_tot=0.05) {
  stopifnot(abs(sum(w)-1) < 1e-12, identical(names(w), colnames(G)))
  a <- alpha_tot * w
  tested   <- setNames(rep(FALSE, 6), names(w))
  rejected <- setNames(rep(FALSE, 6), names(w))
  tested[a > 0] <- TRUE
  repeat {
    can_reject <- (!rejected) & (p <= a)
    if (!any(can_reject)) break
    rej_nodes <- names(which(can_reject))
    for (i in rej_nodes) {
      ai <- a[i]
      if (ai > 0) { a <- a + ai * G[i, ]; a[i] <- 0 }
      rejected[i] <- TRUE
    }
    a <- push_alpha_through_rejected(a, rejected, G)
    tested[a > 0] <- TRUE
  }
  list(tested=tested, rejected=rejected)
}

## AND-gate for H6 only (unchanged)
run_andgate_H6 <- function(p, w, G, alpha_tot=0.05) {
  a <- alpha_tot * w
  tested   <- setNames(rep(FALSE, 6), names(w))
  rejected <- setNames(rep(FALSE, 6), names(w))
  tested[a > 0] <- TRUE
  repeat {
    can_reject <- (!rejected[1:5]) & (p[1:5] <= a[1:5])
    if (!any(can_reject)) break
    rej_nodes <- names(which(can_reject))
    for (i in rej_nodes) {
      ai <- a[i]
      if (ai > 0) { a <- a + ai * G[i, ]; a[i] <- 0 }
      rejected[i] <- TRUE
    }
    a <- push_alpha_through_rejected(a, rejected, G)
    tested[a > 0] <- TRUE
  }
  if (all(rejected[1:5])) {
    tested["H6"]  <- TRUE
    rejected["H6"] <- (p["H6"] <= 0.05)
  }
  list(tested=tested, rejected=rejected)
}

## ===== Data generation & p-values =====
gen_arm <- function(trt, n, Sigma, Delta, SD, pC) {
  Z <- mvrnorm(n, mu=rep(0,6), Sigma=Sigma); colnames(Z) <- endpts
  means <- setNames(rep(0,6), endpts)
  if (trt==1) means[c("H1","H2","H4","H5","H6")] <- Delta[c("H1","H2","H4","H5","H6")]
  dat <- data.frame(
    kccq   = means["H1"] + SD["H1"] * Z[,"H1"],
    pvo2   = means["H2"] + SD["H2"] * Z[,"H2"],
    zscore = means["H4"] + SD["H4"] * Z[,"H4"],
    ntprob = means["H5"] + SD["H5"] * Z[,"H5"],
    lavi   = means["H6"] + SD["H6"] * Z[,"H6"],
    trt    = trt
  )
  p <- if (trt==1) pmin(pmax(pC + Delta["H3"], 1e-8), 1-1e-8) else pC
  dat$imp <- as.integer(Z[,"H3"] <= qnorm(p))
  dat
}
p_two_or_one <- function(tstat, sided=c("two","one"), delta_sign=1) {
  sided <- match.arg(sided)
  if (sided=="two") 2*pnorm(-abs(tstat)) else pnorm(-delta_sign*tstat)
}
pvals_from_data <- function(dt, Delta) {
  pv <- setNames(rep(NA_real_,6), endpts)
  pv["H1"] <- p_two_or_one(t.test(kccq ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H1"]))
  pv["H2"] <- p_two_or_one(t.test(pvo2 ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H2"]))
  pv["H3"] <- suppressWarnings(chisq.test(table(dt$trt, dt$imp), correct=FALSE)$p.value)
  pv["H4"] <- p_two_or_one(t.test(zscore ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H4"]))
  pv["H5"] <- p_two_or_one(t.test(ntprob ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H5"]))
  pv["H6"] <- p_two_or_one(t.test(lavi   ~ trt, data=dt, var.equal=FALSE)$statistic, sided=sided, delta_sign=sign(Delta["H6"]))
  pv
}

## ===== Scenarios (16) =====
H5_fix <- list(diff=-1.0, sd=1.0)
H6_fix <- list(diff=-3.5, sd=10.0)

scenarios <- list(
  "Scenario 1"  = list(H1=list(diff=5.00, sd=15), H2=list(diff=1.00, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 2"  = list(H1=list(diff=5.00, sd=15), H2=list(diff=0.75, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 3"  = list(H1=list(diff=5.00, sd=15), H2=list(diff=0.54, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 4"  = list(H1=list(diff=5.00, sd=15), H2=list(diff=0.30, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),

  "Scenario 5"  = list(H1=list(diff=3.75, sd=15), H2=list(diff=1.00, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 6"  = list(H1=list(diff=3.75, sd=15), H2=list(diff=0.75, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 7"  = list(H1=list(diff=3.75, sd=15), H2=list(diff=0.54, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 8"  = list(H1=list(diff=3.75, sd=15), H2=list(diff=0.30, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),

  "Scenario 9"  = list(H1=list(diff=2.70, sd=15), H2=list(diff=1.00, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 10" = list(H1=list(diff=2.70, sd=15), H2=list(diff=0.75, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 11" = list(H1=list(diff=2.70, sd=15), H2=list(diff=0.54, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 12" = list(H1=list(diff=2.70, sd=15), H2=list(diff=0.30, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),

  "Scenario 13" = list(H1=list(diff=1.50, sd=15), H2=list(diff=1.00, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 14" = list(H1=list(diff=1.50, sd=15), H2=list(diff=0.75, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 15" = list(H1=list(diff=1.50, sd=15), H2=list(diff=0.54, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix),
  "Scenario 16" = list(H1=list(diff=1.50, sd=15), H2=list(diff=0.30, sd=3), H3=0.15, H4=list(diff=0.2, sd=1), H5=H5_fix, H6=H6_fix)
)
stopifnot(length(scenarios) == 16)

## ===== Title/Caption helpers for scenarios =====
mk_title <- function(name, sc) {
  paste0(
    name, " — ",
    "H1=", sc$H1$diff, ", ",
    "H2=", sc$H2$diff, ", ",
    "H3=", sc$H3, ", ",
    "H4=", sc$H4$diff, ", ",
    "H5=", sc$H5$diff, ", ",
    "H6=", sc$H6$diff
  )
}
mk_caption <- function(sc) {
  paste0(
    "SDs: ",
    "H1=", sc$H1$sd, ", ",
    "H2=", sc$H2$sd, ", ",
    "H4=", sc$H4$sd, ", ",
    "H5=", sc$H5$sd, ", ",
    "H6=", sc$H6$sd
  )
}

## ===== Simulation per scenario =====
simulate_designs_for_scenario <- function(sc) {
  Delta <- c(H1=sc$H1$diff, H2=sc$H2$diff, H3=sc$H3, H4=sc$H4$diff, H5=sc$H5$diff, H6=sc$H6$diff)
  SD    <- c(H1=sc$H1$sd,   H2=sc$H2$sd,                 H4=sc$H4$sd,   H5=sc$H5$sd,   H6=sc$H6$sd)

  p_list <- vector("list", nsim)
  for (b in seq_len(nsim)) {
    trt <- gen_arm(1, n_per_arm, Sigma, Delta, SD, pC)
    ctl <- gen_arm(0, n_per_arm, Sigma, Delta, SD, pC)
    dt  <- rbind(trt, ctl)
    p_list[[b]] <- pvals_from_data(dt, Delta)
  }

  simulate_with_p <- function(label, w_init, Gmat, mode=c("generic","andgate")) {
    mode <- match.arg(mode)
    tested_acc <- matrix(0L, nsim, 6, dimnames=list(NULL,endpts))
    reject_acc <- matrix(0L, nsim, 6, dimnames=list(NULL,endpts))
    h1_rej <- logical(nsim); h2_rej <- logical(nsim)
    for (b in seq_len(nsim)) {
      p <- p_list[[b]]
      out <- if (mode=="andgate") run_andgate_H6(p, w_init, Gmat, alpha_tot)
             else                  run_graph      (p, w_init, Gmat, alpha_tot)
      tested_acc[b, ] <- as.integer(out$tested)
      reject_acc[b, ] <- as.integer(out$rejected)
      h1_rej[b] <- out$rejected["H1"]
      h2_rej[b] <- out$rejected["H2"]
    }
    res_endpt <- tibble(endpoint=endpts,
                        prob_tested=colMeans(tested_acc),
                        power=colMeans(reject_acc),
                        design=label)
    res_sum12 <- tibble(
      design = label,
      power_either_H1_or_H2 = mean(h1_rej | h2_rej),
      power_both_H1_and_H2  = mean(h1_rej & h2_rej)
    )
    list(endpt=res_endpt, sum12=res_sum12)
  }

  parts <- list(
    simulate_with_p("D1", w_D1, G13, "generic"),
    simulate_with_p("D2", w_D2, G13, "generic"),
    simulate_with_p("D3", w_D3, G13, "generic"),
    simulate_with_p("D4", w_D4, G4,  "andgate"),
    simulate_with_p("D5", w_D5, G5,  "generic"),
    simulate_with_p("D6", w_D6, G6,  "andgate")
  )

  res_endpt  <- bind_rows(lapply(parts, `[[`, "endpt"))
  res_sum12  <- bind_rows(lapply(parts, `[[`, "sum12"))

  ## single column to plot per request:
  ## - D1–D3: P(H1 OR H2)
  ## - D4–D6: P(H1 AND H2)
  res_plot <- res_sum12 %>%
    mutate(power_to_plot = ifelse(design %in% c("D1","D2","D3"),
                                  power_either_H1_or_H2,
                                  power_both_H1_and_H2))
  list(endpt=res_endpt, plot=res_plot, sum12=res_sum12)
}

## ===== Output dir for CSVs =====
out_dir <- "C:/Users/admin/Downloads"
if (!dir.exists(out_dir)) dir.create(out_dir, recursive = TRUE)

## ===== Run all scenarios and plot requested bars with scenario titles =====
for (sc_name in names(scenarios)) {
  sc  <- scenarios[[sc_name]]
  sim <- simulate_designs_for_scenario(sc)
  res_plot <- sim$plot %>% mutate(design=fct_inorder(design))

  ## y-axis dynamic cap (25/50/75/100) and 5% ticks
  top_cap <- axis_cap_25_50_75(max(res_plot$power_to_plot, na.rm = TRUE))

  ## Save the plotted data
  out_csv <- file.path(out_dir, paste0(gsub("\\s+","_", tolower(sc_name)), "_H1H2_powers.csv"))
  write.csv(res_plot, out_csv, row.names = FALSE)

  ## Title & caption with scenario values
  title_txt   <- mk_title(sc_name, sc)
  caption_txt <- mk_caption(sc)

  ## One figure per scenario as requested
  p <- ggplot(res_plot, aes(design, power_to_plot, fill=design)) +
    geom_col(width=0.72) +
    labs(
      title = title_txt,
      subtitle = "D1–D3: P(reject H1 OR H2) | D4–D6: P(reject H1 AND H2) after alpha recycling",
      x = "Design",
      y = "Power",
      caption = caption_txt
    ) +
    scale_y_continuous(limits = c(0, top_cap),
                       breaks = axis_breaks_5pct(top_cap),
                       labels = percent_format(accuracy = 1),
                       expand = expansion(mult = c(0,0))) +
    theme_minimal(base_size = 12) +
    scale_fill_manual(values = cols) +
    theme(legend.position = "none")
  print(p)
}