## directed ud function
ud_direct <- function(n, dim, dmax=1, T=1000){
  ind <- 0 # number of total runs in inner loop
  x <- matrix(runif(dim*n), ncol = dim)
  du <- seq(from = .Machine$double.eps , to = dmax*sqrt(dim), length.out = 1000) 
  ksprog <- rep(NA, T) # progress of maximum k-s distance
  
  for(j in 1:T){
    distance1 <- dist(x)
    ecdf1 <- ecdf(distance1)
    ecdf2 <- ecdf(du)
    dist_diff <- abs(ecdf1(distance1) - ecdf2(distance1))
    dist_max_loc <- which.max(dist_diff) #find where is the max_diff
    ppair <- loc_search(n, dist_max_loc)
    ksold <- ks.test(distance1, du)$statistic
    ksprog[j] <- ksold
    for(i in 1:50){# after finding where is the max_diff, propose to change the pair 50 times 
      ind <- ind +1
      #change both two points in the pair is better than just change one of them
      #coin <- rbinom(1,1, 0.5)
      #xold <- x[ppair[coin],]
      #x[ppair[coin],] <- matrix(runif(dim), ncol = dim)
      xold <- x[ppair,]
      x[ppair,] <- matrix(runif(dim*2), ncol = dim)
      ksnew <- ks.test(dist(x), du)$statistic
      if(ksnew < ksold){ksold <- ksnew; break}
      #  else{x[ppair[coin],]<- xold}
      else{x[ppair,]<- xold}
      
    }}
  return(list(x=x, ksopt=ksold,ksprog = ksprog, ind = ind))
}


# function to find max_dist_difference point pair in a n*n distance matrix
loc_search <- function(n, dist_max_loc){
  sum <- 0
  col_elem_num <- (n-1):1
  for(j in 1:(n-1)){
    sum <- col_elem_num[j] + sum
    if(sum >= dist_max_loc){
      return(c(n-(sum-dist_max_loc), j))
    }
  }
}
# compare with ud
setwd("/home/boyazhang/repos/unifdist/code")
source("/home/boyazhang/repos/unifdist/code/ud.R")
## Loading required package: mvtnorm
## Loading required package: tgp
## 
## Attaching package: 'EnvStats'
## The following objects are masked from 'package:stats':
## 
##     predict, predict.lm
## The following object is masked from 'package:base':
## 
##     print.default
system.time({test11 <- ud(8,2,1); test12 <- ud(8,2,1); test13 <- ud(8,2,1)})
##    user  system elapsed 
##  61.266   0.000  61.267
system.time({test21 <- ud_direct(8,2,T=2000); test22 <- ud_direct(8,2,T=2000); test23 <- ud_direct(8,2,T=2000)})
##    user  system elapsed 
##  69.567   0.003  69.579
# ksopt of ud and ksopt of ud_direct
c(test11$ksopt,test12$ksopt,test13$ksopt) 
##          D          D          D 
## 0.03600000 0.04800000 0.03771429
c(test21$ksopt,test22$ksopt,test23$ksopt)
##          D          D          D 
## 0.08357143 0.06485714 0.07928571
system.time({test11 <- ud(16,3,1); test12 <- ud(16,3,1); test13 <- ud(16,3,1)})
##    user  system elapsed 
##  67.064   0.000  67.065
system.time({test21 <- ud_direct(16,3,T=2000); test22 <- ud_direct(16,3,T=2000); test23 <- ud_direct(16,3,T=2000)})
##    user  system elapsed 
##  71.726   0.000  71.730
# ksopt of ud and ksopt of ud_direct
c(test11$ksopt,test12$ksopt,test13$ksopt) 
##          D          D          D 
## 0.03400000 0.04400000 0.04266667
c(test21$ksopt,test22$ksopt,test23$ksopt)
##         D         D         D 
## 0.1293333 0.1783333 0.1486667
system.time({test11 <- ud(32,4,1); test12 <- ud(32,4,1); test13 <- ud(32,4,1)})
##    user  system elapsed 
##  88.387   0.000  88.389
system.time({test21 <- ud_direct(32,4,T=2000); test22 <- ud_direct(32,4,T=2000); test23 <- ud_direct(32,4,T=2000)})
##    user  system elapsed 
##  93.141   0.000  93.142
# ksopt of ud and ksopt of ud_direct
c(test11$ksopt,test12$ksopt,test13$ksopt) 
##          D          D          D 
## 0.07596774 0.04493548 0.07598387
c(test21$ksopt,test22$ksopt,test23$ksopt)
##         D         D         D 
## 0.1590323 0.2865323 0.2247419