# A function for calculating R-squared from predicted & actual outcome values
rsquare <- function(true, predicted) {
  sse <- sum((predicted - true)^2)
  sst <- sum((true - mean(true))^2)
  rsq <- 1 - sse / sst
  
  #  rsq=abs(rsq)
  
  return (rsq)
}


# A function for optimizing elastic net on training & validation sets, and testing on test set,
# given a input data.frame df. Note column1 of df must be a list of unique row IDs & column2 is outcome

# Note: portion is the # of subjects for testset and validation set *combined*
linearEN <- function(df, n_runs, cv_n) {
  
  # set paramaters & meta-data
  Bs <- data.frame(data= c(names(df[,-c(1:2)]), "(Intercept)"))
  names(Bs) <- "predictor"
  Bs$chosen <- 0  # number of times predictor got chosen 
  Bs$mean.B <- NA # strength of association between that predictor and cbcl_att
  subjerr <- data.frame(id = df[,1], mae = NA) # list of subject IDs and their prediction error
  rsq.test <- c()
  rsq.test.CV <- c()
  grandcvlog = c()
  runcounter=0
  
  # calculate num of rows corresponding to 1 CV fold
  portion = round(length(df[,1])/cv_n)
  
  start.time <- Sys.time()
  set.seed(100)
  for (x in 1:n_runs) {
    cvlog = data.frame(fold=rep(NA, cv_n), testn=rep(NA, cv_n), test.rsq=rep(NA, cv_n), lambda=rep(NA, cv_n), alpha=rep(NA, cv_n))
    cv = data.frame(ids = 1:length(df[,2]),  chosen_ever = rep(0,length(df[,2])))
    for (fold in 1:cv_n) {
      runcounter=runcounter+1
      # let 1 CV fold be used for validation & testing
      if (fold == cv_n) {
        testvalrows <- sample(cv$ids[cv$chosen_ever==0], sum(cv$chosen_ever==0), replace=F)   
      } else {
        testvalrows <- sample(cv$ids[cv$chosen_ever==0], portion, replace=F) }
      cv$chosen_ever[testvalrows] <- cv$chosen_ever[testvalrows]+1
      
      y.train<- df[-testvalrows,2]
      x.train <- makeX(df[-testvalrows, 3:length(df)])
      
      testrows <- sample(testvalrows, round(portion/2), replace=F)
      y.test <- df[testrows, 2]  
      x.test <- makeX(df[testrows, 3:length(df)]) 
      
      valrows <- setdiff(testvalrows, testrows) # all not in testrows
      ids.val <- df[valrows,1] # make df to save subj id and their error
      y.val <- df[valrows,2] 
      x.val <- makeX(df[valrows, 3:length(df)]) 
      
      # impute all datasets with training column means 
      train.means <- colMeans(x.train, na.rm=T)
      for (var in 1:length(train.means)){
        tr.rows <- which(is.na(x.train[,var]))
        v.rows <-   which(is.na(x.val[,var]))
        t.rows <-  which(is.na(x.test[,var]))
        # check if the var is binary
        vals <- unique(x.train[,var])
        binvar <- length(na.omit(vals))==2
        if (!is_empty(tr.rows)){
          if (binvar==TRUE) {
            x.train[tr.rows, var] <- as.numeric(round(train.means[var])) } 
          else {
            x.train[tr.rows, var] <- train.means[var]}} # impute training-set with its means
        if (!is_empty(v.rows)){
          if (binvar==TRUE) {
            x.val[v.rows, var] <- as.numeric(round(train.means[var])) } 
          else {
            x.val[v.rows, var] <- train.means[var]}} # impute validation set with training-set means
        if (!is_empty(t.rows)){
          if (binvar==TRUE) {
            x.test[t.rows, var] <- as.numeric(round(train.means[var])) } 
          else {
            x.test[t.rows, var]<- train.means[var]}}} # impute testset with training-set means
      # run cv.glmnet with 10 diff alpha values
      mod <- list()
      for (i in seq(from=0, to=1, by=0.05)){
        fit.name <- paste0("alpha", i)
        
        mod[[fit.name]] <- cv.glmnet(x.train, y.train, type.measure="mae",
                                     alpha=i, family="gaussian" )  }
      # use validation set to find best alpha
      # i.e. which alpha gives min error on unseen data
      results <- data.frame()
      for (i in seq(from=0, to=1, by=0.05)) {
        fit.name <- paste0("alpha", i)
        predicted <- predict(mod[[fit.name]], s = mod[[fit.name]]$lambda.1se, 
                             newx = x.val)
        mae <- mean(abs(y.val-predicted)) # mean absolute error
        mse <- mean((y.val-predicted)^2) # mean squared error
        temp <- data.frame(alpha=i, mae=mae, mse=mse, fit.name = fit.name)
        results <- rbind(results, temp)
      }
      best.alpha = results$alpha[which(results$mae==min(results$mae))]
      best.fit.name <- paste0("alpha", best.alpha)
      best.lambda = mod[[best.fit.name]]$lambda.1se
      
      cvlog$lambda[fold] <- best.lambda
      cvlog$alpha[fold] <- best.alpha
      
      # save mean abs error for each subject
      pred.val <- predict(mod[[best.fit.name]], s= "lambda.1se", 
                          newx = x.val)
      for (subj in 1:length(ids.val)){
        row <- which(subjerr$id==ids.val[subj])
        subjerr$mae[row] <- mean(c(subjerr$mae[row], abs(y.val[subj] - pred.val[subj])), na.rm=T)} 
      
      # save info on predictor weights (which have non-zeros and which are the biggest)
      betas <- as.data.frame(as.matrix(coef(mod[[best.fit.name]], s = "lambda.1se")))
      rows <- which(abs(betas$s1)>0)
      Bs[,length(Bs)+1] <- NA # column to store this run's betas
      names(Bs)[length(Bs)] <- 
        paste0("a", best.alpha, "run", x*fold) # column name to reflect alpha value
      for (n in rows) {
        var <- rownames(betas)[n]
        grandrow <- which(Bs$predictor==var) # find right row in Bs
        Bs$chosen[grandrow] <- Bs$chosen[grandrow] + 1 
        Bs[grandrow, length(Bs)] <- betas$s1[n]
      }
      # test on unseen test-set
      pred.test <- predict(mod[[best.fit.name]],  s = "lambda.1se", 
                           newx = x.test)
      rsq.test <- c(rsq.test, rsquare(y.test, pred.test)) 
      # store details of this fold
      cvlog$fold[fold]= fold
      cvlog$testn[fold]=length(testrows)
      cvlog$test.rsq[fold] <- rsquare(y.test, pred.test)
    } # end of 1 CV fold
    rsq.test.CV <- c(rsq.test.CV, mean(rsq.test))
    grandcvlog <- rbind(grandcvlog, cvlog)
  } # end of X loops
  
  # calculate mean Beta values after 100 runs
  Bs$mean.B = rowMeans(Bs[,str_detect(names(Bs), "run")], na.rm=T)
  Bs <- move_columns(Bs, .after=chosen, "mean.B")
  
  fin.time <- Sys.time()
  print(fin.time-start.time)
  
  thresh = .95
  
  res_summary <- round(data.frame(
    sample_n = length(df[,1]),
    num.runs=runs*cv_n,
    num.inputs=length(df)-2,
    num.meetingthreshfreq = (sum(Bs$chosen>=(runs*cv_n)*thresh)-1), 
    intercept = Bs$mean.B[which(Bs$predictor=="(Intercept)")],
    mean.rsq = mean(rsq.test)*100,
    sd.rsq = sd(rsq.test)*100,
    rsq.CIlo = quantile(rsq.test, probs=0.025)*100,
    rsq.CIhi = quantile(rsq.test, probs=0.975)*100,
    CVmean.ofrsqs = mean(rsq.test.CV)*100,
    CVrsq.CIlo = quantile(rsq.test.CV, probs=0.025)*100,
    CVrsq.CIhi = quantile(rsq.test.CV, probs=0.975)*100,
    error.mean = mean(subjerr$mae, na.rm=T),
    error.CIlo = as.numeric(quantile(subjerr$mae, probs=0.025, na.rm=T)),
    error.CIhi = as.numeric(quantile(subjerr$mae, probs=0.975, na.rm=T))
  ), digits=2)
  rownames(res_summary) <- c()
  print(t(res_summary))
  
  Bs <- Bs[-c(which(Bs$predictor=="(Intercept)")),]
  
  # Edit B's
  # make selection frequency % 
  Bs$chosenpc <- round(Bs$chosen/runcounter*100)
  Bs <- move_columns(Bs, .after=chosen, "chosenpc")
  
  # add absolute b and record sign
  Bs$abs.B <- abs(Bs$mean.B)
  Bs$sign <- sign(Bs$mean.B)
  Bs$sign[Bs$sign==1] <- 1.05
  Bs$sign[Bs$sign==-1] <- -0.05 # values for aligning text to yaxis
  
  # calculate CIs for betas
  Bs$cilo<-NA
  Bs$cihi<-NA
  
  betaruns = Bs[,str_detect(names(Bs), "run")]
  for (var in 1:length(Bs$predictor)) {
    Bs$cilo[var] <- as.numeric(quantile(betaruns[var,], probs=0.025, na.rm=T))
    Bs$cihi[var] <- as.numeric(quantile(betaruns[var,], probs=0.975, na.rm=T))
    Bs$sd[var] <- sd(betaruns[var,], na.rm=T)
  }
  
  rownames(Bs) <- Bs$predictor
  # move the vars to plot to the front of df
  Bs<- move_columns(Bs, abs.B, .after = "mean.B")
  Bs<- move_columns(Bs, cilo, .after = "abs.B")
  Bs<- move_columns(Bs,cihi , .after = "cilo")
  Bs<- move_columns(Bs, sign, .after = "t_est")
  
  
  
  results <- list("Bs"= Bs, "log" = grandcvlog, "summary"=res_summary, "Rsq"= mean(rsq.test)*100)
  return (results)
}

# *******************
# A function for optimizing elastic net with binary outcome on training & validation sets, and testing on test set,
# given a input data.frame df. Note column1 of df must be a list of unique row IDs & column2 is outcome
# The "balanced" bit refers to how it balanced the # cases & controls so accuracy isnt biased
# Note: portion is the # of subjects for testset and validation set *combined*

balanced.logisticEN <- function(df, n_runs, cv_n) {
  
  # set paramaters & meta-data
  Bs <- data.frame(data= c(names(df[,-c(1:2)]), "(Intercept)"))
  names(Bs) <- "predictor"
  Bs$chosen <- 0  # number of times predictor got chosen 
  Bs$mean.B <- NA # strength of association between that predictor and cbcl_att
  auc.test.CV <- c()
  grandcvlog = c()
  runcounter=0
  
  start.time <- Sys.time()
  for (x in 1:n_runs) {
    
    # choose the same # of controls to cases
    cases <- df[df[,2]==1,]
    control_rows = sample(df$myid[df[,2]==0], length(cases$myid), replace=F) # random sample of controls
    controls <- df[control_rows,]
    df_balanced <- rbind(cases, controls)
    
    cvlog = data.frame(fold=rep(NA, cv_n), 
                       testn=rep(NA, cv_n),
                       testn_cases=rep(NA, cv_n), 
                       auc.train=rep(NA, cv_n), 
                       train_tpr = rep(NA, cv_n), 
                       train_tnr = rep(NA, cv_n),
                       auc.test=rep(NA, cv_n),
                       test_tpr = rep(NA, cv_n), 
                       test_tnr = rep(NA, cv_n))
    cv = data.frame(ids = 1:length(df_balanced[,2]),  chosen_ever = rep(0,length(df_balanced[,2])), case=df_balanced[,2])
    #cv$case.wgt[cv$case.wgt==1] <- 1.15
    #cv$case.wgt[cv$case.wgt==0] <- 1
    
    # calculate num of rows for validation & testing (1 CV fold)
    portion = round(length(df_balanced[,1])/cv_n)
    
    
    for (fold in 1:cv_n) {
      runcounter=runcounter+1
      # let portion of the data go to validation & testing (1 cv fold)
      if (fold == cv_n) { 
        testvalrows <- sample(cv$ids[cv$chosen_ever==0], sum(cv$chosen_ever==0), replace=F)   
      } else {
        testvalrows <- sample(cv$ids[cv$chosen_ever==0], portion, replace=F) }
      cv$chosen_ever[testvalrows] <- cv$chosen_ever[testvalrows]+1
      
      y.train<- df_balanced[-testvalrows,2]
      x.train <- makeX(df_balanced[-testvalrows, 3:length(df_balanced)])
      
      
      testrows <- sample(testvalrows, round(portion/2), replace=F)
      y.test <- df_balanced[testrows, 2]  
      x.test <- makeX(df_balanced[testrows, 3:length(df_balanced)]) 
      
      valrows <- setdiff(testvalrows, testrows) # all not in testrows
      ids.val <- df_balanced[valrows,1] # make df_balanced to save subj id and their error
      y.val <- df_balanced[valrows,2] 
      x.val <- makeX(df_balanced[valrows, 3:length(df_balanced)]) 
      
      # impute all datasets with training column means 
      train.means <- colMeans(x.train, na.rm=T)
      for (var in 1:length(train.means)){
        tr.rows <- which(is.na(x.train[,var]))
        v.rows <-   which(is.na(x.val[,var]))
        t.rows <-  which(is.na(x.test[,var]))
        # check if the var is binary
        vals <- unique(x.train[,var])
        binvar <- length(na.omit(vals))==2
        if (!is_empty(tr.rows)){
          if (binvar==TRUE) {
            x.train[tr.rows, var] <- as.numeric(round(train.means[var])) } 
          else {
            x.train[tr.rows, var] <- train.means[var]}} # impute training-set with its means
        if (!is_empty(v.rows)){
          if (binvar==TRUE) {
            x.val[v.rows, var] <- as.numeric(round(train.means[var])) } 
          else {
            x.val[v.rows, var] <- train.means[var]}} # impute validation set with training-set means
        if (!is_empty(t.rows)){
          if (binvar==TRUE) {
            x.test[t.rows, var] <- as.numeric(round(train.means[var])) } 
          else {
            x.test[t.rows, var]<- train.means[var]}}} # impute testset with training-set means
      
      # run cv.glmnet with 10 diff alpha values
      mod <- list()
      for (i in seq(from=0, to=1, by=0.05)){
        fit.name <- paste0("alpha", i)
        
        mod[[fit.name]] <- cv.glmnet(x.train, y.train, type.measure="auc",
                                     alpha=i, family="binomial" )  }
      
      # use validation set to find best alpha
      results <- data.frame()
      for (i in seq(from=0, to=1, by=0.05)) {
        fit.name <- paste0("alpha", i)
        predicted <- predict(mod[[fit.name]], s = mod[[fit.name]]$lambda.1se, 
                             newx = x.val, type="class")
        # note: type can be: link, response, coefficients, class, nonzero
        my.lambda = mod[[fit.name]]$lambda.1se
        my.auc = mod[[fit.name]]$cvm[which(mod[[fit.name]]$lambda==my.lambda)]
        temp <- data.frame(alpha=i, auc = my.auc, lambda=my.lambda)
        results <- rbind(results, temp)
      }
      best.alpha = results$alpha[which(results$auc==max(results$auc))]
      best.fit.name <- paste0("alpha", best.alpha)
      best.lambda = mod[[best.fit.name]]$lambda.1se
      
      # save training set accuracy stats
      predicted <- predict(mod[[best.fit.name]], s=mod[[best.fit.name]]$lambda.1se,
                           newx = x.train, type="class")
      comp <- data.frame(predicted, y.train)
      names(comp) <- c("predicted", "actual")
      tp <- sum(comp$predicted==1 & comp$actual==1)
      tn <- sum(comp$predicted==0 & comp$actual==0)
      fn <- sum(comp$predicted==0 & comp$actual==1)
      fp <- sum(comp$predicted==1 & comp$actual==0)
      cvlog$train_tpr[fold] = tp/(tp+fn) # "1s" we got right  "sensitivity"
      cvlog$train_tnr[fold] = tn/(tn+fp) # "0s" we got right "specificity"

      # save info on predictor weights (which have non-zeros and which are the biggest)
      betas <- as.data.frame(as.matrix(coef(mod[[best.fit.name]], s = "lambda.1se")))
      rows <- which(abs(betas$s1)>0)
      Bs[,length(Bs)+1] <- NA # column to store this run's betas
      names(Bs)[length(Bs)] <- 
        paste0("a", best.alpha, "run", x*fold) # column name to reflect alpha value
      for (n in rows) {
        var <- rownames(betas)[n]
        grandrow <- which(Bs$predictor==var) # find right row in Bs
        Bs$chosen[grandrow] <- Bs$chosen[grandrow] + 1 
        Bs[grandrow, length(Bs)] <- betas$s1[n]
      }
      # test on unseen test-set
      pred.test <- predict(mod[[best.fit.name]],  s = "lambda.1se", 
                           newx = x.test, type="class")
      # test set accuracy stats:
      auc.test = mod[[best.fit.name]]$cvm[which(mod[[best.fit.name]]$lambda==
                                                  mod[[best.fit.name]]$lambda.1se)]
      comp <- data.frame(pred.test, y.test)
      names(comp) <- c("predicted", "actual")
      tp <- sum(comp$predicted==1 & comp$actual==1)
      tn <- sum(comp$predicted==0 & comp$actual==0)
      fn <- sum(comp$predicted==0 & comp$actual==1)
      fp <- sum(comp$predicted==1 & comp$actual==0)
      cvlog$test_tpr[fold] = tp/(tp+fn) # "1s" we got right  "sensitivity" as decimal
      cvlog$test_tnr[fold] = tn/(tn+fp) # "0s" we got right "specificity" as decimal
      
      # store details of this fold
      cvlog$fold[fold]= fold
      cvlog$testn[fold]=length(testrows)
      cvlog$auc.test[fold] <- auc.test
      cvlog$auc.train[fold]<- max(results$auc)
      cvlog$testn_cases[fold] <- sum(y.test)
    } # end of 1 CV fold
    auc.test.CV <- c(auc.test.CV, mean(cvlog$auc.test)) # mean of 5 test-sets
    grandcvlog <- rbind(grandcvlog, cvlog)
  } # end of X runs
  
  # calculate mean Beta values after X runs
  Bs$mean.B = rowMeans(Bs[,str_detect(names(Bs), "run")], na.rm=T)
  Bs <- move_columns(Bs, .after=chosen, "mean.B")
  
  fin.time <- Sys.time()
  print(fin.time-start.time)
  
  thresh = .95
  
  res_summary <- round(data.frame(
    sample_n = length(df[,1]),
    num.runs= n_runs*cv_n,
    num.inputs= length(df)-2,
    num.meetingthreshfreq = (sum(Bs$chosen>=(n_runs*cv_n)*thresh)-1), 
    intercept = Bs$mean.B[which(Bs$predictor=="(Intercept)")],
    mean.auc = mean(grandcvlog$auc.test)*100,
    sd.auc = sd(grandcvlog$auc.test)*100,
    aucCIlo = quantile(grandcvlog$auc.test, probs=0.025)*100,
    aucCIhi = quantile(grandcvlog$auc.test, probs=0.975)*100,
    mean.sensit = mean(grandcvlog$test_tpr), 
    mean.specif = mean(grandcvlog$test_tnr)

  ), digits=2)
  rownames(res_summary) <- c()
  print(t(res_summary))
  
  Bs <- Bs[-c(which(Bs$predictor=="(Intercept)")),]
  
  # Edit B's
  # make selection frequency % 
  Bs$chosenpc <- round(Bs$chosen/runcounter*100)
  Bs <- move_columns(Bs, .after=chosen, "chosenpc")
  
  # make ORs, probability, and record sign
  Bs$exp.B <- exp(Bs$mean.B) # odds ratio
  Bs$prob <-   Bs$exp.B/(1 + Bs$exp.B) # probability
  Bs$sign <- sign(Bs$mean.B)
  Bs$sign[Bs$sign==1] <- 1.05
  Bs$sign[Bs$sign==-1] <- -0.05 # values for aligning text to yaxis
  
  # calculate CIs for betas
  Bs$cilo<-NA
  Bs$cihi<-NA
  
  betaruns = Bs[,str_detect(names(Bs), "run")]
  for (var in 1:length(Bs$predictor)) {
    Bs$cilo[var] <- as.numeric(quantile(betaruns[var,], probs=0.025, na.rm=T))
    Bs$cihi[var] <- as.numeric(quantile(betaruns[var,], probs=0.975, na.rm=T))
    Bs$sd[var] <- sd(betaruns[var,], na.rm=T)
  }

  # CIs are in terms of logits, convert to OR 
  Bs$exp.cilo <- exp(Bs$cilo)
  Bs$exp.cihi <- exp(Bs$cihi)
  
  rownames(Bs) <- Bs$predictor
  # move the vars to plot to the front of df
  Bs<- move_columns(Bs, exp.B, .after = "mean.B")
  Bs<- move_columns(Bs, exp.cilo, .after = "exp.B")
  Bs<- move_columns(Bs, exp.cihi , .after = "exp.cilo")
  Bs<- move_columns(Bs, sd , .after = "exp.cihi")
  Bs<- move_columns(Bs, sign, .after = "t_est")
  
  results <- list("Bs"= Bs, 
                  "log" = grandcvlog, 
                  "summary"=res_summary, 
                  "auc"= mean(grandcvlog$auc.test)*100)
  return (results)
}

# A function for optimizing elastic net with binary outcome on training & validation sets, and testing on test set,
# given a input data.frame df. Note column1 of df must be a list of unique row IDs & column2 is outcome
# Note, this does not make sure the # of cases ~ the # of controls
# Note: portion is the # of subjects for testset and validation set *combined*



# collate linear results from sub-groups
collate_linear <- function(mod, groupname) {
  res_tbl <- data.frame(N = mod[["summary"]]$sample_n,
                        intercept = mod[["summary"]]$intercept,
                        Rsq_mean =  mod[["summary"]]$mean.rsq,
                        Rsq_sd =  mod[["summary"]]$sd.rsq,
                        Rsq_lowCI = mod[["summary"]]$rsq.CIlo,
                        Rsq_hiCI =  mod[["summary"]]$rsq.CIhi,
                        num_predictors = mod[["summary"]]$num.meetingthreshfreq,
                        num_total = mod[["summary"]]$num.inputs)
  res_tbl$group = groupname
  res_tbl <- move_columns(res_tbl, group, .before=1)
  
  return (res_tbl)
}

# collate logistic results from sub-groups
collate_logistic <- function(mod, groupname) {
  res_tbl <- data.frame(N = mod[["summary"]]$sample_n,
                        intercept = mod[["summary"]]$intercept,
                        AUC_mean =  mod[["summary"]]$CVmean.ofAUCs,
                        AUC_lowCI = mod[["summary"]]$CVAUC.CIlo,
                        AUC_hiCI =  mod[["summary"]]$CVAUC.CIhi,
                        num_predictors = mod[["summary"]]$num.meetingthreshfreq,
                        num_total = mod[["summary"]]$num.inputs)
  res_tbl$group = groupname
  res_tbl <- move_columns(res_tbl, group, .before=1)
  
  return (res_tbl)
}