library(lhs)
source("/home/boyazhang/repos/unifdist/code/ud_new.R")
source("/home/boyazhang/repos/unifdist/code/lhsbeta_design.R")

ninit <- 8


f <- function(X, sd=0.01) {
  X[,1] <- (X[,1] - 0.5)*6 + 1
  X[,2] <- (X[,2] - 0.5)*6 + 1
  y <- X[,1] * exp(-X[,1]^2 - X[,2]^2) + rnorm(nrow(X), sd=sd)
}
x1 <- x2 <- seq(0,1,length=100)
XX <- expand.grid(x1, x2)
ytrue <- f(XX, sd=0)


library(laGP)
eps <- sqrt(.Machine$double.eps)
mymaximin <- function(n, m, T=100000, Xorig=NULL) 
{   
  X <- matrix(runif(n*m), ncol=m)  ## initial design
  d <- distance(X); d <- as.numeric(d[upper.tri(d)])
  md <- min(d)
  if(!is.null(Xorig)) {           ## new code
    md2 <- min(distance(X, Xorig))
    if(md2 < md) md <- md2
  }
  
  for(t in 1:T) {
    row <- sample(1:n, 1)
    xold <- X[row,]       ## random row selection
    X[row,] <- runif(m)   ## random new row
    d <- distance(X); d <- as.numeric(d[upper.tri(d)])
    mdprime <- min(d)
    if(!is.null(Xorig)) {       ## new code
      mdprime2 <- min(distance(X, Xorig))
      if(mdprime2 < mdprime) mdprime <- mdprime2
    }
    if(mdprime > md) { md <- mdprime  ## accept
    } else { X[row,] <- xold }        ## reject
  }
  return(X)
}

obj.alm <- function(x, gpi) - sqrt(predGP(gpi, matrix(x, nrow=1), lite=TRUE)$s2)

alm.search <- function(X, gpi)
{
  start <- mymaximin(nrow(X), 2, T=100*nrow(X), Xorig=X)[1:nrow(X),] ## start from maximin
  xnew <- matrix(NA, nrow=nrow(start), ncol=ncol(X)+1)
  for(i in 1:nrow(start)) {
    out <- optim(start[i,], obj.alm, method="L-BFGS-B", 
                 lower=0, upper=1, gpi=gpi)
    xnew[i,] <- c(out$par, -out$value)
  }
  solns <- data.frame(cbind(start, xnew))
  names(solns) <- c("s1", "s2", "x1", "x2", "val")
  return(solns)
}

opt_alm_mean <- function(design, nstop){
  
## this is where we can put a beta design instead
  if(design == "lhs"){
    X <- randomLHS(ninit, 2)}
  else if(design == "lhs_beta"){
    X <- lhs_beta_1(n=ninit, m=2, shape1 = 3, shape2 = 6 )$X
  }else if(design == "bd"){
    X <- bd(n=ninit, m=2, shape1 = 3, shape2 = 6 )$X
  }
y <- f(X)
gpi <- newGP(X, y, d=0.1, g=0.1*var(y), dK=TRUE)
g <- garg(list(mle=TRUE), y)

## this is the crutial line here
## d <- darg(list(mle = TRUE, max=0.5), X)
d <- darg(list(mle=TRUE), X)

mle <- jmleGP(gpi, c(d$min, d$max), c(g$min, g$max), d$ab, g$ab)
rmse <- sqrt(mean((ytrue - predGP(gpi, XX, lite=TRUE)$mean)^2))
solns <- alm.search(X, gpi)
m <- which.max(solns$val)
prog <- solns$val[m]

for(i in nrow(X):nstop) {
  solns <- alm.search(X, gpi)
  m <- which.max(solns$val); prog <- c(prog, solns$val[m])
  xnew <- as.matrix(solns[m,3:4])
  X <- rbind(X, xnew); y <- c(y, f(xnew))
  updateGP(gpi, xnew, y[length(y)])
  mle <- rbind(mle, jmleGP(gpi, c(d$min, d$max), c(g$min, g$max), d$ab, g$ab))
  rmse <- c(rmse, sqrt(mean((ytrue - predGP(gpi, XX, lite=TRUE)$mean)^2)))
}
gp_mean = predGP(gpi, XX, lite=TRUE)$mean
gp_s2 = predGP(gpi, XX, lite=TRUE)$s2
deleteGP(gpi)
return(list(X = X, gp_mean = gp_mean, gp_s2 = gp_s2))
}

reps <- 100
lhs_min <- rep(NA, reps)
lhs_max <- rep(NA, reps)

Sys.time()
for(i in 1: reps){
  p <- opt_alm_mean("lhs", 32)
  lhs_min[i] <- min(p$gp_mean)
  lhs_max[i] <- max(p$gp_mean)
}

lhs_beta_min <- rep(NA, reps)
lhs_beta_max <- rep(NA, reps)

for(i in 1: reps){
  p <- opt_alm_mean("lhs_beta", 32)
  lhs_beta_min[i] <- min(p$gp_mean)
  lhs_beta_max[i] <- max(p$gp_mean)
}

bd_min <- rep(NA, reps)
bd_max <- rep(NA, reps)

for(i in 1: reps){
  p <- opt_alm_mean("bd", 32)
  bd_min[i] <- min(p$gp_mean)
  bd_max[i] <- max(p$gp_mean)
}
Sys.time()

save.image("/home/boyazhang/repos/unifdist/code/alm_example.RData")
load("/home/boyazhang/repos/unifdist/code/alm_example.RData")
par(mfcol = c(1,2))
boxplot(lhs_max, lhs_beta_max, bd_max, names = c("lhs", "lhs-beta", "beta"), main = "global maximum, 8->32, 100 reps")
boxplot(lhs_min, lhs_beta_min, bd_min, names = c("lhs", "lhs-beta", "beta"), main = "global minimum, 8->32, 100 reps")

t.test(lhs_max, bd_max, alternative = "less")$p.value
## [1] 0.0006561552
t.test(lhs_min, bd_min, alternative = "greater")$p.value
## [1] 0.0006917197
t.test(lhs_max, lhs_beta_max, alternative = "less")$p.value
## [1] 0.08174546
t.test(lhs_min, lhs_beta_min, alternative = "greater")$p.value
## [1] 0.305438