Introduction

Nếu Python có thư viện scikit-learn cho các thuật toán học máy thì R có caret. Cho đến hiện tại có 238 thuật toán học máy, trong đó có 191 thuật toán học giám sát (Supervised Machine Learning Algorithms) cho bài toán phân loại được hỗ trợ bởi caret. Danh sách các thuật toán này và các thông tin chi tiết khác có thể đọc ở đây.

Trong Post này chúng ta sẽ đồng thời huấn luyện và so sánh 125 (trong số 191) thuật toán phân loại cho bộ dữ liệu pima được lưu ở Center for Machine Learning and Intelligent Systems thuộc University of California - Irvine. Đây là bộ dữ liệu xuất hiện trong nhiều paper. Mô tả về bộ dữ liệu này có thể đọc ở đây. Chúng ta sử dụng một phiên bản của bộ dữ liệu này từ thư viện faraway:

# Clear workspace: 
rm(list = ls())

# Load pima data: 
library(faraway)
data("pima")

Rồi xem qua dữ liệu:

head(pima, 10)
##    pregnant glucose diastolic triceps insulin  bmi diabetes age test
## 1         6     148        72      35       0 33.6    0.627  50    1
## 2         1      85        66      29       0 26.6    0.351  31    0
## 3         8     183        64       0       0 23.3    0.672  32    1
## 4         1      89        66      23      94 28.1    0.167  21    0
## 5         0     137        40      35     168 43.1    2.288  33    1
## 6         5     116        74       0       0 25.6    0.201  30    0
## 7         3      78        50      32      88 31.0    0.248  26    1
## 8        10     115         0       0       0 35.3    0.134  29    0
## 9         2     197        70      45     543 30.5    0.158  53    1
## 10        8     125        96       0       0  0.0    0.232  54    1

Data Preparing

Không người nào có nồng độ insulin là zero được. Các chỉ số sinh lí khác cũng vậy. Do đó trước hết cần convert những giá trị 0 về NA như sau:

#===================
#  Prepare data
#===================

# Function replaces zero by NA value: 

replace_zero_na <- function(x) {
  
  x[x == 0] <- NA
  
  return(x)
}


# Use the function: 

library(dplyr)

pima %>% mutate(glucose = replace_zero_na(glucose), 
                diastolic = replace_zero_na(diastolic), 
                triceps = replace_zero_na(triceps), 
                insulin = replace_zero_na(insulin), 
                bmi = replace_zero_na(bmi), 
                diabetes = replace_zero_na(diabetes)) -> pima_new

Relabel lại cho cột biến test (1 có nghĩa là bị tiểu đường và ngược lại là 0):

# Relabel for test and rename: 

pima_new %>% 
  mutate(diabete_status = case_when(test == 1 ~ "Yes", TRUE ~ "No"), 
         diabete_status = as.factor(diabete_status), 
         test = NULL) -> pima_new

Xem qua dữ liệu:

head(pima_new, 10)
##    pregnant glucose diastolic triceps insulin  bmi diabetes age diabete_status
## 1         6     148        72      35      NA 33.6    0.627  50            Yes
## 2         1      85        66      29      NA 26.6    0.351  31             No
## 3         8     183        64      NA      NA 23.3    0.672  32            Yes
## 4         1      89        66      23      94 28.1    0.167  21             No
## 5         0     137        40      35     168 43.1    2.288  33            Yes
## 6         5     116        74      NA      NA 25.6    0.201  30             No
## 7         3      78        50      32      88 31.0    0.248  26            Yes
## 8        10     115        NA      NA      NA 35.3    0.134  29             No
## 9         2     197        70      45     543 30.5    0.158  53            Yes
## 10        8     125        96      NA      NA   NA    0.232  54            Yes

Sử dụng thuật toán Random Forest để “lấp đầy” (data imputation) những quan sát là NA:

# Impute NA values by Random Forest: 

missRanger::missRanger(pima_new, seed = 29) -> pima_RFimputed
## 
## Missing value imputation by random forests
## 
##   Variables to impute:       glucose, diastolic, triceps, insulin, bmi
##   Variables used to impute:  pregnant, glucose, diastolic, triceps, insulin, bmi, diabetes, age, diabete_status
## iter 1:  .....
## iter 2:  .....
## iter 3:  .....

Dữ liệu sau khi được xử lí missing data:

head(pima_RFimputed, 10)
##    pregnant glucose diastolic  triceps   insulin      bmi diabetes age
## 1         6     148   72.0000 35.00000 186.62733 33.60000    0.627  50
## 2         1      85   66.0000 29.00000  66.11457 26.60000    0.351  31
## 3         8     183   64.0000 24.11980 219.50917 23.30000    0.672  32
## 4         1      89   66.0000 23.00000  94.00000 28.10000    0.167  21
## 5         0     137   40.0000 35.00000 168.00000 43.10000    2.288  33
## 6         5     116   74.0000 24.84307 106.75267 25.60000    0.201  30
## 7         3      78   50.0000 32.00000  88.00000 31.00000    0.248  26
## 8        10     115   74.2445 35.30387 170.13053 35.30000    0.134  29
## 9         2     197   70.0000 45.00000 543.00000 30.50000    0.158  53
## 10        8     125   96.0000 29.62103 222.13257 31.48421    0.232  54
##    diabete_status
## 1             Yes
## 2              No
## 3             Yes
## 4              No
## 5             Yes
## 6              No
## 7             Yes
## 8              No
## 9             Yes
## 10            Yes

Chuẩn hóa 0-1 cho tất cả các features:

# Scale data: 

pima_RFimputed %>% 
  mutate_if(is.numeric, function(x) {(x - min(x)) / (max(x) - min(x))}) -> pima_RFimputed

Install Required Packages

Để huấn luyện 125 thuật toán đã nói ở trên thì việc tiếp theo là phải cài đặt những thư viện (package) cần thiết tương ứng với những thuật toán này. Trước hết lấy ra danh sách các thuật toán phân loại được hỗ trợ bởi thư viện caret cùng package tương ứng:

#============================
#  Compare/chek many models
#============================

library(caret)

getModelInfo() -> models

models %>% names() -> all_models

library(stringr)

get_model_info <- function(model_name) {
  
  models[[model_name]]$label -> full_name
  
  models[[model_name]]$library %>% str_flatten(collapse = "|") -> libr
  
  models[[model_name]]$type %>% str_flatten(collapse = "|") -> reg_cla
  
  return(data.frame(model = model_name, full_name = full_name, library = libr, type = reg_cla))
  
}

lapply(all_models, get_model_info) -> modelList

do.call("bind_rows", modelList) -> df_models

df_models %>% 
  filter(str_detect(type, "Classification")) -> class_models

Danh sách (một số) thuật toán phân loại này:

library(kableExtra) # For presenting table. 

class_models %>% 
  head(10) %>% 
  select(2, 4) %>% 
  kbl(caption = "Table 1: Some Machine Learning Algorithms for Classification", 
      col.names = c("Model Name", "Type"), 
      escape = TRUE) %>%
  kable_classic(full_width = FALSE, html_font = "Cambria")
Table 1: Some Machine Learning Algorithms for Classification
Model Name Type
Boosted Classification Trees Classification
Bagged AdaBoost Classification
AdaBoost.M1 Classification
AdaBoost Classification Trees Classification
Adaptive Mixture Discriminant Analysis Classification
Model Averaged Neural Network Classification|Regression
Naive Bayes Classifier with Attribute Weighting Classification
Tree Augmented Naive Bayes Classifier with Attribute Weighting Classification
Bagged Model Regression|Classification
Bagged MARS Regression|Classification

Một số thuật toán như Model Averaged Neural Network có thể sử dụng cho cả bài toán phân loại và hồi quy như ta thấy ở Table 1. So sánh các thư viện cần phải cài đặt với những thư viện đã được cài:

# R packages installed: 
installed.packages() -> libr_installed

libr_installed %>% as.data.frame() %>% row.names() -> my_Rpackages

# R packages required by caret: 
class_models$library %>% 
  str_split(pattern = "\\|", simplify = TRUE) %>% 
  as.vector() %>% 
  unique() -> required_caret

# R package must be installed: 
required_caret[str_detect(required_caret, "")] -> required_caret

required_caret[!required_caret %in% my_Rpackages] -> new_packages_inst

Rồi cài đặt các packages còn thiếu:

install.packages(new_packages_inst, dependencies = TRUE)

Trong số các packages trên có một số packages đã bị gỡ phải CRAN (tức là phải cài đặt theo một cách thức khác). Các thuật toán tương ứng với những package bị gỡ này:

removed_lib <- c("adaptDA", "CHAID", "elmNN", "gpls", 
                 "logicFS", "FCNN4R", "mxnet", "randomGLM", "rrlda", "vbmp")

Do vậy mà các thuật toán còn lại là:

class_models %>% 
  filter(!library %in% removed_lib) %>% 
  pull(model) -> my_models

my_models[!my_models %in% c("gaussprLinear", "gaussprRadial", "mlpSGD", "rbf")] -> my_models

Comparision Method

Để đánh giá - so sánh cũng như tìm tham số tối ưu của thuật toán chúng ta sử dụng cách tiếp cận gọi là K-Fold Cross Validation. Dưới đây là minh họa cho quán trình này với K = 10:

Ở đây dữ liệu ban đầu sẽ được chia thành 10 phần (bằng hoặc xấp xỉ bằng nhau). Trong đó 9 phần để huấn luyện mô hình và 1 phần còn lại được sử dụng để đánh giá ngược lại mô hình đã có căn cứ vào tiêu chí E nào đó. Nếu là bài toán phân loại nhị phân thì E thường là ROC-AUC. Với K = 10 thì sẽ có 10 giá trị E và các mô hình cũng như tham số tối ưu sẽ được lựa chọn bằng cách so sánh giá trị trung bình của E.

Nếu bộ dữ liệu mà lớn thì K = 10 là giá trị thường được lựa chọn. Ngược lại thì K = 4 (tương ứng với 75% training data và 25% testing data) hoặc K = 5 (80% training data và 20% testing data) thường được lựa chọn. Trong post này chúng ta sẽ chọn K = 4 (number = 4) và lặp lại quá trình 2 lần (repeats = 2):

# Conditions for training and comparing: 

fitControl1 <- trainControl(method = "repeatedcv", 
                            number = 4,
                            repeats = 2,
                            classProbs = TRUE, 
                            allowParallel = FALSE, 
                            summaryFunction = twoClassSummary)

Như vậy nếu một thuật toán mà, giả sử, có 3 tham số, mỗi tham số chọn 10 ứng viên để train, thì tất cả sẽ có 10×10×10×4*2 + 1 = 8001 “mô hình con” được huấn luyện. 8001 vì sau khi có tham số tối ưu rồi mô hình cuối cùng sẽ được refit (tức là huấn luyện lại).

Để rút ngắn thời gian huấn luyện, chúng ta chỉ chọn 3 ứng viên cho một tham số của mô hình học máy để huấn luyện, thể hiện qua tuneLength = 3 của hàm dưới đây:

#======================
#  Check caret models
#======================
check_caret <- function(method_name) {
  
  set.seed(1)
  
  system.time(train(diabete_status ~.,
                    data = pima_RFimputed,
                    method = method_name, 
                    metric = "ROC",
                    tuneLength = 3, 
                    trControl = fitControl1) -> caret_model) -> time_train
  
  caret_model$results %>% 
    slice(which.max(ROC)) %>% 
    mutate(method = method_name) %>% 
    select(method, max_ROC = ROC, Sens, Spec) %>% 
    mutate(train_time = time_train[3]) %>% 
    return()
  
}

Nếu đủ kiên trì và máy tính đủ mạnh thì có thể, ví dụ, chọn tuneLength = 15. Có thể ước lượng rằng với lựa chọn này thì thời gian huấn luyện sẽ tăng lên 5 lần.

Trong quá trình huấn luyện có thể một thuật toán nào đó không thực hiện được (lí do có thể là còn thiếu dependencies hay chưa được update thư viện chẳng hạn). Tạm thời chúng ta “bỏ qua” các thuật toán này bằng cách viết hàm check_caret_tryCatch() như sau:

# Training all models selected: 

check_caret_tryCatch <- function(j) {tryCatch(check_caret(my_models[j]), error = function(e) {NULL})}

Huấn luyện các thuận toán đồng thời theo dõi tiến độ. Lưu ý rằng quá trình huấn luyện này có thể mất từ 8h đến vài ba ngày tùy cấu hình máy tính:

n_models <- length(my_models)

pb <- txtProgressBar(min = 0, max = n_models, style = 3)

space_df <- data.frame()

system.time(
  
  for(j in 1:n_models) {
    
    print(paste0(j, "_th:"))
    
    mess <- paste("Method", my_models[j], "is training...")
    
    print(mess)
    
    df_results <- check_caret_tryCatch(j)
    
    space_df <- bind_rows(space_df, df_results)
    
    print(df_results)
    
    setTxtProgressBar(pb, j)
  }
)

close(pb)

Lưu lại kết quả:

write.csv(space_df, "space_df.csv", row.names = FALSE)

Main Results

Nếu chọn ROC-AUC là tiêu chí đánh giá và lựa chọn thuật toán thì gamboost, ranger, và bam là những thuật toán có xếp hạng cao nhất và các kết quả chênh lệch nhau không đáng kể. Tuy nhiên thời gian huân luyện thì lại có chênh lệch rất lớn. Tạo ra kết quả tương tự nhau nhưng gamboost chỉ mất 5.72 giây nhưng bam thì cần 1717.86 giây (Table 2):

space_df %>% 
  arrange(-max_ROC) %>% 
  mutate(rank = 1:nrow(.)) %>% 
  mutate_if(is.numeric, function(x) {round(x, 3)}) %>% 
  kbl(caption = "Table 2: Model Performance in Descending Order of ROC", escape = TRUE) %>%
  kable_classic(full_width = FALSE, html_font = "Cambria")
Table 2: Model Performance in Descending Order of ROC
method max_ROC Sens Spec train_time rank
gamboost 0.861 0.864 0.642 5.72 1
ranger 0.860 0.857 0.655 14.64 2
bam 0.859 0.854 0.674 1717.86 3
nnet 0.859 0.865 0.618 4.69 4
gamSpline 0.859 0.856 0.632 5.58 5
rf 0.858 0.854 0.662 9.20 6
parRF 0.858 0.852 0.662 7.72 7
gam 0.858 0.849 0.675 41.02 8
cforest 0.858 0.861 0.644 70.85 9
extraTrees 0.858 0.852 0.666 24.59 10
Rborist 0.857 0.853 0.655 270.75 11
ORFpls 0.857 0.846 0.655 473.31 12
xgbDART 0.857 0.866 0.619 4367.11 13
bagFDAGCV 0.857 0.854 0.649 10.09 14
rotationForestCp 0.857 0.868 0.660 20.22 15
ORFridge 0.856 0.832 0.692 999.31 16
ORFsvm 0.856 0.863 0.627 2340.72 17
bagFDA 0.856 0.856 0.625 27.53 18
ada 0.856 0.842 0.670 71.13 19
loclda 0.856 0.890 0.543 28.56 20
ORFlog 0.855 0.834 0.694 1054.09 21
rotationForest 0.855 0.854 0.653 11.03 22
avNNet 0.855 0.821 0.707 21.25 23
xgbTree 0.855 0.851 0.662 109.29 24
gamLoess 0.855 0.848 0.649 3.11 25
dwdRadial 0.854 0.864 0.595 12.89 26
gbm 0.854 0.859 0.634 2.77 27
gaussprPoly 0.853 0.867 0.582 108.61 28
dwdPoly 0.853 0.875 0.590 90.92 29
wsrf 0.852 0.848 0.644 7.94 30
earth 0.852 0.852 0.649 2.86 31
gcvEarth 0.852 0.852 0.649 1.36 32
AdaBoost.M1 0.851 0.857 0.625 3395.17 33
RRF 0.851 0.842 0.679 185.86 34
RRFglobal 0.851 0.839 0.681 30.85 35
fda 0.851 0.852 0.644 1.47 36
nodeHarvest 0.850 0.883 0.543 205.51 37
svmPoly 0.849 0.888 0.543 25.56 38
pcaNNet 0.849 0.849 0.627 5.32 39
mlp 0.847 0.835 0.659 5.73 40
mlpML 0.847 0.835 0.659 6.17 41
naive_bayes 0.847 0.787 0.711 1.06 42
nb 0.847 0.788 0.711 4.90 43
svmRadialSigma 0.847 0.879 0.576 12.50 44
mlpWeightDecay 0.846 0.861 0.619 15.25 45
mlpWeightDecayML 0.846 0.861 0.619 15.69 46
blackboost 0.846 0.879 0.549 28.14 47
AdaBag 0.844 0.797 0.694 1189.19 48
treebag 0.844 0.828 0.660 2.88 49
svmRadial 0.843 0.881 0.550 4.84 50
svmRadialWeights 0.843 0.883 0.545 10.30 51
svmRadialCost 0.843 0.880 0.550 4.86 52
spls 0.842 0.877 0.552 4.46 53
kernelpls 0.842 0.876 0.556 1.12 54
pls 0.842 0.876 0.556 1.17 55
simpls 0.842 0.876 0.556 1.33 56
widekernelpls 0.842 0.876 0.556 25.18 57
regLogistic 0.842 0.881 0.547 3.40 58
glmnet 0.842 0.884 0.543 2.28 59
multinom 0.842 0.875 0.554 1.07 60
mda 0.842 0.872 0.590 2.03 61
pda2 0.842 0.876 0.556 1.25 62
plr 0.842 0.874 0.560 1.48 63
ordinalNet 0.842 0.884 0.534 1159.87 64
glmStepAIC 0.841 0.878 0.556 2.32 65
bayesglm 0.841 0.874 0.560 1.98 66
dwdLinear 0.841 0.878 0.567 2.17 67
glmboost 0.841 0.876 0.560 1.59 68
glm 0.841 0.874 0.563 0.99 69
vglmAdjCat 0.841 0.874 0.563 2.95 70
vglmContRatio 0.841 0.874 0.563 12.13 71
vglmCumulative 0.841 0.874 0.563 10.44 72
pda 0.841 0.876 0.558 1.20 73
sparseLDA 0.841 0.000 1.000 1.92 74
sda 0.841 0.876 0.563 1.35 75
hdda 0.841 0.876 0.556 1.45 76
lda 0.841 0.876 0.556 0.99 77
lda2 0.841 0.876 0.556 0.97 78
rda 0.841 0.876 0.556 7.42 79
monmlp 0.840 0.857 0.623 59.94 80
msaenet 0.840 0.819 0.666 79.53 81
svmLinearWeights 0.840 0.764 0.735 5.04 82
svmLinear 0.839 0.882 0.543 2.84 83
svmLinear2 0.839 0.878 0.556 2.14 84
C5.0 0.839 0.839 0.625 4.65 85
pam 0.835 0.929 0.414 1.08 86
xgbLinear 0.834 0.826 0.668 117.73 87
adaboost 0.832 0.834 0.662 118.41 88
LMT 0.831 0.864 0.567 6.09 89
knn 0.825 0.830 0.604 1.15 90
slda 0.822 0.858 0.552 1.14 91
PART 0.816 0.813 0.674 3.51 92
qda 0.816 0.839 0.575 0.93 93
ctree2 0.815 0.790 0.698 2.66 94
ctree 0.814 0.833 0.610 1.53 95
Linda 0.811 0.754 0.707 3.43 96
kknn 0.811 0.811 0.601 7.49 97
hda 0.809 0.867 0.511 26.33 98
xyf 0.801 0.876 0.552 16.91 99
rbfDDA 0.800 0.848 0.616 21.88 100
C5.0Tree 0.800 0.803 0.631 1.08 101
LogitBoost 0.798 0.811 0.621 1.95 102
sdwd 0.794 1.000 0.000 1.79 103
stepLDA 0.794 0.879 0.500 9.47 104
stepQDA 0.794 0.887 0.487 9.39 105
J48 0.793 0.815 0.623 7.99 106
rpart2 0.792 0.800 0.638 1.44 107
rpart1SE 0.787 0.809 0.618 1.22 108
rpart 0.775 0.827 0.586 1.40 109
JRip 0.769 0.828 0.700 16.94 110
C5.0Rules 0.768 0.828 0.625 1.07 111
evtree 0.767 0.841 0.618 126.20 112
QdaCov 0.764 0.683 0.743 2.60 113
rmda 0.745 0.554 0.815 16.53 114
dnn 0.731 1.000 0.000 51.50 115
PRIM 0.725 0.777 0.563 260.00 116
OneR 0.638 0.822 0.453 1.21 117
mlpKerasDecay 0.627 0.517 0.521 231.19 118
mlpKerasDecayCost 0.626 0.254 0.746 572.08 119
mlpKerasDropoutCost 0.598 0.214 0.787 577.20 120
mlpKerasDropout 0.526 0.750 0.252 195.55 121
null 0.500 1.000 0.000 0.84 122
bagEarth 0.195 0.102 0.532 42.21 123
plsRglm 0.168 0.865 0.558 8.37 124
bagEarthGCV 0.143 0.147 0.353 12.43 125

Final Notes

Kết quả của 125 thuật toán học máy cho bộ dữ liệu pima dẫn đến một số ngụ ý sau:

  • Có rất nhiều bài viết kiểu như Top 10 Machine Learning Algorithms, LightGBM vs XGBOOST – Which algorithm is better, Kaggle Competitions Top Classification Algorithm nói về thuật toán nào là “tốt nhất”. Tuy nhiên kết quả từ Table 2 cho thấy chỉ có right algorithm for right data mà thôi (nhại lại câu thành ngữ right tool for the job).

  • Thông thường các dự án data science bị giới hạn bởi thời gian và có deadline rõ ràng. Mặt khác chúng ta hoàn toàn không thể biết trước thuật toán nào là tốt nhất với bộ dữ liệu của dự án. Những kết quả ở trên có thể cho một gợi ý: huấn luyện đồng thời chỉ một số (hoặc tất cả nếu đủ thời gian) mô hình lựa chọn trước xem kết quả ra sao để chọn ra mô hình có vẻ ổn nhất. Kế tiếp lại tinh chỉnh sâu hơn nữa các tham số cho mô hình ổn nhất đó. Chẳng hạn từ Table 2 thì gamboost (Gradient Boosting with Componentwise Smoothing Splines) là thuật toán có ROC cao nhất trên bộ dữ liệu pima. Chúng ta có thể tinh chỉnh cẩn thận hơn nữa các tham số của thuật toán này để có ROC cao hơn.

  • Table 2 cũng cho thấy có một số thuật toán mà thời gian huấn luyện là rất lớn trong khi model perfornamce thì không được tốt như kì vọng, thậm chí là thua kém hơn so với những thuật toán đơn giản khác. Những thuật toán này có lẽ không nên được lựa chọn khi mà nguồn lực thời gian cho dự án là không nhiều hoặc deadline là sắp hết.

An Extension

Ngoài tinh chỉnh tham số cho một thuật toán được chọn thì chúng ta cũng còn một lựa chọn khác: sử dụng Ensemble Learning Methods. Thay vì sử dụng For Loop như trên chúng ta có thể sử dụng hàm caretList() của thư viện caretEnsemble để huấn luyện đồng thời nhiều thuật toán. Dưới đây là R codes cho cách tiếp cận này (thời gian huấn luyện khá nhiều nên không hiển thị kết quả):

# Algorithms with training time <= 300s: 

df_space %>% 
  filter(train_time <= 300) %>% 
  pull(method) -> some_models_selected

#===================
#     Ensemble
#===================

set.seed(1)
number <- 5
repeats <- 3

library(caret)
control <- trainControl(method = "repeatedcv", 
                        number = number , 
                        repeats = repeats, 
                        classProbs = TRUE, 
                        allowParallel = FALSE, 
                        summaryFunction = twoClassSummary, 
                        savePredictions = "final", 
                        index = createResample(pima_RFimputed$diabete_status, repeats*number))


# Train these ML Models: 

library(caretEnsemble)
set.seed(1)
system.time(model_list1 <- caretList(diabete_status ~., 
                                     data = pima_RFimputed,
                                     trControl = control, 
                                     metric = "ROC", 
                                     methodList = some_models_selected))

# Extract results: 

list_of_results <- lapply(some_models_selected, function(x) {model_list1[[x]]$resample})

# Convert to data frame: 
total_df <- do.call("bind_rows", list_of_results)
total_df %>% mutate(Model = lapply(some_models_selected, function(x) {rep(x, number*repeats)}) %>% unlist()) -> total_df

# Save results: 
write.csv(total_df, "total_df.csv", row.names = FALSE)

# Ensemble 1: M
greedy_ensemble <- caretEnsemble(model_list1, 
                                 metric = "ROC",
                                 trControl = control)
# Draft Results: 
summary(greedy_ensemble)

# Add Ensemble 1 Model: 
total_df_en <- bind_rows(total_df, greedy_ensemble$ens_model$resample %>% mutate(Model = "ensemble"))

# Ensemble 2: 
gbm_ensemble <- caretStack(model_list1, 
                           method = "gbm",                                  
                           metric = "ROC",
                           verbose = FALSE, 
                           trControl = control)

# Add Ensemble 2 Model: 
total_df_en2 <- bind_rows(total_df_en, gbm_ensemble$ens_model$resample %>% mutate(Model = "ensemble2"))

# Average ROC based on 15 samples: 

total_df_en2 %>% 
  group_by(Model) %>% 
  summarise(avg_roc = mean(ROC)) %>% 
  ungroup() %>% 
  arrange(-avg_roc) %>% 
  mutate(rank = 1:nrow(.)) %>% 
  mutate_if(is.numeric, function(x) {round(x, 3)}) -> df_with_en2


df_with_en2 %>% 
  filter(stringr::str_detect(Model, "ensemble"))

# Rank = 14 for ensemble 1, = 14 for ensemble 2: 

df_with_en2 %>% 
  mutate_if(is.numeric, function(x) {round(x, 3)}) %>% 
  kbl(caption = "Table 3: Ensemble Approach by ROC", escape = TRUE) %>%
  kable_classic(full_width = FALSE, html_font = "Cambria")