library(caret)
library(rpart.plot)
knitr::opts_knit$set(root.dir = rprojroot::find_rstudio_root_file())
setwd(knitr::opts_knit$get("root.dir"))
source("R/flowshop.R")
source("R/models/model_utils.R")
# train_mhs <- ALL_MHS
# names(train_mhs) <- ALL_MHS
# rpart_models <- map_dfr(
#   train_mhs,
#   nestedCVtrainRecommendation,
#   model_name = "rpart",
#   .id = "mh"
# )
# save(rpart_models, file = "data/models/rpart_models.Rdata")
load("data/models/rpart_models.Rdata")
# train_mhs <- ALL_MHS
# names(train_mhs) <- ALL_MHS
# rf_models <- map_dfr(
#   train_mhs, 
#   nestedCVtrainRecommendation,
#   model_name = "rf",
#   .id = "mh"
# )
# save(rf_models, file = "data/models/rf_models.Rdata")
load("data/models/rf_models.Rdata")
# train_mhs <- ALL_MHS
# names(train_mhs) <- ALL_MHS
# xgbTree_models <- map_dfr(
#   train_mhs, 
#   nestedCVtrainRecommendation,
#   model_name = "xgbTree",
#   preprocess = function(dt) {
#     inputMatrix(dt, FALSE)
#   },
#   .id = "mh"
# )
# save(xgbTree_models, file = "data/models/xgbTree_models.Rdata")
load("data/models/xgbTree_models.Rdata")

Performance per MH

performanceDf <- function(dt, model_name) {
  dt %>%
    select(mh, fold, performance) %>%
    mutate(
      model = model_name,
      performance = map(performance, as.data.frame)
    ) %>%
    unnest() %>%
    gather(metric, value, -mh, -model, -fold)
}

all_performances <- bind_rows(
  performanceDf(rpart_models, 'dt'),
  performanceDf(rf_models, 'rf'),
  performanceDf(xgbTree_models, 'xgb')
)

bestMean <- function(dt) {
  dt %>% 
    group_by(model) %>%
    filter(!is.na(value)) %>%
    summarize(mean_val = mean(value)) %>% 
    filter(mean_val == max(mean_val)) %>% 
    pull(model)
}

makeSymm <- function(m) {
  m[upper.tri(m)] <- t(m)[upper.tri(m)]
  return(m)
}

tiesWithBest <- function(dt, best_mean) {
  # normality_test <- dt %>% 
  #   group_by(model) %>%
  #   nest() %>%
  #   mutate(is_normal = map_lgl(data, ~shapiro.test(.x$value)$p.value > 0.05))
  # if (!all(normality_test$is_normal)) {
  #   print("some metrics are not normal!")
  # }
  phtest <- PMCMRplus::frdAllPairsNemenyiTest(dt$value, dt$model, dt$fold)
  p_val_matrix <- makeSymm(phtest$p.value)
  p_values = as.data.frame(as.table(p_val_matrix))
  colnames(p_values) <- c("AlgoA", "AlgoB", "p.value")
  p_values <- p_values %>%
    mutate_if(is.factor, as.character)
  ties_with_winner <- p_values %>% 
    filter(AlgoA %in% best_mean | AlgoB %in% best_mean, p.value > 0.05)
  unique(c(best_mean, ties_with_winner$AlgoA, ties_with_winner$AlgoB))
}


test_dt <- all_performances %>%
  group_by(mh, metric) %>%
  nest() %>%
  mutate(
    best_mean = map(data, bestMean),
    ties_with_best = map2(data, best_mean, tiesWithBest)
  )

formatedVals <- function(model, mv, sv, best_mean, ties_with_best, ...) {
  print(mv)
  res <- sprintf("%.2f (%.2f)", mv, sv)
  if (model %in% best_mean) {
    res <- paste0("<strong>", res, "</strong>")
  }
  if (model %in% ties_with_best) {
    res <- paste0("<i>", res, "</i>")
  }
  res
}

summ_performances <- all_performances %>%
  group_by(mh, model, metric) %>%
  summarise(
    mv = mean(value, na.rm = T),
    sv = sd(value, na.rm = T)
  ) %>%
  left_join(test_dt, by = c("mh", "metric")) %>%
  mutate(
    is_best = map2_lgl(model, best_mean, ~ .x %in% .y),
    is_equal = map2_lgl(model, ties_with_best, ~ .x %in% .y),
    formated = case_when(
      is_best ~ sprintf("**%.2f (%.2f)**", mv, sv),
      is_equal ~ sprintf("*%.2f (%.2f)*", mv, sv),
      TRUE ~ sprintf("%.2f (%.2f)", mv, sv)
    )
  ) %>%
  select(mh, model, metric, formated)

for (mt in unique(summ_performances$metric)) {
  cat('\n### ', mt, '\n')
  metric_summ <- summ_performances %>%
    ungroup() %>%
    filter(metric == mt) %>%
    select(-metric) %>%
    spread(mh, formated) %>%
    knitr::kable(format = 'markdown', escape = F) %>%
    print()
}

Accuracy

model ACO IG IHC ILS ISA TS
dt 0.79 (0.03) 0.95 (0.01) 0.89 (0.01) 0.93 (0.01) 0.88 (0.02) 0.82 (0.02)
rf 0.85 (0.04) 0.94 (0.02) 0.91 (0.03) 0.91 (0.01) 0.89 (0.01) 0.88 (0.03)
xgb 0.85 (0.05) 0.95 (0.01) 0.90 (0.02) 0.91 (0.02) 0.87 (0.01) 0.88 (0.02)

F1

model ACO IG IHC ILS ISA TS
dt 0.47 (0.10) 0.24 (0.08) 0.28 (0.09) 0.34 (0.16) 0.17 (0.13) 0.61 (0.05)
rf 0.69 (0.08) 0.30 (0.13) 0.50 (0.15) 0.26 (0.11) 0.15 (0.09) 0.77 (0.05)
xgb 0.68 (0.08) 0.35 (0.12) 0.45 (0.10) 0.28 (0.11) 0.16 (0.05) 0.76 (0.04)

Kappa

model ACO IG IHC ILS ISA TS
dt 0.35 (0.10) 0.20 (0.11) 0.24 (0.09) 0.28 (0.18) 0.05 (0.12) 0.50 (0.06)
rf 0.59 (0.10) 0.28 (0.14) 0.45 (0.17) 0.22 (0.11) 0.10 (0.10) 0.69 (0.07)
xgb 0.58 (0.11) 0.29 (0.15) 0.40 (0.10) 0.23 (0.11) 0.07 (0.08) 0.68 (0.06)

Precision

model ACO IG IHC ILS ISA TS
dt 0.62 (0.08) 0.67 (0.33) 0.51 (0.19) 0.59 (0.21) 0.33 (0.35) 0.70 (0.08)
rf 0.71 (0.09) 0.49 (0.31) 0.60 (0.20) 0.36 (0.10) 0.31 (0.17) 0.78 (0.07)
xgb 0.70 (0.10) 0.56 (0.27) 0.55 (0.14) 0.36 (0.15) 0.23 (0.15) 0.78 (0.05)

Recall

model ACO IG IHC ILS ISA TS
dt 0.39 (0.12) 0.14 (0.07) 0.22 (0.10) 0.23 (0.16) 0.06 (0.10) 0.56 (0.08)
rf 0.67 (0.08) 0.24 (0.12) 0.44 (0.14) 0.22 (0.12) 0.09 (0.07) 0.75 (0.06)
xgb 0.67 (0.08) 0.25 (0.14) 0.40 (0.12) 0.24 (0.11) 0.09 (0.06) 0.75 (0.06)

Macro-level performances

macroPerformanceDf <- function(dt, model_name) {
  dt %>%
    select(mh, fold, reference, predicted) %>% 
    unnest() %>% 
    group_by(mh) %>% 
    mutate(instance_id = row_number(),
           model = model_name) %>% 
    group_by(model, fold, instance_id) %>% 
    nest() %>%
    mutate(
      predicted_set = map(
        data, 
        ~pull(filter(.x, as.integer(predicted) == 2), mh)
      ),
      reference_set = map(
        data,
        ~ pull(filter(.x, as.integer(reference) == 2), mh)
      )
    ) %>%
    select(-data)
}

all_performances <- bind_rows(
  macroPerformanceDf(rpart_models, 'dt'),
  macroPerformanceDf(rf_models, 'rf'),
  macroPerformanceDf(xgbTree_models, 'xgb')
)

hammingLoss <- function(pred, ref) {
  M <- 2
  symm_set_diff <- c(setdiff(pred, ref), setdiff(ref, pred))
  length(symm_set_diff) / M
}

classAccuracy <- function(pred, ref) {
  setequal(pred, ref)
}

macroPrecision <- function(pred, ref) {
  length(intersect(pred, ref)) / length(ref)
}

macroRecall <- function(pred, ref) {
  length(intersect(pred, ref)) / length(pred)
}

macroF1 <- function(pred, ref) {
  2 * length(intersect(pred, ref)) / (length(pred) + length(ref))
}

macroAccuracy <- function(pred, ref) {
  length(intersect(pred, ref)) / length(union(pred, ref))
}

getMetric <- function(dt, metricF) {
  mean(map2_dbl(dt$predicted_set, dt$reference_set, metricF), na.rm = T)
}

M <- length(ALL_MHS)
macro_performances <- all_performances %>%
  group_by(model, fold) %>%
  nest() %>%
  mutate(
    "Hamming loss" = map_dbl(data, getMetric, metricF = hammingLoss),
    "Classification Acc." = map_dbl(data, getMetric, metricF = classAccuracy),
    Precision = map_dbl(data, getMetric, metricF = macroPrecision),
    Recall = map_dbl(data, getMetric, metricF = macroRecall),
    F1 = map_dbl(data, getMetric, metricF = macroF1),
    Accuracy = map_dbl(data, getMetric, metricF = macroAccuracy)
  ) %>%
  select(-data) %>%
  gather(metric, value, -model, -fold)

test_dt <- macro_performances %>%
  group_by(metric) %>%
  nest() %>%
  mutate(
    best_mean = map(data, bestMean),
    ties_with_best = map2(data, best_mean, tiesWithBest)
  )
  
summ_performances <- macro_performances %>%
  group_by(model, metric) %>%
  summarise(
    mv = mean(value, na.rm = T),
    sv = sd(value, na.rm = T)
  ) %>%
  left_join(test_dt, by = c("metric")) %>%
  mutate(
    is_best = map2_lgl(model, best_mean, ~ .x %in% .y),
    is_equal = map2_lgl(model, ties_with_best, ~ .x %in% .y),
    formated = case_when(
      is_best ~ sprintf("**%.2f (%.2f)**", mv, sv),
      is_equal ~ sprintf("*%.2f (%.2f)*", mv, sv),
      TRUE ~ sprintf("%.2f (%.2f)", mv, sv)
    )
  ) %>%
  select(model, metric, formated)
summ_performances %>%
    ungroup() %>%
    spread(metric, formated) %>%
    knitr::kable(format = 'markdown', escape = F) %>%
    print()
model Accuracy Classification Acc. F1 Hamming loss Precision Recall
dt 0.87 (0.01) 0.46 (0.03) 0.93 (0.00) 0.36 (0.02) 0.96 (0.01) 0.90 (0.01)
rf 0.89 (0.01) 0.53 (0.04) 0.93 (0.01) 0.31 (0.03) 0.96 (0.01) 0.93 (0.01)
xgb 0.88 (0.01) 0.52 (0.02) 0.93 (0.01) 0.32 (0.02) 0.95 (0.01) 0.93 (0.01)