1 Các bước xây dựng mô hình machine learning bằng mlr3

  • task: Khai báo nguồn dữ liệu đầu vào, có thể là data.frame, database
  • learner: Khai báo thuật toán sử dụng. Ví dụ random forest, xgboost, lightgbm …
  • search_space: Khai báo tham số cần tuning cho learner ở trên, ví dụ number of tree, mtry …
  • measure: Khai báo tiêu chí đánh giá performance của learner. Ví dụ AUC, CE …
  • resampling: Khai báo cách phân chia mẫu thành tập train, tập test. Ví dụ: cross validation, holdout, …
  • tuning: lựa chọn thuật toán hay cách để tuning. Ví dụ: hyperband, grid_search …

2 Ví dụ

Sử dụng thuật toán random forest để dự báo xác suất vỡ nợ của khách hàng. Dữ liệu đầu vào là tập german_credit.

2.1 Khai báo task

task_classif <- tsk('german_credit')
task_classif
## <TaskClassif:german_credit> (1000 x 21)
## * Target: credit_risk
## * Properties: twoclass
## * Features (20):
##   - fct (14): credit_history, employment_duration, foreign_worker,
##     housing, job, other_debtors, other_installment_plans,
##     people_liable, personal_status_sex, property, purpose, savings,
##     status, telephone
##   - int (3): age, amount, duration
##   - ord (3): installment_rate, number_credits, present_residence

2.2 Khai báo learner

classif_learner <- lrn('classif.ranger')
classif_learner
## <LearnerClassifRanger:classif.ranger>
## * Model: -
## * Parameters: num.threads=1
## * Packages: ranger
## * Predict Type: response
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: importance, multiclass, oob_error, twoclass, weights
# chọn loại dự báo
classif_learner$predict_type <- "prob"

2.3 Khai báo tham số cần tunning

  • Trước khi khai báo tham số cần tuning nên kiểm tra xem learner hỗ trợ những tham số nào và tham số đấy định dạng factor, hay numeric … để khai báo cho phù hợp
classif_learner$param_set %>% 
  as.data.table() %>% 
  select(id,    class,  lower,  upper,  levels) %>% 
  kable()
id class lower upper levels
num.trees ParamInt 1 Inf NULL
mtry ParamInt 1 Inf NULL
importance ParamFct NA NA none , impurity , impurity_corrected, permutation
write.forest ParamLgl NA NA TRUE, FALSE
min.node.size ParamInt 1 Inf NULL
replace ParamLgl NA NA TRUE, FALSE
sample.fraction ParamDbl 0 1 NULL
class.weights ParamDbl -Inf Inf NULL
splitrule ParamFct NA NA gini , extratrees
num.random.splits ParamInt 1 Inf NULL
split.select.weights ParamDbl 0 1 NULL
always.split.variables ParamUty NA NA NULL
respect.unordered.factors ParamFct NA NA ignore , order , partition
scale.permutation.importance ParamLgl NA NA TRUE, FALSE
keep.inbag ParamLgl NA NA TRUE, FALSE
holdout ParamLgl NA NA TRUE, FALSE
num.threads ParamInt 1 Inf NULL
save.memory ParamLgl NA NA TRUE, FALSE
verbose ParamLgl NA NA TRUE, FALSE
oob.error ParamLgl NA NA TRUE, FALSE
max.depth ParamInt -Inf Inf NULL
alpha ParamDbl -Inf Inf NULL
min.prop ParamDbl -Inf Inf NULL
regularization.factor ParamUty NA NA NULL
regularization.usedepth ParamLgl NA NA TRUE, FALSE
seed ParamInt -Inf Inf NULL
minprop ParamDbl -Inf Inf NULL
se.method ParamFct NA NA jack , infjack
  • Lựa chọn tham số
ps_ranger = ps(
    num.trees = p_int(300, 800, tags = "budget"),
    mtry = p_int(8, 15),
    sample.fraction = p_dbl(0.7, 0.8)
  )

2.4 Khai báo resampling

# cross-validation with 5 folds
resampling_inner = rsmp("cv", folds = 5)
resampling_inner 
## <ResamplingCV> with 5 iterations
## * Instantiated: FALSE
## * Parameters: folds=5

2.5 Khai báo metric

measure = msr("classif.auc")
measure
## <MeasureBinarySimple:classif.auc>
## * Packages: mlr3measures
## * Range: [0, 1]
## * Minimize: FALSE
## * Parameters: list()
## * Properties: -
## * Predict type: prob

2.6 Khai báo thuật toán tunning

tuner = tnr("hyperband", eta = 2)
tuner
## <TunerHyperband>
## * Parameters: eta=2
## * Parameter classes: ParamLgl, ParamInt, ParamDbl, ParamFct
## * Properties: dependencies, single-crit, multi-crit
## * Packages: -

2.7 Thực hiện tunning

2.7.1 Khai báo object để tuning

tune_single_crit = TuningInstanceSingleCrit$new(
      task = task_classif,
      learner = classif_learner,
      resampling = resampling_inner,
      measure = measure,
      terminator = trm("none"), # hyperband terminates itself
      search_space = ps_ranger
    )
tune_single_crit
## <TuningInstanceSingleCrit>
## * State:  Not optimized
## * Objective: <ObjectiveTuning:classif.ranger_on_german_credit>
## * Search Space:
## <ParamSet>
##                 id    class lower upper nlevels        default value
## 1:       num.trees ParamInt 300.0 800.0     501 <NoDefault[3]>      
## 2:            mtry ParamInt   8.0  15.0       8 <NoDefault[3]>      
## 3: sample.fraction ParamDbl   0.7   0.8     Inf <NoDefault[3]>      
## * Terminator: <TerminatorNone>
## * Terminated: FALSE
## * Archive:
## <ArchiveTuning>
## Null data.table (0 rows and 0 cols)

2.7.2 Tunning

  • Kết quả tunning
tune_single_crit$archive
## <ArchiveTuning>
##    mtry sample.fraction num.trees bracket bracket_stage budget_scaled
## 1:   11            0.72       400       1             0           1.3
## 2:    9            0.76       400       1             0           1.3
## 3:    9            0.76       800       1             1           2.7
## 4:    9            0.76       800       0             0           2.7
## 5:   14            0.78       800       0             0           2.7
##    budget_real n_configs classif.auc           timestamp batch_nr
## 1:         400         2        0.79 2021-11-24 23:18:25        1
## 2:         400         2        0.79 2021-11-24 23:18:25        1
## 3:         800         1        0.79 2021-11-24 23:18:27        2
## 4:         800         2        0.79 2021-11-24 23:18:33        3
## 5:         800         2        0.79 2021-11-24 23:18:33        3
  • Dự báo

Áp tham số tốt nhất để training trên toán tập dữ liệu

tuned_learner <- classif_learner$clone()
tuned_learner$param_set$values = tune_single_crit$result_learner_param_vals # best paramters
tuned_learner$train(task_classif)

Kết quả dự báo

tuned_learner$predict(task_classif)
## <PredictionClassif> for 1000 observations:
##     row_ids truth response prob.good   prob.bad
##           1  good     good 0.8959236 0.10407639
##           2   bad      bad 0.2060665 0.79393353
##           3  good     good 0.9492312 0.05076885
## ---                                            
##         998  good     good 0.9708611 0.02913889
##         999   bad      bad 0.1580159 0.84198413
##        1000  good     good 0.6234355 0.37656448
LS0tDQp0aXRsZTogIkjGsOG7m25nIGThuqtuIG1scjMgY8ahIGLhuqNuIg0KYXV0aG9yOiAiTmd1eeG7hW4gTmfhu41jIELDrG5oIg0KZGF0ZTogImByIFN5cy5EYXRlKClgIg0Kb3V0cHV0Og0KICBodG1sX2RvY3VtZW50OiANCiAgICBjb2RlX2Rvd25sb2FkOiB0cnVlDQogICAgY29kZV9mb2xkaW5nOiBzaG93DQogICAgbnVtYmVyX3NlY3Rpb25zOiB5ZXMNCiAgICB0aGVtZTogImRlZmF1bHQiDQogICAgdG9jOiBUUlVFDQogICAgdG9jX2Zsb2F0OiBUUlVFDQogICAgZGV2OiAnc3ZnJw0KZWRpdG9yX29wdGlvbnM6IA0KICBjaHVua19vdXRwdXRfdHlwZTogY29uc29sZQ0KLS0tDQoNCmBgYHtyIHNldHVwLCBpbmNsdWRlPUZBTFNFfQ0Ka25pdHI6Om9wdHNfY2h1bmskc2V0KGVjaG8gPSBUUlVFKQ0KbGlicmFyeShrbml0cikNCmBgYA0KDQojIEPDoWMgYsaw4bubYyB4w6J5IGThu7FuZyBtw7QgaMOsbmggbWFjaGluZSBsZWFybmluZyBi4bqxbmcgbWxyMw0KDQotIHRhc2s6IEtoYWkgYsOhbyBuZ3Xhu5NuIGThu68gbGnhu4d1IMSR4bqndSB2w6BvLCBjw7MgdGjhu4MgbMOgIGRhdGEuZnJhbWUsIGRhdGFiYXNlIA0KLSBsZWFybmVyOiBLaGFpIGLDoW8gdGh14bqtdCB0b8OhbiBz4butIGThu6VuZy4gVsOtIGThu6UgcmFuZG9tIGZvcmVzdCwgeGdib29zdCwgbGlnaHRnYm0gLi4uDQotIHNlYXJjaF9zcGFjZTogS2hhaSBiw6FvIHRoYW0gc+G7kSBj4bqnbiB0dW5pbmcgY2hvIGxlYXJuZXIg4bufIHRyw6puLCB2w60gZOG7pSBudW1iZXIgb2YgdHJlZSwgbXRyeSAuLi4NCi0gbWVhc3VyZTogS2hhaSBiw6FvIHRpw6p1IGNow60gxJHDoW5oIGdpw6EgcGVyZm9ybWFuY2UgY+G7p2EgbGVhcm5lci4gVsOtIGThu6UgQVVDLCBDRSAuLi4NCi0gcmVzYW1wbGluZzogS2hhaSBiw6FvIGPDoWNoIHBow6JuIGNoaWEgbeG6q3UgdGjDoG5oIHThuq1wIHRyYWluLCB04bqtcCB0ZXN0LiBWw60gZOG7pTogY3Jvc3MgdmFsaWRhdGlvbiwgaG9sZG91dCwgLi4uDQotIHR1bmluZzogbOG7sWEgY2jhu41uIHRodeG6rXQgdG/DoW4gaGF5IGPDoWNoIMSR4buDIHR1bmluZy4gVsOtIGThu6U6IGh5cGVyYmFuZCwgZ3JpZF9zZWFyY2ggLi4uDQoNCiFbXShtbF9hYnN0cmFjdGlvbi5zdmcpDQoNCiMgVsOtIGThu6UNCg0KU+G7rSBk4bulbmcgdGh14bqtdCB0b8OhbiByYW5kb20gZm9yZXN0IMSR4buDIGThu7EgYsOhbyB4w6FjIHN14bqldCB24buhIG7hu6MgY+G7p2Ega2jDoWNoIGjDoG5nLg0KROG7ryBsaeG7h3UgxJHhuqd1IHbDoG8gbMOgIHThuq1wIGdlcm1hbl9jcmVkaXQuDQoNCmBgYHtyLCBpbmNsdWRlID0gRn0NCiMgbG9hZCBwYWNrYWdlcw0KbGlicmFyeShtbHIzKQ0KbGlicmFyeShtbHIzbGVhcm5lcnMpICMgZm9yIGxlYXJuZXINCmxpYnJhcnkobWxyM2h5cGVyYmFuZCkgIyBmb3IgdHVuaW5nDQpsaWJyYXJ5KHRpZHl2ZXJzZSkgIyBkYXRhIG1hbmlwdXRhdGlvbg0KYGBgDQoNCiMjIEtoYWkgYsOhbyB0YXNrDQpgYGB7cn0NCnRhc2tfY2xhc3NpZiA8LSB0c2soJ2dlcm1hbl9jcmVkaXQnKQ0KdGFza19jbGFzc2lmDQoNCmBgYA0KIyMgS2hhaSBiw6FvIGxlYXJuZXINCg0KYGBge3J9DQpjbGFzc2lmX2xlYXJuZXIgPC0gbHJuKCdjbGFzc2lmLnJhbmdlcicpDQpjbGFzc2lmX2xlYXJuZXINCg0KIyBjaOG7jW4gbG/huqFpIGThu7EgYsOhbw0KY2xhc3NpZl9sZWFybmVyJHByZWRpY3RfdHlwZSA8LSAicHJvYiINCmBgYA0KDQoNCiMjIEtoYWkgYsOhbyB0aGFtIHPhu5EgY+G6p24gdHVubmluZw0KDQotIFRyxrDhu5tjIGtoaSBraGFpIGLDoW8gdGhhbSBz4buRIGPhuqduIHR1bmluZyBuw6puIGtp4buDbSB0cmEgeGVtIGxlYXJuZXIgaOG7lyB0cuG7oyBuaOG7r25nIHRoYW0gc+G7kSBuw6BvIHbDoCB0aGFtIHPhu5EgxJHhuqV5IMSR4buLbmggZOG6oW5nIGZhY3RvciwgaGF5IG51bWVyaWMgLi4uIMSR4buDIGtoYWkgYsOhbyBjaG8gcGjDuSBo4bujcA0KYGBge3J9DQpjbGFzc2lmX2xlYXJuZXIkcGFyYW1fc2V0ICU+JSANCiAgYXMuZGF0YS50YWJsZSgpICU+JSANCiAgc2VsZWN0KGlkLAljbGFzcywJbG93ZXIsCXVwcGVyLAlsZXZlbHMpICU+JSANCiAga2FibGUoKQ0KYGBgDQoNCi0gTOG7sWEgY2jhu41uIHRoYW0gc+G7kSANCmBgYHtyfQ0KcHNfcmFuZ2VyID0gcHMoDQogICAgbnVtLnRyZWVzID0gcF9pbnQoMzAwLCA4MDAsIHRhZ3MgPSAiYnVkZ2V0IiksDQogICAgbXRyeSA9IHBfaW50KDgsIDE1KSwNCiAgICBzYW1wbGUuZnJhY3Rpb24gPSBwX2RibCgwLjcsIDAuOCkNCiAgKQ0KYGBgDQoNCiMjIEtoYWkgYsOhbyByZXNhbXBsaW5nDQoNCmBgYHtyfQ0KIyBjcm9zcy12YWxpZGF0aW9uIHdpdGggNSBmb2xkcw0KcmVzYW1wbGluZ19pbm5lciA9IHJzbXAoImN2IiwgZm9sZHMgPSA1KQ0KcmVzYW1wbGluZ19pbm5lciANCmBgYA0KDQojIyBLaGFpIGLDoW8gbWV0cmljDQpgYGB7cn0NCm1lYXN1cmUgPSBtc3IoImNsYXNzaWYuYXVjIikNCm1lYXN1cmUNCmBgYA0KDQoNCiMjIEtoYWkgYsOhbyB0aHXhuq10IHRvw6FuIHR1bm5pbmcNCg0KYGBge3J9DQp0dW5lciA9IHRucigiaHlwZXJiYW5kIiwgZXRhID0gMikNCnR1bmVyDQpgYGANCg0KIyMgVGjhu7FjIGhp4buHbiB0dW5uaW5nDQoNCiMjIyBLaGFpIGLDoW8gb2JqZWN0IMSR4buDIHR1bmluZw0KYGBge3J9DQp0dW5lX3NpbmdsZV9jcml0ID0gVHVuaW5nSW5zdGFuY2VTaW5nbGVDcml0JG5ldygNCiAgICAgIHRhc2sgPSB0YXNrX2NsYXNzaWYsDQogICAgICBsZWFybmVyID0gY2xhc3NpZl9sZWFybmVyLA0KICAgICAgcmVzYW1wbGluZyA9IHJlc2FtcGxpbmdfaW5uZXIsDQogICAgICBtZWFzdXJlID0gbWVhc3VyZSwNCiAgICAgIHRlcm1pbmF0b3IgPSB0cm0oIm5vbmUiKSwgIyBoeXBlcmJhbmQgdGVybWluYXRlcyBpdHNlbGYNCiAgICAgIHNlYXJjaF9zcGFjZSA9IHBzX3Jhbmdlcg0KICAgICkNCnR1bmVfc2luZ2xlX2NyaXQNCmBgYA0KDQojIyMgVHVubmluZw0KDQpgYGB7ciwgaW5jbHVkZT1GfQ0KIyBub3Qgc2hvdw0KdHVuZXIkb3B0aW1pemUodHVuZV9zaW5nbGVfY3JpdCkNCmBgYA0KDQotIEvhur90IHF14bqjIHR1bm5pbmcNCmBgYHtyfQ0KdHVuZV9zaW5nbGVfY3JpdCRhcmNoaXZlDQpgYGANCg0KLSBE4buxIGLDoW8NCg0Kw4FwIHRoYW0gc+G7kSB04buRdCBuaOG6pXQgxJHhu4MgdHJhaW5pbmcgdHLDqm4gdG/DoW4gdOG6rXAgZOG7ryBsaeG7h3UNCmBgYHtyfQ0KdHVuZWRfbGVhcm5lciA8LSBjbGFzc2lmX2xlYXJuZXIkY2xvbmUoKQ0KdHVuZWRfbGVhcm5lciRwYXJhbV9zZXQkdmFsdWVzID0gdHVuZV9zaW5nbGVfY3JpdCRyZXN1bHRfbGVhcm5lcl9wYXJhbV92YWxzICMgYmVzdCBwYXJhbXRlcnMNCnR1bmVkX2xlYXJuZXIkdHJhaW4odGFza19jbGFzc2lmKQ0KDQpgYGANCg0KS+G6v3QgcXXhuqMgZOG7sSBiw6FvDQpgYGB7cn0NCnR1bmVkX2xlYXJuZXIkcHJlZGljdCh0YXNrX2NsYXNzaWYpDQpgYGANCg0K