Bắt tay vào thực hành phân tích dữ liệu với các kĩ thuật machine learning là lúc phải đối diện với nhiều câu hỏi cùng một lúc. Phải chuẩn bị dữ liệu như thế nào, xử lí những NA’s, chuẩn hóa các biến số định lượng, chọn thuật toán (algoithm) nào, chọn mô hình (model) nào, chọn những biến số nào để đưa vào mô hình và đánh giá performance của các mô hình bằng các tiêu chí nào cho phù hợp.
Đôi khi chúng ta thấy rối rắm, nhức đầu vì các câu hỏi trên. Phải gọi nhiều package khác nhau để xử lí các công đoạn của “dự án” mà ta đang thực hiện. Dùng MICE để xử lí missing, Caret để phân chia trainset và testset, dùng các package chuyên biệt cho các thuật toán như rpart, C50, gbm hay randomForest để xử lí các mô hình hoặc các package tích hợp như Caret, mlr ? Sau đó so sánh các mô hình, thuật toán để tìm ra giải pháp tốt nhất.
Liệu chúng ta có chọn nhầm thuật toán hay không, chúng ta có bỏ sót một thuật toán nào mà nó có thể là chọn lựa tốt hơn hết cho bài toán của chúng ta ?
Các package chuyên biệt như C50, rpat, gbm giúp chúng ta hiểu sâu các khái niệm, thông số của thuật toán. Tuy nhiên, khi vào thực hành, chúng ta cần có một integrated package giúp ta so sánh trong phạm vi rộng hơn để tìm ra chọn lựa tốt nhất.
Qua thực hành, tôi thấy package mlr là một gói tích hợp nhiều thủ thuật có thể đơn giản hóa công việc của người phân tích dữ liệu với machine learning. Bài viết này trình bày việc sử dụng mlr để giải quyết nhiều vấn đề nêu trên khi làm việc với các mô hình machine learning.
Tôi sử dụng dataset hepatitis từ openML để minh họa. Dataset này có thể tải về máy tính từ link https://www.openml.org/d/55, sau đó gọi vào RStudio như sau:
library(tidyverse)
mydf=read.csv("F:/R/Data/dataset_55_hepatitis.csv", na.strings = "?")
lưu ý na.strings = “?” giúp chúng ta đọc những NA trong dataset khi chúng được kí hiệu bằng dấu “?”
str(mydf)
## 'data.frame': 155 obs. of 20 variables:
## $ AGE : int 30 50 78 31 34 34 51 23 39 30 ...
## $ SEX : Factor w/ 2 levels "female","male": 2 1 1 1 1 1 1 1 1 1 ...
## $ STEROID : Factor w/ 2 levels "no","yes": 1 1 2 NA 2 2 1 2 2 2 ...
## $ ANTIVIRALS : Factor w/ 2 levels "no","yes": 1 1 1 2 1 1 1 1 1 1 ...
## $ FATIGUE : Factor w/ 2 levels "no","yes": 1 2 2 1 1 1 2 1 2 1 ...
## $ MALAISE : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 1 1 1 1 ...
## $ ANOREXIA : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 2 1 1 1 ...
## $ LIVER_BIG : Factor w/ 2 levels "no","yes": 1 1 2 2 2 2 2 2 2 2 ...
## $ LIVER_FIRM : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 1 1 2 1 ...
## $ SPLEEN_PALPABLE: Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 2 1 1 1 ...
## $ SPIDERS : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 2 1 1 1 ...
## $ ASCITES : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 1 1 1 1 ...
## $ VARICES : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 1 1 1 1 ...
## $ BILIRUBIN : num 1 0.9 0.7 0.7 1 0.9 NA 1 0.7 1 ...
## $ ALK_PHOSPHATE : int 85 135 96 46 NA 95 NA NA NA NA ...
## $ SGOT : int 18 42 32 52 200 28 NA NA 48 120 ...
## $ ALBUMIN : num 4 3.5 4 4 4 4 NA NA 4.4 3.9 ...
## $ PROTIME : int NA NA NA 80 NA 75 NA NA NA NA ...
## $ HISTOLOGY : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 1 1 1 1 ...
## $ Class : Factor w/ 2 levels "DIE","LIVE": 2 2 2 2 2 2 1 2 2 2 ...
head(mydf)
## AGE SEX STEROID ANTIVIRALS FATIGUE MALAISE ANOREXIA LIVER_BIG
## 1 30 male no no no no no no
## 2 50 female no no yes no no no
## 3 78 female yes no yes no no yes
## 4 31 female <NA> yes no no no yes
## 5 34 female yes no no no no yes
## 6 34 female yes no no no no yes
## LIVER_FIRM SPLEEN_PALPABLE SPIDERS ASCITES VARICES BILIRUBIN
## 1 no no no no no 1.0
## 2 no no no no no 0.9
## 3 no no no no no 0.7
## 4 no no no no no 0.7
## 5 no no no no no 1.0
## 6 no no no no no 0.9
## ALK_PHOSPHATE SGOT ALBUMIN PROTIME HISTOLOGY Class
## 1 85 18 4.0 NA no LIVE
## 2 135 42 3.5 NA no LIVE
## 3 96 32 4.0 NA no LIVE
## 4 46 52 4.0 80 no LIVE
## 5 NA 200 4.0 NA no LIVE
## 6 95 28 4.0 75 no LIVE
summary(mydf)
## AGE SEX STEROID ANTIVIRALS FATIGUE MALAISE
## Min. : 7.0 female:139 no :76 no :131 no : 54 no :93
## 1st Qu.:32.0 male : 16 yes :78 yes: 24 yes :100 yes :61
## Median :39.0 NA's: 1 NA's: 1 NA's: 1
## Mean :41.2
## 3rd Qu.:50.0
## Max. :78.0
##
## ANOREXIA LIVER_BIG LIVER_FIRM SPLEEN_PALPABLE SPIDERS ASCITES
## no :122 no : 25 no :84 no :120 no :99 no :130
## yes : 32 yes :120 yes :60 yes : 30 yes :51 yes : 20
## NA's: 1 NA's: 10 NA's:11 NA's: 5 NA's: 5 NA's: 5
##
##
##
##
## VARICES BILIRUBIN ALK_PHOSPHATE SGOT
## no :132 Min. :0.300 Min. : 26.00 Min. : 14.00
## yes : 18 1st Qu.:0.700 1st Qu.: 74.25 1st Qu.: 31.50
## NA's: 5 Median :1.000 Median : 85.00 Median : 58.00
## Mean :1.428 Mean :105.33 Mean : 85.89
## 3rd Qu.:1.500 3rd Qu.:132.25 3rd Qu.:100.50
## Max. :8.000 Max. :295.00 Max. :648.00
## NA's :6 NA's :29 NA's :4
## ALBUMIN PROTIME HISTOLOGY Class
## Min. :2.100 Min. : 0.00 no :85 DIE : 32
## 1st Qu.:3.400 1st Qu.: 46.00 yes:70 LIVE:123
## Median :4.000 Median : 61.00
## Mean :3.817 Mean : 61.85
## 3rd Qu.:4.200 3rd Qu.: 76.25
## Max. :6.400 Max. :100.00
## NA's :16 NA's :67
Đây là bộ dữ liệu của những bệnh nhân viêm gan, thông qua các biến số về nhân chủng học, triệu chứng lâm sàng và xét nghiệm chức năng gan để tiên lượng tử vong (LIVE hoặc DIE).
Chúng ta sẽ vận dụng package mlr để chạy các thuật toán machine learning nhằm tìm ra mô hình tiên lượng tử vong nào tối ưu cho những đối tượng nghiên cứu này.
Dataset có 155 dòng và 20 biến số, trong đó có 6 biến số định lượng và 14 biến số định tính (factor), có nhiều NA trong các biến số. Vấn đề này cần giải quyết vì nhiều thuật toán machine learning không chạy được khi có NA.
# Data visualisation
mydf%>%gather(BILIRUBIN,ALK_PHOSPHATE,SGOT,ALBUMIN, key="Outcome",value="Value")%>%
ggplot(aes(x=AGE,fill=Class))+
geom_density(alpha=0.5,col="black")+
theme_minimal()+
facet_wrap(~Outcome,scales="free",ncol=2)+
scale_fill_manual(values=c("green","pink"))
Nhóm tử vong có tuổi cao hơn và các chỉ số men gan cao hơn.
mydf%>%gather(SEX:VARICES,key="Factors",value="Status")%>%
ggplot(aes(x=Status,fill=Class))+
geom_bar(stat="count",position="fill",alpha=0.5,col="black")+
scale_y_continuous(labels=NULL)+
theme_minimal()+
facet_wrap(~Factors,scales="free",ncol=6)+
scale_fill_manual(values=c("green","pink"))
library(mlr)
imp = impute(mydf, target = "Class", classes = list(integer = imputeMean(), numeric = imputeMean(),
factor = imputeMode()))
mydf <- imp$data
imp$desc
## Imputation description
## Target: Class
## Features: 19; Imputed: 19
## impute.new.levels: TRUE
## recode.factor.levels: TRUE
## dummy.type: factor
sum(is.na(mydf$BILIRUBIN))
## [1] 0
Kiểm tra lại với BILIRUBIN ta thấy đã không còn NA.
Tiếp tục chuẩn hóa dữ liệu định lượng (scale và centerize)
normalizeFeatures(mydf,target="Class",method="standardize")
Đến đây dữ liệu chúng ta đã chuẩn bị xong, sẳn sàng cho việc phân tích.
Tiên lượng tử vong (LIVE hoặc DIE) thông qua các thuật toán phân loại (classification). mlr sẽ cho ta biết các thuật toán nào phù hợp với task1 của chúng ta:
task1 = makeClassifTask(data = mydf, target = "Class")
View(listLearners(task1, properties = "prob"))
listLearners("classif", properties = "prob")[c("class", "package")]
## class package
## 1 classif.ada ada,rpart
## 2 classif.adaboostm1 RWeka
## 3 classif.bartMachine bartMachine
## 4 classif.binomial stats
## 5 classif.blackboost mboost,party
## 6 classif.boosting adabag,rpart
## ... (#rows: 66, #cols: 2)
Kết quả cho thấy 19 thuật toán (learners) có thể được vận dụng. Phải lưu ý là có thể RStudio của bạn chưa install đầy đủ những packages liên quan đến các thuật toán (learners) được liệt kê ở trên. Nếu như vậy một số functions sẽ báo lỗi vì thiếu package đó.
Trong phép classification thông thường chúng ta sử dụng các measures như accuracy (acc), mmce, tpr để đánh giá performance của các mô hình.
Chúng ta tiếp tục đặt câu hỏi là trong 19 learners trên, những learners là tối ưu hơn về acc, mmce, tpr.
# Tuning hyperparameters
set.seed(2018)
rdesc = makeResampleDesc("CV", iters = 3) # stratify = TRUE for imbalaned data
# Building multiple learners
lrns = list(makeLearner("classif.binomial" , predict.type = "prob"),
"classif.cvglmnet" , "classif.featureless" , "classif.gausspr" ,
"classif.gbm", "classif.glmnet" , "classif.ksvm" ,
"classif.lda", "classif.logreg" , "classif.lssvm" ,
"classif.mda" , "classif.multinom" , "classif.naiveBayes" ,
"classif.nnet" , "classif.probit" , "classif.randomForest",
"classif.randomForestSRC" ,"classif.rpart" , "classif.svm"
)
# Comparing learners
bmr = benchmark(lrns, task1, rdesc, measure = list(acc,tpr))
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Distribution not specified, assuming bernoulli ...
## Distribution not specified, assuming bernoulli ...
## Distribution not specified, assuming bernoulli ...
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## # weights: 21 (20 variable)
## initial value 72.087307
## iter 10 value 28.148045
## iter 20 value 25.456933
## iter 30 value 25.425814
## iter 40 value 25.419918
## final value 25.419850
## converged
## # weights: 21 (20 variable)
## initial value 71.394160
## iter 10 value 22.845310
## iter 20 value 17.642022
## iter 30 value 17.609646
## iter 40 value 17.608130
## iter 50 value 17.607370
## iter 60 value 17.607349
## iter 70 value 17.607331
## iter 80 value 17.607194
## final value 17.607138
## converged
## # weights: 21 (20 variable)
## initial value 71.394160
## iter 10 value 13.910094
## iter 20 value 10.568199
## iter 30 value 10.510257
## iter 40 value 10.495410
## iter 50 value 10.490648
## iter 60 value 10.488000
## iter 70 value 10.485917
## iter 80 value 10.485342
## final value 10.485284
## converged
## # weights: 64
## initial value 59.261328
## iter 10 value 51.422510
## iter 20 value 45.486253
## iter 30 value 44.649394
## iter 40 value 44.628624
## iter 50 value 42.765780
## iter 60 value 42.261831
## iter 70 value 42.173959
## iter 80 value 41.366345
## iter 90 value 40.744585
## iter 100 value 40.739016
## final value 40.739016
## stopped after 100 iterations
## # weights: 64
## initial value 93.352403
## iter 10 value 43.804454
## iter 20 value 42.689600
## iter 30 value 42.368059
## iter 40 value 42.363899
## iter 50 value 39.321621
## iter 60 value 38.984377
## iter 70 value 37.620221
## iter 80 value 37.201900
## iter 90 value 37.100696
## iter 100 value 36.758103
## final value 36.758103
## stopped after 100 iterations
## # weights: 64
## initial value 55.053890
## iter 10 value 49.538762
## iter 20 value 48.904244
## final value 48.902416
## converged
bmr
## task.id learner.id acc.test.mean tpr.test.mean
## 1 mydf classif.binomial 0.8387381 0.42393162
## 2 mydf classif.cvglmnet 0.8065611 0.06666667
## 3 mydf classif.featureless 0.7937406 0.00000000
## 4 mydf classif.gausspr 0.8451483 0.35356125
## 5 mydf classif.gbm 0.7937406 0.00000000
## 6 mydf classif.glmnet 0.8257919 0.41623932
## 7 mydf classif.ksvm 0.8388638 0.30598291
## 8 mydf classif.lda 0.8257919 0.48660969
## 9 mydf classif.logreg 0.8387381 0.42393162
## 10 mydf classif.lssvm 0.8132227 0.38005698
## 11 mydf classif.mda 0.8455254 0.49031339
## 12 mydf classif.multinom 0.8387381 0.42393162
## 13 mydf classif.naiveBayes 0.8323278 0.64900285
## 14 mydf classif.nnet 0.8068125 0.14814815
## 15 mydf classif.probit 0.8386124 0.41253561
## 16 mydf classif.randomForest 0.8386124 0.39829060
## 17 mydf classif.randomForestSRC 0.8257919 0.39059829
## 18 mydf classif.rpart 0.7351684 0.36125356
## 19 mydf classif.svm 0.8323278 0.26894587
Ở bảng kết quả này cho thấy top 5 các learners có acc và tpr khả dĩ nhất là
task.id | learner.id | acc.test.mean | tpr.test.mean |
---|---|---|---|
3 mydf | classif.lda | 0.8257919 | 0.48660969 |
5 mydf | classif.multinom | 0.8387381 | 0.42393162 |
4 mydf | classif.binomial | 0.8387381 | 0.42393162 |
1 mydf | classif.naiveBayes | 0.8319507 | 0.65000000 |
2 mydf | classif.mda | 0.8455254 | 0.49031339 |
Từ đây tôi chọn ra 3 learners tiềm năng để phân tích và so sánh.
# partitioning dataset
set.seed(2018)
n = nrow(mydf)
train.set = sample(n, size = 2/3*n)
test.set = setdiff(1:n, train.set)
# classif.naiveBayes
set.seed(2018)
naiveBayes.lrn = makeLearner("classif.naiveBayes", predict.type = "prob")
naiveBayes.model = train(naiveBayes.lrn, task = task1, subset = train.set)
naiveBayes.pred = predict(naiveBayes.model,task = task1, subset = test.set)
naiveBayes.perf <- performance(naiveBayes.pred, measures = list(mmce, acc, tpr))
naiveBayes.perf
## mmce acc tpr
## 0.1346154 0.8653846 0.8000000
calculateConfusionMatrix(naiveBayes.pred)
## predicted
## true DIE LIVE -err.-
## DIE 8 2 2
## LIVE 5 37 5
## -err.- 5 2 7
# classif.mda
set.seed(2018)
mda.lrn = makeLearner("classif.mda", predict.type = "prob")
mda.model = train(mda.lrn, task = task1, subset = train.set)
mda.pred = predict(mda.model,task = task1, subset = test.set)
mda.perf <- performance(mda.pred, measures = list(mmce, acc, fpr))
mda.perf
## mmce acc fpr
## 0.1346154 0.8653846 0.0952381
calculateConfusionMatrix(mda.pred)
## predicted
## true DIE LIVE -err.-
## DIE 7 3 3
## LIVE 4 38 4
## -err.- 4 3 7
# classif.lda
set.seed(2018)
lda.lrn = makeLearner("classif.lda", predict.type = "prob")
lda.model = train(lda.lrn, task = task1, subset = train.set)
lda.pred = predict(lda.model,task = task1, subset = test.set)
lda.perf <- performance(lda.pred, measures = list(mmce, acc, tpr))
lda.perf
## mmce acc tpr
## 0.1346154 0.8653846 0.7000000
calculateConfusionMatrix(lda.pred)
## predicted
## true DIE LIVE -err.-
## DIE 7 3 3
## LIVE 4 38 4
## -err.- 4 3 7
rbind(naiveBayes.perf, mda.perf, lda.perf)
## mmce acc tpr
## naiveBayes.perf 0.1346154 0.8653846 0.8000000
## mda.perf 0.1346154 0.8653846 0.0952381
## lda.perf 0.1346154 0.8653846 0.7000000
Mô hình naiveBayes và lda đều có accuracy cao nhất. Tuy nhiên, ở đây tôi muốn phát hiện càng chính xác càng tốt những bệnh nhân có nguy tử vong (DIE) cho nên tôi quan tâm hơn đến tiêu chí tpr (true positive rate). Vì thế tôi chọn mô hình classif.naiveBayes để phân tích tiếp vì có tpr = 0.8.
library(randomForestSRC)
feats <- generateFilterValuesData(task1, method = "randomForestSRC.rfsrc",
nselect = getTaskNFeats(task1))
plotFilterValues(feats)
Biểu đồ đánh giá mức độ quan trọng của các biến số trong dataset trên khi phân loại LIVE hoặc DIE theo phương pháp randomForestSRC.
Tuy nhiên, chuyên biệt hơn, chúng ta sẽ chọn lựa biến số riêng cho mô hình theo thuật toán naiveBayes mà chúng ta vừa xác định trên đây.
# Tuning hyperparameters
set.seed(2018)
rdesc = makeResampleDesc("CV", iters = 3, stratify = TRUE ) # stratify = TRUE for imbalaned data
contrl <- makeFeatSelControlRandom(maxit=100, prob=0.5, max.features=10) # creat a control
fsr <- selectFeatures(naiveBayes.lrn, task1, resampling = rdesc, control = contrl) # feature selection
fsr
## FeatSel result:
## Features (9): SEX, ANTIVIRALS, FATIGUE, MALAISE, ANOREXIA, VARICES, BILIRUBIN, PROTIME, HISTOLOGY
## mmce.test.mean=0.1290850
Kêt quả là có 9 biến số ở trên được chọn lựa trong số 19 biến số predictors trong dataset.
Đến đây mlr cho phép chúng ta update các biến số được chọn lựa này vào mô hình rút gọn 9 biến số mà không cần viết code learner lại từ đầu. Tôi thích function này.
# remodeling with 8 selected features for naiveBayes
set.seed(2018)
task2 = subsetTask(task1,features=fsr$x)
naiveBayes.model2 = train(naiveBayes.lrn, task2, subset = train.set)
naiveBayes.pred2 = predict(naiveBayes.model2, task = task2, subset = test.set)
perf2 <-performance(naiveBayes.pred2, measures = list(mmce, acc, tpr))
round(perf2, digits = 3)
## mmce acc tpr
## 0.115 0.885 0.700
calculateConfusionMatrix(naiveBayes.pred2)
## predicted
## true DIE LIVE -err.-
## DIE 7 3 3
## LIVE 3 39 3
## -err.- 3 3 6
Việc tinh lọc biến số ở trên mô hình đã cải thiện rất đáng kể tính chính xác (acc = 0.885, tuy nhiên, tỉ lệ tpr = 0.70 (true positive rate) giảm hơn mô hình đầy đủ biến số (full model với thuật toán naiveBayes, tpr = 0.8)
Với tiêu chí ưu tiên phát hiện những trường hợp nguy cơ tử vong cao (không bỏ sót trường hợp nguy cơ cao nào) chúng ta chọn lựa full model có tpr = 0.8, acc = 0.865.
Chúng ta vẽ biểu đồ ROC của mô hình đầy đủ
d1 = generateThreshVsPerfData(naiveBayes.pred, measures = list(fpr, tpr))
plotROCCurves(d1)
calculateROCMeasures(naiveBayes.pred)
## predicted
## true DIE LIVE
## DIE 8 2 tpr: 0.8 fnr: 0.2
## LIVE 5 37 fpr: 0.12 tnr: 0.88
## ppv: 0.62 for: 0.05 lrp: 6.72 acc: 0.87
## fdr: 0.38 npv: 0.95 lrm: 0.23 dor: 29.6
##
##
## Abbreviations:
## tpr - True positive rate (Sensitivity, Recall)
## fpr - False positive rate (Fall-out)
## fnr - False negative rate (Miss rate)
## tnr - True negative rate (Specificity)
## ppv - Positive predictive value (Precision)
## for - False omission rate
## lrp - Positive likelihood ratio (LR+)
## fdr - False discovery rate
## npv - Negative predictive value
## acc - Accuracy
## lrm - Negative likelihood ratio (LR-)
## dor - Diagnostic odds ratio
Mỗi một package đều có những lợi thế riêng. Nếu sử dụng package chuyên biệt cho từng thuật toán, chúng ta sẽ nắm vững những đặc điểm, tính năng đặc trưng của package đó.
Khi sử dụng những package tích hợp như mlr hoặc caret chúng ta có thể đồng thời chạy nhiều thuật toán machine learning khác nhau để so sánh, chọn lựa thuật toán và mô hình khác nhau.
Mong nhận được những ý kiến của các bạn. Cám ơn các bạn đã dành thời gian đọc bài viết này.
Tài liệu tham khảo thêm mlr ở link này: https://mlr-org.github.io/mlr-tutorial/release/html/index.html
Cheatsheet tra cứu nhanh: https://github.com/mlr-org/mlr-tutorial/raw/gh-pages/cheatsheet/MlrCheatsheet.pdf