
Introduction
Nếu Python có thư viện scikit-learn cho các thuật toán học máy thì R có caret. Cho đến hiện tại có 238 thuật toán học máy, trong đó có 191 thuật toán học giám sát (Supervised Machine Learning Algorithms) cho bài toán phân loại được hỗ trợ bởi caret. Danh sách các thuật toán này và các thông tin chi tiết khác có thể đọc ở đây.
Trong Post này chúng ta sẽ đồng thời huấn luyện và so sánh 125 (trong số 191) thuật toán phân loại cho bộ dữ liệu pima được lưu ở Center for Machine Learning and Intelligent Systems thuộc University of California - Irvine. Đây là bộ dữ liệu xuất hiện trong nhiều paper. Mô tả về bộ dữ liệu này có thể đọc ở đây. Chúng ta sử dụng một phiên bản của bộ dữ liệu này từ thư viện faraway:
# Clear workspace:
rm(list = ls())
# Load pima data:
library(faraway)
data("pima")
Rồi xem qua dữ liệu:
## pregnant glucose diastolic triceps insulin bmi diabetes age test
## 1 6 148 72 35 0 33.6 0.627 50 1
## 2 1 85 66 29 0 26.6 0.351 31 0
## 3 8 183 64 0 0 23.3 0.672 32 1
## 4 1 89 66 23 94 28.1 0.167 21 0
## 5 0 137 40 35 168 43.1 2.288 33 1
## 6 5 116 74 0 0 25.6 0.201 30 0
## 7 3 78 50 32 88 31.0 0.248 26 1
## 8 10 115 0 0 0 35.3 0.134 29 0
## 9 2 197 70 45 543 30.5 0.158 53 1
## 10 8 125 96 0 0 0.0 0.232 54 1
Data Preparing
Không người nào có nồng độ insulin là zero được. Các chỉ số sinh lí khác cũng vậy. Do đó trước hết cần convert những giá trị 0 về NA như sau:
#===================
# Prepare data
#===================
# Function replaces zero by NA value:
replace_zero_na <- function(x) {
x[x == 0] <- NA
return(x)
}
# Use the function:
library(dplyr)
pima %>% mutate(glucose = replace_zero_na(glucose),
diastolic = replace_zero_na(diastolic),
triceps = replace_zero_na(triceps),
insulin = replace_zero_na(insulin),
bmi = replace_zero_na(bmi),
diabetes = replace_zero_na(diabetes)) -> pima_new
Relabel lại cho cột biến test (1 có nghĩa là bị tiểu đường và ngược lại là 0):
# Relabel for test and rename:
pima_new %>%
mutate(diabete_status = case_when(test == 1 ~ "Yes", TRUE ~ "No"),
diabete_status = as.factor(diabete_status),
test = NULL) -> pima_new
Xem qua dữ liệu:
## pregnant glucose diastolic triceps insulin bmi diabetes age diabete_status
## 1 6 148 72 35 NA 33.6 0.627 50 Yes
## 2 1 85 66 29 NA 26.6 0.351 31 No
## 3 8 183 64 NA NA 23.3 0.672 32 Yes
## 4 1 89 66 23 94 28.1 0.167 21 No
## 5 0 137 40 35 168 43.1 2.288 33 Yes
## 6 5 116 74 NA NA 25.6 0.201 30 No
## 7 3 78 50 32 88 31.0 0.248 26 Yes
## 8 10 115 NA NA NA 35.3 0.134 29 No
## 9 2 197 70 45 543 30.5 0.158 53 Yes
## 10 8 125 96 NA NA NA 0.232 54 Yes
Sử dụng thuật toán Random Forest để “lấp đầy” (data imputation) những quan sát là NA:
# Impute NA values by Random Forest:
missRanger::missRanger(pima_new, seed = 29) -> pima_RFimputed
##
## Missing value imputation by random forests
##
## Variables to impute: glucose, diastolic, triceps, insulin, bmi
## Variables used to impute: pregnant, glucose, diastolic, triceps, insulin, bmi, diabetes, age, diabete_status
## iter 1: .....
## iter 2: .....
## iter 3: .....
Dữ liệu sau khi được xử lí missing data:
## pregnant glucose diastolic triceps insulin bmi diabetes age
## 1 6 148 72.0000 35.00000 186.62733 33.60000 0.627 50
## 2 1 85 66.0000 29.00000 66.11457 26.60000 0.351 31
## 3 8 183 64.0000 24.11980 219.50917 23.30000 0.672 32
## 4 1 89 66.0000 23.00000 94.00000 28.10000 0.167 21
## 5 0 137 40.0000 35.00000 168.00000 43.10000 2.288 33
## 6 5 116 74.0000 24.84307 106.75267 25.60000 0.201 30
## 7 3 78 50.0000 32.00000 88.00000 31.00000 0.248 26
## 8 10 115 74.2445 35.30387 170.13053 35.30000 0.134 29
## 9 2 197 70.0000 45.00000 543.00000 30.50000 0.158 53
## 10 8 125 96.0000 29.62103 222.13257 31.48421 0.232 54
## diabete_status
## 1 Yes
## 2 No
## 3 Yes
## 4 No
## 5 Yes
## 6 No
## 7 Yes
## 8 No
## 9 Yes
## 10 Yes
Chuẩn hóa 0-1 cho tất cả các features:
# Scale data:
pima_RFimputed %>%
mutate_if(is.numeric, function(x) {(x - min(x)) / (max(x) - min(x))}) -> pima_RFimputed
Install Required Packages
Để huấn luyện 125 thuật toán đã nói ở trên thì việc tiếp theo là phải cài đặt những thư viện (package) cần thiết tương ứng với những thuật toán này. Trước hết lấy ra danh sách các thuật toán phân loại được hỗ trợ bởi thư viện caret cùng package tương ứng:
#============================
# Compare/chek many models
#============================
library(caret)
getModelInfo() -> models
models %>% names() -> all_models
library(stringr)
get_model_info <- function(model_name) {
models[[model_name]]$label -> full_name
models[[model_name]]$library %>% str_flatten(collapse = "|") -> libr
models[[model_name]]$type %>% str_flatten(collapse = "|") -> reg_cla
return(data.frame(model = model_name, full_name = full_name, library = libr, type = reg_cla))
}
lapply(all_models, get_model_info) -> modelList
do.call("bind_rows", modelList) -> df_models
df_models %>%
filter(str_detect(type, "Classification")) -> class_models
Danh sách (một số) thuật toán phân loại này:
library(kableExtra) # For presenting table.
class_models %>%
head(10) %>%
select(2, 4) %>%
kbl(caption = "Table 1: Some Machine Learning Algorithms for Classification",
col.names = c("Model Name", "Type"),
escape = TRUE) %>%
kable_classic(full_width = FALSE, html_font = "Cambria")
Table 1: Some Machine Learning Algorithms for Classification
Model Name
|
Type
|
Boosted Classification Trees
|
Classification
|
Bagged AdaBoost
|
Classification
|
AdaBoost.M1
|
Classification
|
AdaBoost Classification Trees
|
Classification
|
Adaptive Mixture Discriminant Analysis
|
Classification
|
Model Averaged Neural Network
|
Classification|Regression
|
Naive Bayes Classifier with Attribute Weighting
|
Classification
|
Tree Augmented Naive Bayes Classifier with Attribute Weighting
|
Classification
|
Bagged Model
|
Regression|Classification
|
Bagged MARS
|
Regression|Classification
|
Một số thuật toán như Model Averaged Neural Network có thể sử dụng cho cả bài toán phân loại và hồi quy như ta thấy ở Table 1. So sánh các thư viện cần phải cài đặt với những thư viện đã được cài:
# R packages installed:
installed.packages() -> libr_installed
libr_installed %>% as.data.frame() %>% row.names() -> my_Rpackages
# R packages required by caret:
class_models$library %>%
str_split(pattern = "\\|", simplify = TRUE) %>%
as.vector() %>%
unique() -> required_caret
# R package must be installed:
required_caret[str_detect(required_caret, "")] -> required_caret
required_caret[!required_caret %in% my_Rpackages] -> new_packages_inst
Rồi cài đặt các packages còn thiếu:
install.packages(new_packages_inst, dependencies = TRUE)
Trong số các packages trên có một số packages đã bị gỡ phải CRAN (tức là phải cài đặt theo một cách thức khác). Các thuật toán tương ứng với những package bị gỡ này:
removed_lib <- c("adaptDA", "CHAID", "elmNN", "gpls",
"logicFS", "FCNN4R", "mxnet", "randomGLM", "rrlda", "vbmp")
Do vậy mà các thuật toán còn lại là:
class_models %>%
filter(!library %in% removed_lib) %>%
pull(model) -> my_models
my_models[!my_models %in% c("gaussprLinear", "gaussprRadial", "mlpSGD", "rbf")] -> my_models
Comparision Method
Để đánh giá - so sánh cũng như tìm tham số tối ưu của thuật toán chúng ta sử dụng cách tiếp cận gọi là K-Fold Cross Validation. Dưới đây là minh họa cho quán trình này với K = 10:

Ở đây dữ liệu ban đầu sẽ được chia thành 10 phần (bằng hoặc xấp xỉ bằng nhau). Trong đó 9 phần để huấn luyện mô hình và 1 phần còn lại được sử dụng để đánh giá ngược lại mô hình đã có căn cứ vào tiêu chí E nào đó. Nếu là bài toán phân loại nhị phân thì E thường là ROC-AUC. Với K = 10 thì sẽ có 10 giá trị E và các mô hình cũng như tham số tối ưu sẽ được lựa chọn bằng cách so sánh giá trị trung bình của E.
Nếu bộ dữ liệu mà lớn thì K = 10 là giá trị thường được lựa chọn. Ngược lại thì K = 4 (tương ứng với 75% training data và 25% testing data) hoặc K = 5 (80% training data và 20% testing data) thường được lựa chọn. Trong post này chúng ta sẽ chọn K = 4 (number = 4
) và lặp lại quá trình 2 lần (repeats = 2
):
# Conditions for training and comparing:
fitControl1 <- trainControl(method = "repeatedcv",
number = 4,
repeats = 2,
classProbs = TRUE,
allowParallel = FALSE,
summaryFunction = twoClassSummary)
Như vậy nếu một thuật toán mà, giả sử, có 3 tham số, mỗi tham số chọn 10 ứng viên để train, thì tất cả sẽ có 10×10×10×4*2 + 1 = 8001 “mô hình con” được huấn luyện. 8001 vì sau khi có tham số tối ưu rồi mô hình cuối cùng sẽ được refit (tức là huấn luyện lại).
Để rút ngắn thời gian huấn luyện, chúng ta chỉ chọn 3 ứng viên cho một tham số của mô hình học máy để huấn luyện, thể hiện qua tuneLength = 3
của hàm dưới đây:
#======================
# Check caret models
#======================
check_caret <- function(method_name) {
set.seed(1)
system.time(train(diabete_status ~.,
data = pima_RFimputed,
method = method_name,
metric = "ROC",
tuneLength = 3,
trControl = fitControl1) -> caret_model) -> time_train
caret_model$results %>%
slice(which.max(ROC)) %>%
mutate(method = method_name) %>%
select(method, max_ROC = ROC, Sens, Spec) %>%
mutate(train_time = time_train[3]) %>%
return()
}
Nếu đủ kiên trì và máy tính đủ mạnh thì có thể, ví dụ, chọn tuneLength = 15
. Có thể ước lượng rằng với lựa chọn này thì thời gian huấn luyện sẽ tăng lên 5 lần.
Trong quá trình huấn luyện có thể một thuật toán nào đó không thực hiện được (lí do có thể là còn thiếu dependencies hay chưa được update thư viện chẳng hạn). Tạm thời chúng ta “bỏ qua” các thuật toán này bằng cách viết hàm check_caret_tryCatch()
như sau:
# Training all models selected:
check_caret_tryCatch <- function(j) {tryCatch(check_caret(my_models[j]), error = function(e) {NULL})}
Huấn luyện các thuận toán đồng thời theo dõi tiến độ. Lưu ý rằng quá trình huấn luyện này có thể mất từ 8h đến vài ba ngày tùy cấu hình máy tính:
n_models <- length(my_models)
pb <- txtProgressBar(min = 0, max = n_models, style = 3)
space_df <- data.frame()
system.time(
for(j in 1:n_models) {
print(paste0(j, "_th:"))
mess <- paste("Method", my_models[j], "is training...")
print(mess)
df_results <- check_caret_tryCatch(j)
space_df <- bind_rows(space_df, df_results)
print(df_results)
setTxtProgressBar(pb, j)
}
)
close(pb)
Lưu lại kết quả:
write.csv(space_df, "space_df.csv", row.names = FALSE)
Main Results
Nếu chọn ROC-AUC là tiêu chí đánh giá và lựa chọn thuật toán thì gamboost, ranger, và bam là những thuật toán có xếp hạng cao nhất và các kết quả chênh lệch nhau không đáng kể. Tuy nhiên thời gian huân luyện thì lại có chênh lệch rất lớn. Tạo ra kết quả tương tự nhau nhưng gamboost chỉ mất 5.72 giây nhưng bam thì cần 1717.86 giây (Table 2):
space_df %>%
arrange(-max_ROC) %>%
mutate(rank = 1:nrow(.)) %>%
mutate_if(is.numeric, function(x) {round(x, 3)}) %>%
kbl(caption = "Table 2: Model Performance in Descending Order of ROC", escape = TRUE) %>%
kable_classic(full_width = FALSE, html_font = "Cambria")
Table 2: Model Performance in Descending Order of ROC
method
|
max_ROC
|
Sens
|
Spec
|
train_time
|
rank
|
gamboost
|
0.861
|
0.864
|
0.642
|
5.72
|
1
|
ranger
|
0.860
|
0.857
|
0.655
|
14.64
|
2
|
bam
|
0.859
|
0.854
|
0.674
|
1717.86
|
3
|
nnet
|
0.859
|
0.865
|
0.618
|
4.69
|
4
|
gamSpline
|
0.859
|
0.856
|
0.632
|
5.58
|
5
|
rf
|
0.858
|
0.854
|
0.662
|
9.20
|
6
|
parRF
|
0.858
|
0.852
|
0.662
|
7.72
|
7
|
gam
|
0.858
|
0.849
|
0.675
|
41.02
|
8
|
cforest
|
0.858
|
0.861
|
0.644
|
70.85
|
9
|
extraTrees
|
0.858
|
0.852
|
0.666
|
24.59
|
10
|
Rborist
|
0.857
|
0.853
|
0.655
|
270.75
|
11
|
ORFpls
|
0.857
|
0.846
|
0.655
|
473.31
|
12
|
xgbDART
|
0.857
|
0.866
|
0.619
|
4367.11
|
13
|
bagFDAGCV
|
0.857
|
0.854
|
0.649
|
10.09
|
14
|
rotationForestCp
|
0.857
|
0.868
|
0.660
|
20.22
|
15
|
ORFridge
|
0.856
|
0.832
|
0.692
|
999.31
|
16
|
ORFsvm
|
0.856
|
0.863
|
0.627
|
2340.72
|
17
|
bagFDA
|
0.856
|
0.856
|
0.625
|
27.53
|
18
|
ada
|
0.856
|
0.842
|
0.670
|
71.13
|
19
|
loclda
|
0.856
|
0.890
|
0.543
|
28.56
|
20
|
ORFlog
|
0.855
|
0.834
|
0.694
|
1054.09
|
21
|
rotationForest
|
0.855
|
0.854
|
0.653
|
11.03
|
22
|
avNNet
|
0.855
|
0.821
|
0.707
|
21.25
|
23
|
xgbTree
|
0.855
|
0.851
|
0.662
|
109.29
|
24
|
gamLoess
|
0.855
|
0.848
|
0.649
|
3.11
|
25
|
dwdRadial
|
0.854
|
0.864
|
0.595
|
12.89
|
26
|
gbm
|
0.854
|
0.859
|
0.634
|
2.77
|
27
|
gaussprPoly
|
0.853
|
0.867
|
0.582
|
108.61
|
28
|
dwdPoly
|
0.853
|
0.875
|
0.590
|
90.92
|
29
|
wsrf
|
0.852
|
0.848
|
0.644
|
7.94
|
30
|
earth
|
0.852
|
0.852
|
0.649
|
2.86
|
31
|
gcvEarth
|
0.852
|
0.852
|
0.649
|
1.36
|
32
|
AdaBoost.M1
|
0.851
|
0.857
|
0.625
|
3395.17
|
33
|
RRF
|
0.851
|
0.842
|
0.679
|
185.86
|
34
|
RRFglobal
|
0.851
|
0.839
|
0.681
|
30.85
|
35
|
fda
|
0.851
|
0.852
|
0.644
|
1.47
|
36
|
nodeHarvest
|
0.850
|
0.883
|
0.543
|
205.51
|
37
|
svmPoly
|
0.849
|
0.888
|
0.543
|
25.56
|
38
|
pcaNNet
|
0.849
|
0.849
|
0.627
|
5.32
|
39
|
mlp
|
0.847
|
0.835
|
0.659
|
5.73
|
40
|
mlpML
|
0.847
|
0.835
|
0.659
|
6.17
|
41
|
naive_bayes
|
0.847
|
0.787
|
0.711
|
1.06
|
42
|
nb
|
0.847
|
0.788
|
0.711
|
4.90
|
43
|
svmRadialSigma
|
0.847
|
0.879
|
0.576
|
12.50
|
44
|
mlpWeightDecay
|
0.846
|
0.861
|
0.619
|
15.25
|
45
|
mlpWeightDecayML
|
0.846
|
0.861
|
0.619
|
15.69
|
46
|
blackboost
|
0.846
|
0.879
|
0.549
|
28.14
|
47
|
AdaBag
|
0.844
|
0.797
|
0.694
|
1189.19
|
48
|
treebag
|
0.844
|
0.828
|
0.660
|
2.88
|
49
|
svmRadial
|
0.843
|
0.881
|
0.550
|
4.84
|
50
|
svmRadialWeights
|
0.843
|
0.883
|
0.545
|
10.30
|
51
|
svmRadialCost
|
0.843
|
0.880
|
0.550
|
4.86
|
52
|
spls
|
0.842
|
0.877
|
0.552
|
4.46
|
53
|
kernelpls
|
0.842
|
0.876
|
0.556
|
1.12
|
54
|
pls
|
0.842
|
0.876
|
0.556
|
1.17
|
55
|
simpls
|
0.842
|
0.876
|
0.556
|
1.33
|
56
|
widekernelpls
|
0.842
|
0.876
|
0.556
|
25.18
|
57
|
regLogistic
|
0.842
|
0.881
|
0.547
|
3.40
|
58
|
glmnet
|
0.842
|
0.884
|
0.543
|
2.28
|
59
|
multinom
|
0.842
|
0.875
|
0.554
|
1.07
|
60
|
mda
|
0.842
|
0.872
|
0.590
|
2.03
|
61
|
pda2
|
0.842
|
0.876
|
0.556
|
1.25
|
62
|
plr
|
0.842
|
0.874
|
0.560
|
1.48
|
63
|
ordinalNet
|
0.842
|
0.884
|
0.534
|
1159.87
|
64
|
glmStepAIC
|
0.841
|
0.878
|
0.556
|
2.32
|
65
|
bayesglm
|
0.841
|
0.874
|
0.560
|
1.98
|
66
|
dwdLinear
|
0.841
|
0.878
|
0.567
|
2.17
|
67
|
glmboost
|
0.841
|
0.876
|
0.560
|
1.59
|
68
|
glm
|
0.841
|
0.874
|
0.563
|
0.99
|
69
|
vglmAdjCat
|
0.841
|
0.874
|
0.563
|
2.95
|
70
|
vglmContRatio
|
0.841
|
0.874
|
0.563
|
12.13
|
71
|
vglmCumulative
|
0.841
|
0.874
|
0.563
|
10.44
|
72
|
pda
|
0.841
|
0.876
|
0.558
|
1.20
|
73
|
sparseLDA
|
0.841
|
0.000
|
1.000
|
1.92
|
74
|
sda
|
0.841
|
0.876
|
0.563
|
1.35
|
75
|
hdda
|
0.841
|
0.876
|
0.556
|
1.45
|
76
|
lda
|
0.841
|
0.876
|
0.556
|
0.99
|
77
|
lda2
|
0.841
|
0.876
|
0.556
|
0.97
|
78
|
rda
|
0.841
|
0.876
|
0.556
|
7.42
|
79
|
monmlp
|
0.840
|
0.857
|
0.623
|
59.94
|
80
|
msaenet
|
0.840
|
0.819
|
0.666
|
79.53
|
81
|
svmLinearWeights
|
0.840
|
0.764
|
0.735
|
5.04
|
82
|
svmLinear
|
0.839
|
0.882
|
0.543
|
2.84
|
83
|
svmLinear2
|
0.839
|
0.878
|
0.556
|
2.14
|
84
|
C5.0
|
0.839
|
0.839
|
0.625
|
4.65
|
85
|
pam
|
0.835
|
0.929
|
0.414
|
1.08
|
86
|
xgbLinear
|
0.834
|
0.826
|
0.668
|
117.73
|
87
|
adaboost
|
0.832
|
0.834
|
0.662
|
118.41
|
88
|
LMT
|
0.831
|
0.864
|
0.567
|
6.09
|
89
|
knn
|
0.825
|
0.830
|
0.604
|
1.15
|
90
|
slda
|
0.822
|
0.858
|
0.552
|
1.14
|
91
|
PART
|
0.816
|
0.813
|
0.674
|
3.51
|
92
|
qda
|
0.816
|
0.839
|
0.575
|
0.93
|
93
|
ctree2
|
0.815
|
0.790
|
0.698
|
2.66
|
94
|
ctree
|
0.814
|
0.833
|
0.610
|
1.53
|
95
|
Linda
|
0.811
|
0.754
|
0.707
|
3.43
|
96
|
kknn
|
0.811
|
0.811
|
0.601
|
7.49
|
97
|
hda
|
0.809
|
0.867
|
0.511
|
26.33
|
98
|
xyf
|
0.801
|
0.876
|
0.552
|
16.91
|
99
|
rbfDDA
|
0.800
|
0.848
|
0.616
|
21.88
|
100
|
C5.0Tree
|
0.800
|
0.803
|
0.631
|
1.08
|
101
|
LogitBoost
|
0.798
|
0.811
|
0.621
|
1.95
|
102
|
sdwd
|
0.794
|
1.000
|
0.000
|
1.79
|
103
|
stepLDA
|
0.794
|
0.879
|
0.500
|
9.47
|
104
|
stepQDA
|
0.794
|
0.887
|
0.487
|
9.39
|
105
|
J48
|
0.793
|
0.815
|
0.623
|
7.99
|
106
|
rpart2
|
0.792
|
0.800
|
0.638
|
1.44
|
107
|
rpart1SE
|
0.787
|
0.809
|
0.618
|
1.22
|
108
|
rpart
|
0.775
|
0.827
|
0.586
|
1.40
|
109
|
JRip
|
0.769
|
0.828
|
0.700
|
16.94
|
110
|
C5.0Rules
|
0.768
|
0.828
|
0.625
|
1.07
|
111
|
evtree
|
0.767
|
0.841
|
0.618
|
126.20
|
112
|
QdaCov
|
0.764
|
0.683
|
0.743
|
2.60
|
113
|
rmda
|
0.745
|
0.554
|
0.815
|
16.53
|
114
|
dnn
|
0.731
|
1.000
|
0.000
|
51.50
|
115
|
PRIM
|
0.725
|
0.777
|
0.563
|
260.00
|
116
|
OneR
|
0.638
|
0.822
|
0.453
|
1.21
|
117
|
mlpKerasDecay
|
0.627
|
0.517
|
0.521
|
231.19
|
118
|
mlpKerasDecayCost
|
0.626
|
0.254
|
0.746
|
572.08
|
119
|
mlpKerasDropoutCost
|
0.598
|
0.214
|
0.787
|
577.20
|
120
|
mlpKerasDropout
|
0.526
|
0.750
|
0.252
|
195.55
|
121
|
null
|
0.500
|
1.000
|
0.000
|
0.84
|
122
|
bagEarth
|
0.195
|
0.102
|
0.532
|
42.21
|
123
|
plsRglm
|
0.168
|
0.865
|
0.558
|
8.37
|
124
|
bagEarthGCV
|
0.143
|
0.147
|
0.353
|
12.43
|
125
|
Final Notes
Kết quả của 125 thuật toán học máy cho bộ dữ liệu pima dẫn đến một số ngụ ý sau:
Có rất nhiều bài viết kiểu như Top 10 Machine Learning Algorithms, LightGBM vs XGBOOST – Which algorithm is better, Kaggle Competitions Top Classification Algorithm nói về thuật toán nào là “tốt nhất”. Tuy nhiên kết quả từ Table 2 cho thấy chỉ có right algorithm for right data mà thôi (nhại lại câu thành ngữ right tool for the job).
Thông thường các dự án data science bị giới hạn bởi thời gian và có deadline rõ ràng. Mặt khác chúng ta hoàn toàn không thể biết trước thuật toán nào là tốt nhất với bộ dữ liệu của dự án. Những kết quả ở trên có thể cho một gợi ý: huấn luyện đồng thời chỉ một số (hoặc tất cả nếu đủ thời gian) mô hình lựa chọn trước xem kết quả ra sao để chọn ra mô hình có vẻ ổn nhất. Kế tiếp lại tinh chỉnh sâu hơn nữa các tham số cho mô hình ổn nhất đó. Chẳng hạn từ Table 2 thì gamboost (Gradient Boosting with Componentwise Smoothing Splines) là thuật toán có ROC cao nhất trên bộ dữ liệu pima. Chúng ta có thể tinh chỉnh cẩn thận hơn nữa các tham số của thuật toán này để có ROC cao hơn.
Table 2 cũng cho thấy có một số thuật toán mà thời gian huấn luyện là rất lớn trong khi model perfornamce thì không được tốt như kì vọng, thậm chí là thua kém hơn so với những thuật toán đơn giản khác. Những thuật toán này có lẽ không nên được lựa chọn khi mà nguồn lực thời gian cho dự án là không nhiều hoặc deadline là sắp hết.
An Extension
Ngoài tinh chỉnh tham số cho một thuật toán được chọn thì chúng ta cũng còn một lựa chọn khác: sử dụng Ensemble Learning Methods. Thay vì sử dụng For Loop như trên chúng ta có thể sử dụng hàm caretList()
của thư viện caretEnsemble để huấn luyện đồng thời nhiều thuật toán. Dưới đây là R codes cho cách tiếp cận này (thời gian huấn luyện khá nhiều nên không hiển thị kết quả):
# Algorithms with training time <= 300s:
df_space %>%
filter(train_time <= 300) %>%
pull(method) -> some_models_selected
#===================
# Ensemble
#===================
set.seed(1)
number <- 5
repeats <- 3
library(caret)
control <- trainControl(method = "repeatedcv",
number = number ,
repeats = repeats,
classProbs = TRUE,
allowParallel = FALSE,
summaryFunction = twoClassSummary,
savePredictions = "final",
index = createResample(pima_RFimputed$diabete_status, repeats*number))
# Train these ML Models:
library(caretEnsemble)
set.seed(1)
system.time(model_list1 <- caretList(diabete_status ~.,
data = pima_RFimputed,
trControl = control,
metric = "ROC",
methodList = some_models_selected))
# Extract results:
list_of_results <- lapply(some_models_selected, function(x) {model_list1[[x]]$resample})
# Convert to data frame:
total_df <- do.call("bind_rows", list_of_results)
total_df %>% mutate(Model = lapply(some_models_selected, function(x) {rep(x, number*repeats)}) %>% unlist()) -> total_df
# Save results:
write.csv(total_df, "total_df.csv", row.names = FALSE)
# Ensemble 1: M
greedy_ensemble <- caretEnsemble(model_list1,
metric = "ROC",
trControl = control)
# Draft Results:
summary(greedy_ensemble)
# Add Ensemble 1 Model:
total_df_en <- bind_rows(total_df, greedy_ensemble$ens_model$resample %>% mutate(Model = "ensemble"))
# Ensemble 2:
gbm_ensemble <- caretStack(model_list1,
method = "gbm",
metric = "ROC",
verbose = FALSE,
trControl = control)
# Add Ensemble 2 Model:
total_df_en2 <- bind_rows(total_df_en, gbm_ensemble$ens_model$resample %>% mutate(Model = "ensemble2"))
# Average ROC based on 15 samples:
total_df_en2 %>%
group_by(Model) %>%
summarise(avg_roc = mean(ROC)) %>%
ungroup() %>%
arrange(-avg_roc) %>%
mutate(rank = 1:nrow(.)) %>%
mutate_if(is.numeric, function(x) {round(x, 3)}) -> df_with_en2
df_with_en2 %>%
filter(stringr::str_detect(Model, "ensemble"))
# Rank = 14 for ensemble 1, = 14 for ensemble 2:
df_with_en2 %>%
mutate_if(is.numeric, function(x) {round(x, 3)}) %>%
kbl(caption = "Table 3: Ensemble Approach by ROC", escape = TRUE) %>%
kable_classic(full_width = FALSE, html_font = "Cambria")