library(caret)
library(rpart.plot)
knitr::opts_knit$set(root.dir = rprojroot::find_rstudio_root_file())
setwd(knitr::opts_knit$get("root.dir"))
source("R/flowshop.R")
source("R/models/model_utils.R")

recommendation_data <- loadModelData("MH recommendation")
trainRPart <- function(mh) {
  set.seed(123)
  input_data <- inputMatrix(recommendation_data, FALSE)
  
  # input_data <- removeCorrelatedFeatures(input_data)
  mh_rec <- mh
  out_col_name <- str_to_lower(paste0('rec_', mh_rec))
  output_data <- factor(recommendation_data[,out_col_name], 
                        labels = c('no', 'yes'))
  
  
  train_idxs <- createDataPartition(output_data, 
                                    p = 0.8,
                                    list = FALSE)
  
  train_input_data <- input_data[train_idxs,]
  train_output_data <- output_data[train_idxs]
  test_input_data <- input_data[-train_idxs,]
  test_output_data <- output_data[-train_idxs]
  
  train_result <- train(
    x = train_input_data,
    y = train_output_data,
    method = 'xgbTree',
    metric = 'Kappa',
    trControl = trainControl(
      method = 'repeatedcv',
      number = 10,
      repeats = 10,
      classProbs = F
    )
  )
  
  final_model <- train_result$finalModel
  predicted_test_output <- predict(train_result, test_input_data, type = "raw")
  perf <- postResample(predicted_test_output, test_output_data)
  perf['Precision'] <- precision(predicted_test_output, reference = test_output_data)
  perf['Recall'] <- recall(predicted_test_output, reference = test_output_data)
  perf['F1'] <- F_meas(predicted_test_output, reference = test_output_data)
  cm <- confusionMatrix(predicted_test_output, test_output_data)
  list(
    model = final_model,
    confusion_matrix = cm,
    performance = perf,
    test = list(
      predicted = predicted_test_output,
      reference = test_output_data
    )
  )
}

# xgb_recommendation <- map(ALL_MHS, trainRPart)
load("data/models/xgb_recommendation.Rdata")
names(xgb_recommendation) <- ALL_MHS
# save(xgb_recommendation, file = "data/models/xgb_recommendation.Rdata")
performances <- bind_rows(map(xgb_recommendation, ~ as.list(.x$performance)))
performances$MH <- ALL_MHS

predictions <- map(xgb_recommendation, ~ .x$test) %>%
  as_tibble() %>%
  mutate(rowname = c("predicted", "reference")) %>%
  gather(MH, value, -rowname) %>% 
  spread(rowname, value) %>%
  unnest() %>% 
  group_by(MH) %>%
  mutate(instance_id = 1:n())
predictions_by_instance <- predictions %>% 
  group_by(instance_id) %>% 
  nest()

Performance of the decision tree models per MH

knitr::kable(
  performances
)
Accuracy Kappa Precision Recall F1 MH
0.9149485 0.4327485 0.6250000 0.3846154 0.4761905 IHC
0.8788660 0.1484871 0.3333333 0.1463415 0.2033898 ISA
0.8943299 0.7172615 0.8000000 0.7755102 0.7875648 TS
0.8427835 0.5607928 0.6976744 0.6315789 0.6629834 ACO
0.9123711 0.2471179 0.3684211 0.2413793 0.2916667 ILS
0.9484536 0.3852004 0.4666667 0.3684211 0.4117647 IG

Performance of the models overall

Micro-level

cm <- confusionMatrix(predictions$predicted, predictions$reference)
knitr::kable(
  t(c(cm$overall[c('Accuracy', 'Kappa')],
      cm$byClass[c('Precision', 'Recall', 'F1')]))
)
Accuracy Kappa Precision Recall F1
0.8986254 0.5346327 0.6653696 0.5327103 0.5916955

Macro-level

predictions_by_instance <- predictions %>% 
  group_by(instance_id) %>% 
  nest() %>%
  mutate(
    predicted_set = map(
      data, 
      ~pull(filter(.x, as.integer(predicted) == 2), MH)
    ),
    reference_set = map(
      data,
      ~ pull(filter(.x, as.integer(reference) == 2), MH)
    )
  ) %>%
  select(-data)

hammingLoss <- function(pred, ref, M = 2) {
  symm_set_diff <- c(setdiff(pred, ref), setdiff(ref, pred))
  length(symm_set_diff) / M
}

classAccuracy <- function(pred, ref) {
  setequal(pred, ref)
}

macroPrecision <- function(pred, ref) {
  length(intersect(pred, ref)) / length(ref)
}

macroRecall <- function(pred, ref) {
  length(intersect(pred, ref)) / length(pred)
}

macroF1 <- function(pred, ref) {
  2 * length(intersect(pred, ref)) / (length(pred) + length(ref))
}

macroAccuracy <- function(pred, ref) {
  length(intersect(pred, ref)) / length(union(pred, ref))
}

M <- length(ALL_MHS)
micro_performances <- predictions_by_instance %>%
  mutate(
    "Hamming loss" = map2_dbl(predicted_set, reference_set,
                            ~hammingLoss(.x, .y, M = M)),
    "Classification Acc." = map2_dbl(predicted_set, reference_set,
                        ~classAccuracy(.x, .y)),
    Precision = map2_dbl(predicted_set, reference_set,
                        ~macroPrecision(.x, .y)),
    Recall = map2_dbl(predicted_set, reference_set,
                        ~macroRecall(.x, .y)),
    F1 = map2_dbl(predicted_set, reference_set,
                        ~macroF1(.x, .y)),
    Accuracy = map2_dbl(predicted_set, reference_set,
                        ~macroAccuracy(.x, .y))
  ) %>%
  select(-instance_id, -predicted_set, -reference_set) %>%
  ungroup() %>%
  summarise_all(funs(mean))

knitr::kable(micro_performances)
Hamming loss Classification Acc. Precision Recall F1 Accuracy
0.1013746 0.5309278 0.9568299 0.9270619 0.9348422 0.8891323

Models details

walk(ALL_MHS, function(mh) {
  model_dt <- xgb_recommendation[[mh]]
  cat('\n\n### ', mh, ' recommendation\n\n')
  xgboost::xgb.importance(model = model_dt$model) %>%
    xgboost::xgb.ggplot.importance(top_n = 10) %>%
    plot()
  xgboost::xgb.ggplot.deepness(model_dt$model)
  cat('\n```\n')
  print(model_dt$confusion_matrix)
  cat('\n```\n')
})

IHC recommendation

Confusion Matrix and Statistics

          Reference
Prediction  no yes
       no   15   9
       yes  24 340
                                          
               Accuracy : 0.9149          
                 95% CI : (0.8826, 0.9407)
    No Information Rate : 0.8995          
    P-Value [Acc > NIR] : 0.17722         
                                          
                  Kappa : 0.4327          
 Mcnemar's Test P-Value : 0.01481         
                                          
            Sensitivity : 0.38462         
            Specificity : 0.97421         
         Pos Pred Value : 0.62500         
         Neg Pred Value : 0.93407         
             Prevalence : 0.10052         
         Detection Rate : 0.03866         
   Detection Prevalence : 0.06186         
      Balanced Accuracy : 0.67941         
                                          
       'Positive' Class : no              
                                          

ISA recommendation

Confusion Matrix and Statistics

          Reference
Prediction  no yes
       no    6  12
       yes  35 335
                                          
               Accuracy : 0.8789          
                 95% CI : (0.8422, 0.9096)
    No Information Rate : 0.8943          
    P-Value [Acc > NIR] : 0.857931        
                                          
                  Kappa : 0.1485          
 Mcnemar's Test P-Value : 0.001332        
                                          
            Sensitivity : 0.14634         
            Specificity : 0.96542         
         Pos Pred Value : 0.33333         
         Neg Pred Value : 0.90541         
             Prevalence : 0.10567         
         Detection Rate : 0.01546         
   Detection Prevalence : 0.04639         
      Balanced Accuracy : 0.55588         
                                          
       'Positive' Class : no              
                                          

TS recommendation

Confusion Matrix and Statistics

          Reference
Prediction  no yes
       no   76  19
       yes  22 271
                                          
               Accuracy : 0.8943          
                 95% CI : (0.8594, 0.9231)
    No Information Rate : 0.7474          
    P-Value [Acc > NIR] : 2.984e-13       
                                          
                  Kappa : 0.7173          
 Mcnemar's Test P-Value : 0.7548          
                                          
            Sensitivity : 0.7755          
            Specificity : 0.9345          
         Pos Pred Value : 0.8000          
         Neg Pred Value : 0.9249          
             Prevalence : 0.2526          
         Detection Rate : 0.1959          
   Detection Prevalence : 0.2448          
      Balanced Accuracy : 0.8550          
                                          
       'Positive' Class : no              
                                          

ACO recommendation

Confusion Matrix and Statistics

          Reference
Prediction  no yes
       no   60  26
       yes  35 267
                                          
               Accuracy : 0.8428          
                 95% CI : (0.8027, 0.8776)
    No Information Rate : 0.7552          
    P-Value [Acc > NIR] : 1.763e-05       
                                          
                  Kappa : 0.5608          
 Mcnemar's Test P-Value : 0.3057          
                                          
            Sensitivity : 0.6316          
            Specificity : 0.9113          
         Pos Pred Value : 0.6977          
         Neg Pred Value : 0.8841          
             Prevalence : 0.2448          
         Detection Rate : 0.1546          
   Detection Prevalence : 0.2216          
      Balanced Accuracy : 0.7714          
                                          
       'Positive' Class : no              
                                          

ILS recommendation

Confusion Matrix and Statistics

          Reference
Prediction  no yes
       no    7  12
       yes  22 347
                                          
               Accuracy : 0.9124          
                 95% CI : (0.8797, 0.9386)
    No Information Rate : 0.9253          
    P-Value [Acc > NIR] : 0.8554          
                                          
                  Kappa : 0.2471          
 Mcnemar's Test P-Value : 0.1227          
                                          
            Sensitivity : 0.24138         
            Specificity : 0.96657         
         Pos Pred Value : 0.36842         
         Neg Pred Value : 0.94038         
             Prevalence : 0.07474         
         Detection Rate : 0.01804         
   Detection Prevalence : 0.04897         
      Balanced Accuracy : 0.60398         
                                          
       'Positive' Class : no              
                                          

IG recommendation

Confusion Matrix and Statistics

          Reference
Prediction  no yes
       no    7   8
       yes  12 361
                                          
               Accuracy : 0.9485          
                 95% CI : (0.9215, 0.9682)
    No Information Rate : 0.951           
    P-Value [Acc > NIR] : 0.6494          
                                          
                  Kappa : 0.3852          
 Mcnemar's Test P-Value : 0.5023          
                                          
            Sensitivity : 0.36842         
            Specificity : 0.97832         
         Pos Pred Value : 0.46667         
         Neg Pred Value : 0.96783         
             Prevalence : 0.04897         
         Detection Rate : 0.01804         
   Detection Prevalence : 0.03866         
      Balanced Accuracy : 0.67337         
                                          
       'Positive' Class : no