# Load necessary library
library(survival)
## Warning: package 'survival' was built under R version 4.2.3
# Function defining the piecewise exponential hazard model
haz.model <- function(p.end.ref = 0.80, p.delta = 0.02) {
  hr <- 2.2
  haz.0.1 <- -log(p.end.ref) / (12 * (1 + hr)) # Control Arm, Period 1 (up to Week 12)
  haz.1.1 <- -log(p.end.ref + p.delta) / (12 * (1 + hr)) # Exp. Arm, Period 1 (up to Week 12)
  haz.0.2 <- -hr * log(p.end.ref) / (64 * (1 + hr)) # Control Arm, Period 2 (Week 12 - Week 76)
  haz.1.2 <- -hr * log(p.end.ref + p.delta) / (64 * (1 + hr)) # Exp. Arm, Period 2 (Week 12 - Week 76)
  return(c(haz.0.1, haz.1.1, haz.0.2, haz.1.2))
}

# Function for sample size calculation for a 2-arm NI trial, binary outcome
MyPass.N <- function(pA = 0.04, pB = 0.04, delta = -0.04, kappa = 1, alpha = 0.025, beta = 0.10) {
  nB <- (pA * (1 - pA) / kappa + pB * (1 - pB)) * ((qnorm(1 - alpha) + qnorm(1 - beta)) / (pA - pB - delta))^2
  return(ceiling(nB))
}

# Function for simulating a clinical trial replicate with unblinded interim analysis
sim.data.unblinded <- function(p.end.ref, p.delta, n, ncompl, n.tot = 172, cut.off = 0.76, recruit, up.cap = 548) {
  vst <- c(1, 2, 3, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 60, 68, 76)
  hz <- haz.model(p.end.ref = p.end.ref, p.delta = p.delta)
  haz.0.1 <- hz[1]; haz.1.1 <- hz[2]; haz.0.2 <- hz[3]; haz.1.2 <- hz[4]

  # Simulate event times
  y.0.1 <- ceiling(rexp(n, rate = haz.0.1)); y.1.1 <- ceiling(rexp(n, rate = haz.1.1))
  y.0.2 <- ceiling(rexp(n, rate = haz.0.2)); y.1.2 <- ceiling(rexp(n, rate = haz.1.2))

  y.0 <- ifelse(y.0.1 <= 12, y.0.1, 12 + y.0.2)
  y.1 <- ifelse(y.1.1 <= 12, y.1.1, 12 + y.1.2)

  event.time.0 <- sapply(y.0, function(x) ifelse(x > 76, NA, min(vst[vst >= x])))
  event.time.1 <- sapply(y.1, function(x) ifelse(x > 76, NA, min(vst[vst >= x])))

  # Interim Analysis: Simulate last visit for ongoing subjects
  tmp.0 <- rmultinom(n - ncompl, size = 1, prob = recruit)
  tmp.1 <- rmultinom(n - ncompl, size = 1, prob = recruit)
  ia.0 <- t(vst[1:18] %*% tmp.0)
  ia.1 <- t(vst[1:18] %*% tmp.1)

  last.visit.0 <- c(ia.0, rep(76, ncompl))
  last.visit.1 <- c(ia.1, rep(76, ncompl))

  # Define censoring for each arm
  obs.0 <- ifelse(is.na(event.time.0), 0, ifelse(event.time.0 <= last.visit.0, 1, 0))
  obs.1 <- ifelse(is.na(event.time.1), 0, ifelse(event.time.1 <= last.visit.1, 1, 0))
  obs.time.0 <- ifelse(obs.0 == 0, last.visit.0, event.time.0)
  obs.time.1 <- ifelse(obs.1 == 0, last.visit.1, event.time.1)

  trt <- rep(c(0, 1), each = n)
  obs.time <- c(obs.time.0, obs.time.1)
  obs <- c(obs.0, obs.1)

  # Unblinded Kaplan-Meier analysis
  a.control <- survfit(Surv(obs.time.0, obs.0) ~ 1)  # Control Arm
  a.experimental <- survfit(Surv(obs.time.1, obs.1) ~ 1)  # Experimental Arm

  est.p.0 <- min(a.control$surv)
  est.p.1 <- min(a.experimental$surv)

  # Calculate sample size adjustment based on unblinded data
  NewN <- ceiling(MyPass.N(pA = est.p.0, pB = est.p.1, delta = -0.10, kappa = 1, alpha = 0.025, beta = 0.20))
  NewN <- ifelse(NewN > 344 & est.p.0 <= cut.off & est.p.1 <= cut.off, NewN, 344)
  NewN <- ifelse(NewN > up.cap, up.cap, NewN)

  # Final outcomes for subjects in interim analysis
  y.0.e <- ifelse(y.0 > 76, 1, 0)
  y.1.e <- ifelse(y.1 > 76, 1, 0)

  # Post-interim data simulation
  y.0.2 <- rbinom(n.tot - n, 1, p.end.ref)
  y.1.2 <- rbinom(n.tot - n, 1, p.end.ref + p.delta)

  # Calculate p-value for unblinded interim
  end.0.all <- mean(c(y.0.e, y.0.2))
  end.1.all <- mean(c(y.1.e, y.1.2))
  p.0 <- end.0.all
  p.1 <- end.1.all
  vrs <- p.0 * (1 - p.0) + p.1 * (1 - p.1)
  z <- (p.1 - p.0 - (-0.10)) / sqrt(vrs / n.tot)
  p.value.0 <- 1 - pnorm(z)

  # Sample size increase step based on unblinded interim results
  if (est.p.0 <= cut.off & est.p.1 <= cut.off) {
    y.0.3 <- rbinom(NewN / 2 - n.tot, 1, p.end.ref)
    y.1.3 <- rbinom(NewN / 2 - n.tot, 1, p.end.ref + p.delta)
    end.0.all <- mean(c(y.0.e, y.0.2, y.0.3))
    end.1.all <- mean(c(y.1.e, y.1.2, y.1.3))
    p.0 <- end.0.all
    p.1 <- end.1.all
    vrs <- p.0 * (1 - p.0) + p.1 * (1 - p.1)
    z <- (p.1 - p.0 - (-0.10)) / sqrt(2 * vrs / NewN)
    p.value.1 <- 1 - pnorm(z)
  } else {
    p.value.1 <- p.value.0
  }

  return(list(p.value.0 = p.value.0, p.value.1 = p.value.1))
}

# Function to run simulations and calculate Type I error rates for unblinded interim analysis
sim.trials.unblinded.type1 <- function(nrep = 1000, p.ref.end, p.delta, n, ncompl, n.tot, cut.off, recruit, up.cap) {
  res <- replicate(nrep, sim.data.unblinded(p.end.ref = p.ref.end, p.delta = p.delta, n = n, ncompl = ncompl, n.tot = n.tot, cut.off = cut.off, recruit = recruit, up.cap = up.cap), simplify = FALSE)
  
  p.val.0 <- sapply(res, function(x) x$p.value.0)
  p.val.1 <- sapply(res, function(x) x$p.value.1)
  
  type1.error.0 <- mean(p.val.0 < 0.025)
  type1.error.1 <- mean(p.val.1 < 0.025)

  return(c(TypeI_Error_Initial = type1.error.0, TypeI_Error_Increased = type1.error.1))
}

# Example: Run simulations for unblinded interim analysis and display Type I error rates
assumptions <- list(
    list(p.ref.end = 0.50, p.delta = 0, n = 100, ncompl = 20, n.tot = 172, cut.off = 0.76, recruit = c(0,0,0,0.1,0.1,0.08,0.07,0.07,0.07,0.07,0.06,0.05,0.05,0.04,0.05,0.04,0.11,0.04), up.cap = 548),
    list(p.ref.end = 0.55, p.delta = 0, n = 100, ncompl = 20, n.tot = 172, cut.off = 0.76, recruit = c(0,0,0,0.1,0.1,0.08,0.07,0.07,0.07,0.07,0.06,0.05,0.05,0.04,0.05,0.04,0.11,0.04), up.cap = 548),
    list(p.ref.end = 0.60, p.delta = 0, n = 100, ncompl = 20, n.tot = 172, cut.off = 0.76, recruit = c(0,0,0,0.1,0.1,0.08,0.07,0.07,0.07,0.07,0.06,0.05,0.05,0.04,0.05,0.04,0.11,0.04), up.cap = 548),
    list(p.ref.end = 0.65, p.delta = 0, n = 100, ncompl = 20, n.tot = 172, cut.off = 0.76, recruit = c(0,0,0,0.1,0.1,0.08,0.07,0.07,0.07,0.07,0.06,0.05,0.05,0.04,0.05,0.04,0.11,0.04), up.cap = 548),
    list(p.ref.end = 0.70, p.delta = 0, n = 100, ncompl = 20, n.tot = 172, cut.off = 0.76, recruit = c(0,0,0,0.1,0.1,0.08,0.07,0.07,0.07,0.07,0.06,0.05,0.05,0.04,0.05,0.04,0.11,0.04), up.cap = 548),
    list(p.ref.end = 0.75, p.delta = 0, n = 120, ncompl = 30, n.tot = 172, cut.off = 0.76, recruit = c(0,0,0,0.1,0.1,0.08,0.07,0.07,0.07,0.07,0.06,0.05,0.05,0.04,0.05,0.04,0.11,0.04), up.cap = 548),
    list(p.ref.end = 0.80, p.delta = 0, n = 100, ncompl = 20, n.tot = 172, cut.off = 0.76, recruit = c(0,0,0,0.1,0.1,0.08,0.07,0.07,0.07,0.07,0.06,0.05,0.05,0.04,0.05,0.04,0.11,0.04), up.cap = 548),
    list(p.ref.end = 0.85, p.delta = 0, n = 100, ncompl = 20, n.tot = 172, cut.off = 0.76, recruit = c(0,0,0,0.1,0.1,0.08,0.07,0.07,0.07,0.07,0.06,0.05,0.05,0.04,0.05,0.04,0.11,0.04), up.cap = 548),
    list(p.ref.end = 0.90, p.delta = 0, n = 100, ncompl = 20, n.tot = 172, cut.off = 0.76, recruit = c(0,0,0,0.1,0.1,0.08,0.07,0.07,0.07,0.07,0.06,0.05,0.05,0.04,0.05,0.04,0.11,0.04), up.cap = 548)
)

for (assumption in assumptions) {
  res <- sim.trials.unblinded.type1(nrep = 1000, p.ref.end = assumption$p.ref.end, p.delta = assumption$p.delta, n = assumption$n, ncompl = assumption$ncompl, n.tot = assumption$n.tot, cut.off = assumption$cut.off, recruit = assumption$recruit, up.cap = assumption$up.cap)
  print(paste("Type I error for p.ref.end =", assumption$p.ref.end, ":"))
  print(res)
}
## [1] "Type I error for p.ref.end = 0.5 :"
##   TypeI_Error_Initial TypeI_Error_Increased 
##                 0.473                 0.520 
## [1] "Type I error for p.ref.end = 0.55 :"
##   TypeI_Error_Initial TypeI_Error_Increased 
##                 0.460                 0.511 
## [1] "Type I error for p.ref.end = 0.6 :"
##   TypeI_Error_Initial TypeI_Error_Increased 
##                 0.488                 0.540 
## [1] "Type I error for p.ref.end = 0.65 :"
##   TypeI_Error_Initial TypeI_Error_Increased 
##                 0.512                 0.566 
## [1] "Type I error for p.ref.end = 0.7 :"
##   TypeI_Error_Initial TypeI_Error_Increased 
##                 0.537                 0.569 
## [1] "Type I error for p.ref.end = 0.75 :"
##   TypeI_Error_Initial TypeI_Error_Increased 
##                 0.586                 0.604 
## [1] "Type I error for p.ref.end = 0.8 :"
##   TypeI_Error_Initial TypeI_Error_Increased 
##                 0.653                 0.657 
## [1] "Type I error for p.ref.end = 0.85 :"
##   TypeI_Error_Initial TypeI_Error_Increased 
##                  0.73                  0.73 
## [1] "Type I error for p.ref.end = 0.9 :"
##   TypeI_Error_Initial TypeI_Error_Increased 
##                 0.858                 0.858