mlr于 2013年首次发布到CRAN。其核心设计和架构可以追溯到更早。由于各种原因,开发团队重构了mlr的功能,现在这个包已经改为mlr3 。
mlr3与tidymodels 类似,也是用于建模的一套R包,这些R包包括。 https://github.com/mlr-org/mlr3/wiki/Extension-Packages
mlr3 这一套包是使用R中的R6类编写的,我们首先学习R中的R6类。R中有几种面向对象系统,S3,S4,RC和R6 。
R6不是R自带的类,使用他需要安装并加载R6包。install.packages(“R6”) 。R6通过R6Class函数创建类以及其方法。
R6Class函数的最重要的两个参数:
我们看一个例子:
library(R6)
Accumulator <- R6Class("Accumulator", list(
sum = 0,
add = function(x = 1) {
self$sum <- self$sum + x
invisible(self)
})
)
Accumulator
## <Accumulator> object generator
## Public:
## sum: 0
## add: function (x = 1)
## clone: function (deep = FALSE)
## Parent env: <environment: R_GlobalEnv>
## Locked objects: TRUE
## Locked class: FALSE
## Portable: TRUE
使用R6类的new方法可以构建一个新的对象,在R6中,方法属于对象,使用美元符号访问方法和字段($)
x <- Accumulator$new()
x$add(100)
x$sum
## [1] 100
在这个类中,字段和方法是公共的,可以自由访问。
大多数类定义两个重要的方法:\(initialize()和\)print()。
Person <- R6Class("Person", list(
name = NULL,
age = NA,
initialize = function(name, age = NA) {
stopifnot(is.character(name), length(name) == 1)
stopifnot(is.numeric(age), length(age) == 1)
self$name <- name
self$age <- age
}
))
# hadley <- Person$new("Hadley", age = "thirty-eight")
hadley <- Person$new("Hadley", age = 38)
Person <- R6Class("Person", list(
name = NULL,
age = NA,
initialize = function(name, age = NA) {
self$name <- name
self$age <- age
},
print = function(...) {
cat("Person: \n")
cat(" Name: ", self$name, "\n", sep = "")
cat(" Age: ", self$age, "\n", sep = "")
invisible(self)
}
))
hadley2 <- Person$new("Hadley")
hadley2
## Person:
## Name: Hadley
## Age: NA
另外, 还可以修改现有类的字段和方法
Accumulator <- R6Class("Accumulator")
Accumulator$set("public", "sum", 0)
Accumulator$set("public", "add", function(x = 1) {
self$sum <- self$sum + x
invisible(self)
})
需要注意的是,新方法和字段仅对新对象可用;它们不会追溯添加到现有对象中。
要从现有类继承行为,请将类对象提供给inherit参数:
AccumulatorChatty <- R6Class("AccumulatorChatty",
inherit = Accumulator,
public = list(
add = function(x = 1) {
cat("Adding ", x, "\n", sep = "")
super$add(x = x)
}
)
)
x2 <- AccumulatorChatty$new()
x2$add(10)$add(1)$sum
## Adding 10
## Adding 1
## [1] 11
#> Adding 10
#> Adding 1
#> [1] 11
R6Class()还有两个与 类似的参数public:
Person <- R6Class("Person",
public = list(
initialize = function(name, age = NA) {
private$name <- name
private$age <- age
},
print = function(...) {
cat("Person: \n")
cat(" Name: ", private$name, "\n", sep = "")
cat(" Age: ", private$age, "\n", sep = "")
}
),
private = list(
age = NA,
name = NULL
)
)
hadley3 <- Person$new("Hadley")
hadley3
## Person:
## Name: Hadley
## Age: NA
hadley3$name
## NULL
引用与复制
y1 <- Accumulator$new()
y2 <- y1
y1$add(10)
c(y1 = y1$sum, y2 = y2$sum)
## y1 y2
## 10 10
如果你想要一个副本,你需要明确$clone()的对象:
y1 <- Accumulator$new()
y2 <- y1$clone()
y1$add(10)
c(y1 = y1$sum, y2 = y2$sum)
## y1 y2
## 10 0
先讲教程:https://mlr3book.mlr-org.com/
然后一个包一个包的介绍https://mlr3book.mlr-org.com/
第一步是安装mlr3.
我们使用iris数据集构建一个决策树模型。
library("mlr3")
## Warning: package 'mlr3' was built under R version 4.1.2
task = tsk("iris")
learner = lrn("classif.featureless")
# train a model of this learner for a subset of the task
learner$train(task, row_ids = 1:120)
# this is what the decision tree looks like
learner$model
## $tab
##
## setosa versicolor virginica
## 50 50 20
##
## $features
## [1] "Petal.Length" "Petal.Width" "Sepal.Length" "Sepal.Width"
##
## attr(,"class")
## [1] "classif.featureless_model"
如果想要查看有哪些可选模型,可以通过乳腺癌代码。
lrn()
## <DictionaryLearner> with 6 stored values
## Keys: classif.debug, classif.featureless, classif.rpart, regr.debug,
## regr.featureless, regr.rpart
更多的模型需要使用拓展包:mlr3learners,这些模型包括:
分类模型
回归模型
生存模型
我们使用构建好的模型进行预测,代码如下所示。
predictions = learner$predict(task, row_ids = 121:150)
predictions
## <PredictionClassif> for 30 observations:
## row_ids truth response
## 121 virginica versicolor
## 122 virginica versicolor
## 123 virginica versicolor
## ---
## 148 virginica versicolor
## 149 virginica versicolor
## 150 virginica versicolor
评估模型,计算准确度
# accuracy of our model on the test set of the final 30 rows
predictions$score(msr("classif.acc"))
## classif.acc
## 0
我们在这一部分稍微详细的介绍mlr3建模
Task 对象通常包含数据,用于定义机器学习的问题。
我们使用mtcars包中的数据集创建一个回归任务。
library("mlr3")
task_mtcars = as_task_regr(mtcars, target = "mpg", id = "cars")
print(task_mtcars)
## <TaskRegr:cars> (32 x 11)
## * Target: mpg
## * Properties: -
## * Features (10):
## - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt
class(task_mtcars)
## [1] "TaskRegr" "TaskSupervised" "Task" "R6"
as_task_regr 函数有三个参数,1 数据集 2 target 参数表示目标变量 3 id ,可以不设置
另外一种写法是
task_mtcars <- TaskRegr$new(backend = mtcars,target = "mpg",id = "cars")
task_mtcars
## <TaskRegr:cars> (32 x 11)
## * Target: mpg
## * Properties: -
## * Features (10):
## - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt
task_mtcars$print()
## <TaskRegr:cars> (32 x 11)
## * Target: mpg
## * Properties: -
## * Features (10):
## - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt
我们还可以使用mlr3viz 进行数据可视化。
library("mlr3viz")
## Warning: package 'mlr3viz' was built under R version 4.1.2
autoplot(task_mtcars$select(c("am","carb")), type = "pairs")
## Registered S3 method overwritten by 'GGally':
## method from
## +.gg ggplot2
mlr3包含一些预定义数据,使用mlr_tasks进行查看
mlr_tasks
## <DictionaryTask> with 11 stored values
## Keys: boston_housing, breast_cancer, german_credit, iris, mtcars,
## penguins, pima, sonar, spam, wine, zoo
可以将结果转变成为数据框
as.data.table(mlr_tasks)
## key label task_type nrow ncol properties lgl
## 1: boston_housing Boston Housing Prices regr 506 19 0
## 2: breast_cancer Wisconsin Breast Cancer classif 683 10 twoclass 0
## 3: german_credit German Credit classif 1000 21 twoclass 0
## 4: iris Iris Flowers classif 150 5 multiclass 0
## 5: mtcars Motor Trends regr 32 11 0
## 6: penguins Palmer Penguins classif 344 8 multiclass 0
## 7: pima Pima Indian Diabetes classif 768 9 twoclass 0
## 8: sonar Sonar: Mines vs. Rocks classif 208 61 twoclass 0
## 9: spam HP Spam Detection classif 4601 58 twoclass 0
## 10: wine Wine Regions classif 178 14 multiclass 0
## 11: zoo Zoo Animals classif 101 17 multiclass 15
## int dbl chr fct ord pxc
## 1: 3 13 0 2 0 0
## 2: 0 0 0 0 9 0
## 3: 3 0 0 14 3 0
## 4: 0 4 0 0 0 0
## 5: 0 10 0 0 0 0
## 6: 3 2 0 2 0 0
## 7: 0 8 0 0 0 0
## 8: 0 60 0 0 0 0
## 9: 0 57 0 0 0 0
## 10: 2 11 0 0 0 0
## 11: 1 0 0 0 0 0
想要获取数据,可以使用get方法
mlr_tasks$get("sonar")
## <TaskClassif:sonar> (208 x 61): Sonar: Mines vs. Rocks
## * Target: Class
## * Properties: twoclass
## * Features (60):
## - dbl (60): V1, V10, V11, V12, V13, V14, V15, V16, V17, V18, V19, V2,
## V20, V21, V22, V23, V24, V25, V26, V27, V28, V29, V3, V30, V31,
## V32, V33, V34, V35, V36, V37, V38, V39, V4, V40, V41, V42, V43,
## V44, V45, V46, V47, V48, V49, V5, V50, V51, V52, V53, V54, V55,
## V56, V57, V58, V59, V6, V60, V7, V8, V9
可以理解为一些自带的数据。使用tsk函数可以获取数据,代码如下所示。
task_penguins = tsk("penguins")
print(task_penguins)
## <TaskClassif:penguins> (344 x 8): Palmer Penguins
## * Target: species
## * Properties: multiclass
## * Features (7):
## - int (3): body_mass, flipper_length, year
## - dbl (2): bill_depth, bill_length
## - fct (2): island, sex
我们可以对task进行操作。
task_mtcars$nrow
## [1] 32
task_mtcars$ncol
## [1] 3
task_mtcars$data()
## mpg am carb
## 1: 21.0 1 4
## 2: 21.0 1 4
## 3: 22.8 1 1
## 4: 21.4 0 1
## 5: 18.7 0 2
## 6: 18.1 0 1
## 7: 14.3 0 4
## 8: 24.4 0 2
## 9: 22.8 0 2
## 10: 19.2 0 4
## 11: 17.8 0 4
## 12: 16.4 0 3
## 13: 17.3 0 3
## 14: 15.2 0 3
## 15: 10.4 0 4
## 16: 10.4 0 4
## 17: 14.7 0 4
## 18: 32.4 1 1
## 19: 30.4 1 2
## 20: 33.9 1 1
## 21: 21.5 0 1
## 22: 15.5 0 2
## 23: 15.2 0 2
## 24: 13.3 0 4
## 25: 19.2 0 2
## 26: 27.3 1 1
## 27: 26.0 1 2
## 28: 30.4 1 2
## 29: 15.8 1 4
## 30: 19.7 1 6
## 31: 15.0 1 8
## 32: 21.4 1 2
## mpg am carb
task_mtcars$feature_names
## [1] "am" "carb"
task_mtcars$target_names
## [1] "mpg"
如果是二分类变量,可以使用positive 属性定义哪一个是正样本。
# during construction
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "R")
# switch positive clas
可以为行和列分配不同的角色。这些角色会影响不同操作的任务行为。
print(task_mtcars$col_roles)
## $feature
## [1] "am" "carb"
##
## $target
## [1] "mpg"
##
## $name
## character(0)
##
## $order
## character(0)
##
## $stratum
## character(0)
##
## $group
## character(0)
##
## $weight
## character(0)
列可以没有角色(它们被忽略)或有多个角色。设置属性可以在初始化类的时候,也可以是使用set_col_roles方法。
task_mtcars$set_col_roles("rn", roles = "name")
还可以对task 进行简单的数据处理,包括筛选行,列等等
task_penguins = tsk("penguins")
task_penguins$select(c("body_mass", "flipper_length")) # keep only these features
task_penguins$filter(1:3) # keep only these rows
task_penguins$head()
## species body_mass flipper_length
## 1: Adelie 3750 181
## 2: Adelie 3800 186
## 3: Adelie 3250 195
task_penguins$cbind(data.frame(letters = letters[1:3])) # add column letters
task_penguins$head()
## species body_mass flipper_length letters
## 1: Adelie 3750 181 a
## 2: Adelie 3800 186 b
## 3: Adelie 3250 195 c
可视化任务
library("mlr3viz")
# get the pima indians task
task = tsk("pima")
# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))
# default plot: class frequencies
autoplot(task)
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")
## Warning in ggally_statistic(data = data, mapping = mapping, na.rm = na.rm, :
## Removed 5 rows containing missing values
## Warning in ggally_statistic(data = data, mapping = mapping, na.rm = na.rm, :
## Removed 374 rows containing missing values
## Warning in ggally_statistic(data = data, mapping = mapping, na.rm = na.rm, :
## Removed 375 rows containing missing values
## Warning: Removed 5 rows containing non-finite values (stat_boxplot).
## Warning: Removed 374 rows containing non-finite values (stat_boxplot).
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 5 rows containing non-finite values (stat_bin).
## Warning: Removed 5 rows containing missing values (geom_point).
## Warning: Removed 5 rows containing non-finite values (stat_density).
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 374 rows containing non-finite values (stat_bin).
## Warning: Removed 374 rows containing missing values (geom_point).
## Warning: Removed 375 rows containing missing values (geom_point).
## Warning: Removed 374 rows containing non-finite values (stat_density).
# duo plot (requires package GGally)
autoplot(task, type = "duo")
## Warning: Removed 5 rows containing non-finite values (stat_boxplot).
## Warning: Removed 374 rows containing non-finite values (stat_boxplot).
library("mlr3viz")
# get the complete mtcars task
task = tsk("mtcars")
# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))
# default plot: boxplot of target variable
autoplot(task)
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")
该mlr3软件包附带以下一组分类和回归学习器。
使用lrm函数查看有哪些模型可选
lrn()
## <DictionaryLearner> with 6 stored values
## Keys: classif.debug, classif.featureless, classif.rpart, regr.debug,
## regr.featureless, regr.rpart
mlr3learners 包提供了其他更多的方法。
mlr3extralearners 包也提供了很多方法。remotes::install_github("mlr-org/mlr3extralearners")
mlr3extralearners::list_mlr3learners()
## name class id mlr3_package
## 1: AdaBoostM1 classif classif.AdaBoostM1 mlr3extralearners
## 2: bart classif classif.bart mlr3extralearners
## 3: C50 classif classif.C50 mlr3extralearners
## 4: catboost classif classif.catboost mlr3extralearners
## 5: cforest classif classif.cforest mlr3extralearners
## ---
## 132: ranger surv surv.ranger mlr3learners
## 133: rfsrc surv surv.rfsrc mlr3extralearners
## 134: rpart surv surv.rpart mlr3proba
## 135: svm surv surv.svm mlr3extralearners
## 136: xgboost surv surv.xgboost mlr3learners
## required_packages
## 1: mlr3,mlr3extralearners,RWeka
## 2: mlr3,mlr3extralearners,dbarts
## 3: mlr3,mlr3extralearners,C50
## 4: mlr3,mlr3extralearners,catboost
## 5: mlr3,mlr3extralearners,partykit,sandwich,coin
## ---
## 132: mlr3,mlr3proba,mlr3learners,ranger
## 133: mlr3,mlr3proba,mlr3extralearners,randomForestSRC,pracma
## 134: mlr3,mlr3proba,rpart,distr6,survival
## 135: mlr3,mlr3proba,mlr3extralearners,survivalsvm
## 136: mlr3,mlr3proba,mlr3learners,xgboost
## properties
## 1: multiclass,twoclass
## 2: twoclass,weights
## 3: missings,multiclass,twoclass,weights
## 4: importance,missings,multiclass,twoclass,weights
## 5: multiclass,oob_error,twoclass,weights
## ---
## 132: importance,oob_error,weights
## 133: importance,missings,oob_error,weights
## 134: importance,missings,selected_features,weights
## 135:
## 136: importance,missings,weights
## feature_types predict_types
## 1: numeric,factor,ordered,integer response,prob
## 2: integer,numeric,factor,ordered response,prob
## 3: numeric,factor,ordered response,prob
## 4: numeric,factor,ordered response,prob
## 5: integer,numeric,factor,ordered response,prob
## ---
## 132: logical,integer,numeric,character,factor,ordered distr,crank
## 133: logical,integer,numeric,factor crank,distr
## 134: logical,integer,numeric,character,factor,ordered crank,distr
## 135: integer,numeric,character,factor,logical crank,response
## 136: integer,numeric crank,lp
也可以使用(“mlr3verse”)包中的mlr_learners对象查看所有可用的模型。
library("mlr3verse")
## Warning: package 'mlr3verse' was built under R version 4.1.2
mlr_learners
## <DictionaryLearner> with 136 stored values
## Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
## classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
## classif.earth, classif.extratrees, classif.featureless, classif.fnn,
## classif.gam, classif.gamboost, classif.gausspr, classif.gbm,
## classif.glmboost, classif.glmnet, classif.IBk, classif.J48,
## classif.JRip, classif.kknn, classif.ksvm, classif.lda,
## classif.liblinear, classif.lightgbm, classif.LMT, classif.log_reg,
## classif.lssvm, classif.mob, classif.multinom, classif.naive_bayes,
## classif.nnet, classif.OneR, classif.PART, classif.qda,
## classif.randomForest, classif.ranger, classif.rfsrc, classif.rpart,
## classif.svm, classif.xgboost, clust.agnes, clust.ap, clust.cmeans,
## clust.cobweb, clust.dbscan, clust.diana, clust.em, clust.fanny,
## clust.featureless, clust.ff, clust.hclust, clust.kkmeans,
## clust.kmeans, clust.MBatchKMeans, clust.meanshift, clust.pam,
## clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde, dens.kde_kd,
## dens.kde_ks, dens.locfit, dens.logspline, dens.mixed, dens.nonpar,
## dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
## regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.debug,
## regr.earth, regr.extratrees, regr.featureless, regr.fnn, regr.gam,
## regr.gamboost, regr.gausspr, regr.gbm, regr.glm, regr.glmboost,
## regr.glmnet, regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
## regr.lightgbm, regr.lm, regr.M5Rules, regr.mars, regr.mob,
## regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
## regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
## surv.coxboost, surv.coxph, surv.coxtime, surv.ctree,
## surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv,
## surv.dnnsurv, surv.flexible, surv.gamboost, surv.gbm, surv.glmboost,
## surv.glmnet, surv.kaplan, surv.loghaz, surv.mboost, surv.nelson,
## surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized,
## surv.ranger, surv.rfsrc, surv.rpart, surv.svm, surv.xgboost
训练模型需要创建learner对象
learner = lrn("classif.rpart")
print(learner)
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
learner对象包含了很多信息,包括训练模型需要的包,目标变量的类别,特征类别,其他属性和函数。
每一个learner 类都有其对应的超参数
learner$param_set
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
## 7: minsplit ParamInt 1 Inf Inf 20
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
我们可以直接修改参数
learner$param_set$values = list(cp = 0.01, xval = 0)
learner
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: -
## * Parameters: cp=0.01, xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
此操作会覆盖所有先前设置的参数。您还可以获取当前的一组超参数值,对其进行修改,然后将其写回学习器:
pv = learner$param_set$values
pv$cp = 0.02
learner$param_set$values = pv
另外,还可以在创建learner对象的时候指定参数。
learner = lrn("classif.rpart", id = "rp", cp = 0.001)
learner$id
## [1] "rp"
learner$param_set$values
## $xval
## [1] 0
##
## $cp
## [1] 0.001
对于分类任务,模型通常会给予预测概率,如果概率大于 50%,则预测正标签,否则预测负标签,当然我们可以修改阈值。
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "M")
learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)
measures = msrs(c("classif.tpr", "classif.tnr")) # use msrs() to get a list of multiple measures
pred$confusion
## truth
## response M R
## M 95 10
## R 16 87
pred$score(measures)
## classif.tpr classif.tnr
## 0.8558559 0.8969072
pred$set_threshold(0.2)
pred$confusion
## truth
## response M R
## M 104 25
## R 7 72
pred$score(measures)
## classif.tpr classif.tnr
## 0.9369369 0.7422680
可以使用mlr3pipelines包根据性能度量自动调整阈值,即使用PipeOpTuneThreshold
我们再看一个例子。
library("mlr3verse")
task = tsk("penguins")
learner = lrn("classif.rpart")
train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)
learner$model
## NULL
learner$train(task, row_ids = train_set)
print(learner$model) # 查看模型
## n= 275
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 275 150 Adelie (0.454545455 0.203636364 0.341818182)
## 2) flipper_length< 206.5 176 53 Adelie (0.698863636 0.295454545 0.005681818)
## 4) bill_length< 43.35 123 3 Adelie (0.975609756 0.024390244 0.000000000) *
## 5) bill_length>=43.35 53 4 Chinstrap (0.056603774 0.924528302 0.018867925) *
## 3) flipper_length>=206.5 99 6 Gentoo (0.020202020 0.040404040 0.939393939)
## 6) bill_depth>=17.2 8 4 Chinstrap (0.250000000 0.500000000 0.250000000) *
## 7) bill_depth< 17.2 91 0 Gentoo (0.000000000 0.000000000 1.000000000) *
# 进行预测
prediction = learner$predict(task, row_ids = test_set)
print(prediction)
## <PredictionClassif> for 69 observations:
## row_ids truth response
## 1 Adelie Adelie
## 4 Adelie Adelie
## 5 Adelie Adelie
## ---
## 334 Chinstrap Chinstrap
## 335 Chinstrap Chinstrap
## 340 Chinstrap Chinstrap
# 将预测结果转化成为data.frame
head(as.data.table(prediction)) # show first six predictions
## row_ids truth response
## 1: 1 Adelie Adelie
## 2: 4 Adelie Adelie
## 3: 5 Adelie Adelie
## 4: 16 Adelie Adelie
## 5: 22 Adelie Adelie
## 6: 26 Adelie Adelie
# 计算混淆矩阵
prediction$confusion
## truth
## response Adelie Chinstrap Gentoo
## Adelie 26 2 0
## Chinstrap 1 10 0
## Gentoo 0 0 30
# 修改预测类型
learner$predict_type = "prob"
# re-fit the model
learner$train(task, row_ids = train_set)
# rebuild prediction object
prediction = learner$predict(task, row_ids = test_set)
# data.table conversion
head(as.data.table(prediction)) # show first six
## row_ids truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 1: 1 Adelie Adelie 0.9756098 0.02439024 0
## 2: 4 Adelie Adelie 0.9756098 0.02439024 0
## 3: 5 Adelie Adelie 0.9756098 0.02439024 0
## 4: 16 Adelie Adelie 0.9756098 0.02439024 0
## 5: 22 Adelie Adelie 0.9756098 0.02439024 0
## 6: 26 Adelie Adelie 0.9756098 0.02439024 0
# directly access the predicted labels:
head(prediction$response)
## [1] Adelie Adelie Adelie Adelie Adelie Adelie
## Levels: Adelie Chinstrap Gentoo
# directly access the matrix of probabilities:
head(prediction$prob)
## Adelie Chinstrap Gentoo
## [1,] 0.9756098 0.02439024 0
## [2,] 0.9756098 0.02439024 0
## [3,] 0.9756098 0.02439024 0
## [4,] 0.9756098 0.02439024 0
## [5,] 0.9756098 0.02439024 0
## [6,] 0.9756098 0.02439024 0
# 对结果进行可视化
task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)
建模的最后一步通常是评估训练模型的性能,我们可以使用mlr_measures查看有哪些评估方法
mlr_measures
## <DictionaryMeasure> with 88 stored values
## Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
## classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
## classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
## classif.logloss, classif.mbrier, classif.mcc, classif.npv,
## classif.ppv, classif.prauc, classif.precision, classif.recall,
## classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
## classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
## clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
## regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
## regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
## regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
## regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi,
## surv.brier, surv.calib_alpha, surv.calib_beta, surv.chambless_auc,
## surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
## surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2,
## surv.rcll, surv.rmse, surv.schmid, surv.song_auc, surv.song_tnr,
## surv.song_tpr, surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2,
## time_both, time_predict, time_train
例如我们选择准确率
measure = msr("classif.acc")
print(measure)
## <MeasureClassifSimple:classif.acc>: Classification Accuracy
## * Packages: mlr3, mlr3measures
## * Range: [0, 1]
## * Minimize: FALSE
## * Average: macro
## * Parameters: list()
## * Properties: -
## * Predict type: response
prediction$score(measure)
## classif.acc
## 0.9651163
我们首先来分析二分类问题,例如绘制ROC曲线和混淆矩阵
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "M")
learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)
C = pred$confusion
print(C)
## truth
## response M R
## M 95 10
## R 16 87
要绘制ROC曲线,有很多种方式:
library("mlr3viz")
# TPR vs FPR / Sensitivity vs (1 - Specificity)
autoplot(pred, type = "roc")
我们还可以绘制精确召回曲线(PPV 与 TPR)。ROC 曲线和精确召回曲线 (PRC) 之间的主要区别在于,真阴性结果的数量不用于制作 PRC。对于不平衡的人群,PRC 优于 ROC 曲线。
# Precision vs Recall
autoplot(pred, type = "prc")
在评估模型的性能时,我们对它的泛化性能感兴趣——它在训练期间没有看到的新数据上的表现如何?我们可以通过评估测试集上的模型来估计泛化性能。
将数据集划分为训练和测试有许多不同的策略;我们将mlr3这些策略称为“重采样”。 mlr3包括以下预定义的重采样策略:
我们来看一个例子
library("mlr3verse")
task = tsk("penguins")
learner = lrn("classif.rpart")
在对数据集执行重采样时,我们首先需要定义应该使用哪种方法。
mlr_resamplings这还列出了可以更改以影响每个策略的行为的参数:
as.data.table(mlr_resamplings)
## key label params iters
## 1: bootstrap Bootstrap ratio,repeats 30
## 2: custom Custom Splits NA
## 3: custom_cv Custom Split Cross-Validation NA
## 4: cv Cross-Validation folds 10
## 5: holdout Holdout ratio 1
## 6: insample Insample Resampling 1
## 7: loo Leave-One-Out NA
## 8: repeated_cv Repeated Cross-Validation folds,repeats 100
## 9: subsampling Subsampling ratio,repeats 30
特殊用例的其他重采样方法可通过扩展包获得,例如用于空间数据的mlr3spatiotemporal 。
我们使用holdout 进行重抽样
resampling <- rsmp("holdout")
print(resampling)
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.6667
也可以是
resampling <- mlr_resamplings$get("holdout")
resampling
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.6667
需要注意的是,该$is_instantiated字段设置为FALSE。这意味着我们实际上还没有将策略应用于数据集。
默认情况下,我们将数据以 .66/.33 的比例分成训练和测试。有两种方法可以更改此比率:
resampling$param_set$values = list(ratio = 0.8)
rsmp("holdout", ratio = 0.8)
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.8
以上,我们定义好了重采样的策略,下一步就是添加数据进行实例化。
要实际执行拆分并获得训练和测试拆分的索引,重采样需要一个Task. 通过调用该方法instantiate(),我们将数据的索引拆分为训练集和测试集的索引。这些结果索引存储在Resampling对象中。
resampling$instantiate(task)
str(resampling$train_set(1))
## int [1:275] 1 2 3 4 5 6 7 10 11 13 ...
str(resampling$test_set(1))
## int [1:69] 8 9 12 15 26 35 36 41 42 52 ...
下一步是调用resample函数
task = tsk("penguins")
learner = lrn("classif.rpart", maxdepth = 3, predict_type = "prob")
resampling = rsmp("cv", folds = 3)
rr = resample(task, learner, resampling, store_models = TRUE)
## INFO [23:35:27.806] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 1/3)
## INFO [23:35:27.837] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 2/3)
## INFO [23:35:27.851] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 3/3)
print(rr)
## <ResampleResult> of 3 iterations
## * Task: penguins
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
在这里,我们使用三重交叉验证重采样,它在三个不同的训练和测试集上训练和评估。返回的ResampleResult,存储为rr。
我们可以从rr获取很多信息。根据分类错误计算所有重采样迭代的平均性能:
rr$aggregate(msr("classif.ce"))
## classif.ce
## 0.07561658
提取各个重采样迭代的性能:
rr$score(msr("classif.ce"))
## task task_id learner learner_id
## 1: <TaskClassif[50]> penguins <LearnerClassifRpart[38]> classif.rpart
## 2: <TaskClassif[50]> penguins <LearnerClassifRpart[38]> classif.rpart
## 3: <TaskClassif[50]> penguins <LearnerClassifRpart[38]> classif.rpart
## resampling resampling_id iteration prediction
## 1: <ResamplingCV[20]> cv 1 <PredictionClassif[20]>
## 2: <ResamplingCV[20]> cv 2 <PredictionClassif[20]>
## 3: <ResamplingCV[20]> cv 3 <PredictionClassif[20]>
## classif.ce
## 1: 0.09565217
## 2: 0.04347826
## 3: 0.08771930
查看警告和错误
rr$warnings
## Empty data.table (0 rows and 2 cols): iteration,msg
rr$errors
## Empty data.table (0 rows and 2 cols): iteration,msg
提取并检查重采样拆分;
rr$resampling
## <ResamplingCV>: Cross-Validation
## * Iterations: 3
## * Instantiated: TRUE
## * Parameters: folds=3
查看模型
lrn = rr$learners
lrn
## [[1]]
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: rpart
## * Parameters: xval=0, maxdepth=3
## * Packages: mlr3, rpart
## * Predict Type: prob
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
##
## [[2]]
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: rpart
## * Parameters: xval=0, maxdepth=3
## * Packages: mlr3, rpart
## * Predict Type: prob
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
##
## [[3]]
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: rpart
## * Parameters: xval=0, maxdepth=3
## * Packages: mlr3, rpart
## * Predict Type: prob
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
提取单个预测:
rr$prediction() # all predictions merged into a single Prediction object
## <PredictionClassif> for 344 observations:
## row_ids truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 8 Adelie Adelie 0.9719626 0.02803738 0.0000000
## 11 Adelie Adelie 0.9719626 0.02803738 0.0000000
## 15 Adelie Adelie 0.9719626 0.02803738 0.0000000
## ---
## 333 Chinstrap Chinstrap 0.1063830 0.87234043 0.0212766
## 334 Chinstrap Chinstrap 0.1063830 0.87234043 0.0212766
## 344 Chinstrap Chinstrap 0.1063830 0.87234043 0.0212766
rr$predictions()
## [[1]]
## <PredictionClassif> for 115 observations:
## row_ids truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 8 Adelie Adelie 0.97196262 0.02803738 0.0000000
## 11 Adelie Adelie 0.97196262 0.02803738 0.0000000
## 15 Adelie Adelie 0.97196262 0.02803738 0.0000000
## ---
## 340 Chinstrap Gentoo 0.00000000 0.02352941 0.9764706
## 342 Chinstrap Chinstrap 0.05405405 0.94594595 0.0000000
## 343 Chinstrap Gentoo 0.00000000 0.02352941 0.9764706
##
## [[2]]
## <PredictionClassif> for 115 observations:
## row_ids truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 2 Adelie Adelie 0.95049505 0.04950495 0.00000000
## 3 Adelie Adelie 0.95049505 0.04950495 0.00000000
## 4 Adelie Adelie 0.95049505 0.04950495 0.00000000
## ---
## 338 Chinstrap Chinstrap 0.02272727 0.95454545 0.02272727
## 339 Chinstrap Chinstrap 0.02272727 0.95454545 0.02272727
## 341 Chinstrap Adelie 0.95049505 0.04950495 0.00000000
##
## [[3]]
## <PredictionClassif> for 114 observations:
## row_ids truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 1 Adelie Adelie 0.9892473 0.01075269 0.0000000
## 5 Adelie Adelie 0.9892473 0.01075269 0.0000000
## 6 Adelie Adelie 0.9892473 0.01075269 0.0000000
## ---
## 333 Chinstrap Chinstrap 0.1063830 0.87234043 0.0212766
## 334 Chinstrap Chinstrap 0.1063830 0.87234043 0.0212766
## 344 Chinstrap Chinstrap 0.1063830 0.87234043 0.0212766
过滤结果以仅保留指定的重采样迭代:
rr$filter(c(1, 3))
print(rr)
## <ResampleResult> of 2 iterations
## * Task: penguins
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
还可以自定义重采样,有时有必要使用自定义拆分执行重新采样,可以使用“custom”模板创建手动重采样实例。
resampling = rsmp("custom")
resampling$instantiate(task,
train = list(c(1:10, 51:60, 101:110)), # 训练集的行
test = list(c(11:20, 61:70, 111:120)) # 测试集的行
)
resampling$iters
## [1] 1
resampling$train_set(1)
## [1] 1 2 3 4 5 6 7 8 9 10 51 52 53 54 55 56 57 58 59
## [20] 60 101 102 103 104 105 106 107 108 109 110
resampling$test_set(1)
## [1] 11 12 13 14 15 16 17 18 19 20 61 62 63 64 65 66 67 68 69
## [20] 70 111 112 113 114 115 116 117 118 119 120
mlr3viz提供了autoplot()一种重新采样结果的方法。例如,我们创建一个具有两个特征的二元分类任务,使用 10 倍交叉验证执行重采样并可视化结果:
task = tsk("pima")
task$select(c("glucose", "mass"))
learner = lrn("classif.rpart", predict_type = "prob")
rr = resample(task, learner, rsmp("cv"), store_models = TRUE)
## INFO [23:35:28.042] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 6/10)
## INFO [23:35:28.056] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 9/10)
## INFO [23:35:28.075] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 7/10)
## INFO [23:35:28.089] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/10)
## INFO [23:35:28.102] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/10)
## INFO [23:35:28.115] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 5/10)
## INFO [23:35:28.129] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 10/10)
## INFO [23:35:28.141] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/10)
## INFO [23:35:28.153] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 8/10)
## INFO [23:35:28.164] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/10)
# boxplot of AUC values across the 10 folds
autoplot(rr, measure = msr("classif.auc"))
绘制ROC 曲线
# ROC curve, averaged over 10 folds
autoplot(rr, type = "roc")
我们还可以绘制单个模型的预测:
# learner predictions for the first fold
rr$filter(1)
autoplot(rr, type = "prediction")
## Warning: Removed 3 rows containing missing values (geom_point).
比较不同学习器在多个任务和/或不同重采样方案上的表现是一项常见任务。这种操作在机器学习领域通常被称为“基准测试”
mlr3通过设计指定基准实验。这样的设计本质上是一张要评估的场景表;Task特别是,Learner和Resampling三元组的独特组合。
我们使用benchmark_grid()函数创建一个详尽的设计(通过每次重采样评估每个任务的每个学习者)并正确实例化重采样,以便所有学习者在每个任务的相同训练/测试拆分上执行。我们让学习者预测概率,并告诉他们预测训练集的观察结果(通过设置predict_sets为c(“train”, “test”))。此外,我们使用tsks(),lrns()和以与,和相同的方式rsmps()检索,和的列表。TaskLearnerResamplingtsk()lrn()rsmp()
library("mlr3verse")
design = benchmark_grid(
tasks = tsks(c("spam", "german_credit", "sonar")),
learners = lrns(c("classif.ranger", "classif.rpart", "classif.featureless"),
predict_type = "prob", predict_sets = c("train", "test")),
resamplings = rsmps("cv", folds = 3)
)
print(design)
## task learner resampling
## 1: <TaskClassif[50]> <LearnerClassifRanger[38]> <ResamplingCV[20]>
## 2: <TaskClassif[50]> <LearnerClassifRpart[38]> <ResamplingCV[20]>
## 3: <TaskClassif[50]> <LearnerClassifFeatureless[38]> <ResamplingCV[20]>
## 4: <TaskClassif[50]> <LearnerClassifRanger[38]> <ResamplingCV[20]>
## 5: <TaskClassif[50]> <LearnerClassifRpart[38]> <ResamplingCV[20]>
## 6: <TaskClassif[50]> <LearnerClassifFeatureless[38]> <ResamplingCV[20]>
## 7: <TaskClassif[50]> <LearnerClassifRanger[38]> <ResamplingCV[20]>
## 8: <TaskClassif[50]> <LearnerClassifRpart[38]> <ResamplingCV[20]>
## 9: <TaskClassif[50]> <LearnerClassifFeatureless[38]> <ResamplingCV[20]>
创建的design可以传递benchmark()给开始计算。
bmr = benchmark(design)
## INFO [23:35:29.275] [mlr3] Running benchmark with 27 resampling iterations
## INFO [23:35:29.279] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 3/3)
## INFO [23:35:29.563] [mlr3] Applying learner 'classif.ranger' on task 'sonar' (iter 1/3)
## INFO [23:35:29.647] [mlr3] Applying learner 'classif.featureless' on task 'spam' (iter 1/3)
## INFO [23:35:29.669] [mlr3] Applying learner 'classif.rpart' on task 'spam' (iter 1/3)
## INFO [23:35:29.785] [mlr3] Applying learner 'classif.featureless' on task 'sonar' (iter 1/3)
## INFO [23:35:29.796] [mlr3] Applying learner 'classif.featureless' on task 'german_credit' (iter 1/3)
## INFO [23:35:29.807] [mlr3] Applying learner 'classif.rpart' on task 'sonar' (iter 3/3)
## INFO [23:35:29.832] [mlr3] Applying learner 'classif.rpart' on task 'spam' (iter 2/3)
## INFO [23:35:29.903] [mlr3] Applying learner 'classif.ranger' on task 'spam' (iter 3/3)
## INFO [23:35:31.255] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 2/3)
## INFO [23:35:31.285] [mlr3] Applying learner 'classif.ranger' on task 'sonar' (iter 3/3)
## INFO [23:35:31.365] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 2/3)
## INFO [23:35:31.632] [mlr3] Applying learner 'classif.featureless' on task 'sonar' (iter 2/3)
## INFO [23:35:31.644] [mlr3] Applying learner 'classif.featureless' on task 'german_credit' (iter 2/3)
## INFO [23:35:31.656] [mlr3] Applying learner 'classif.rpart' on task 'sonar' (iter 1/3)
## INFO [23:35:31.681] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 1/3)
## INFO [23:35:31.936] [mlr3] Applying learner 'classif.rpart' on task 'sonar' (iter 2/3)
## INFO [23:35:31.961] [mlr3] Applying learner 'classif.rpart' on task 'spam' (iter 3/3)
## INFO [23:35:32.032] [mlr3] Applying learner 'classif.ranger' on task 'spam' (iter 2/3)
## INFO [23:35:33.406] [mlr3] Applying learner 'classif.featureless' on task 'german_credit' (iter 3/3)
## INFO [23:35:33.419] [mlr3] Applying learner 'classif.featureless' on task 'sonar' (iter 3/3)
## INFO [23:35:33.429] [mlr3] Applying learner 'classif.featureless' on task 'spam' (iter 3/3)
## INFO [23:35:33.452] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 3/3)
## INFO [23:35:33.485] [mlr3] Applying learner 'classif.ranger' on task 'spam' (iter 1/3)
## INFO [23:35:34.830] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 1/3)
## INFO [23:35:34.857] [mlr3] Applying learner 'classif.featureless' on task 'spam' (iter 2/3)
## INFO [23:35:34.879] [mlr3] Applying learner 'classif.ranger' on task 'sonar' (iter 2/3)
## INFO [23:35:34.960] [mlr3] Finished benchmark
一旦基准测试完成(并且,根据您的设计大小,这可能需要相当长的时间),我们可以使用aggregate来度量训练集和测试集合的效果。
measures = list(
msr("classif.auc", predict_sets = "train", id = "auc_train"),
msr("classif.auc", id = "auc_test")
)
tab = bmr$aggregate(measures)
print(tab)
## nr resample_result task_id learner_id resampling_id
## 1: 1 <ResampleResult[22]> spam classif.ranger cv
## 2: 2 <ResampleResult[22]> spam classif.rpart cv
## 3: 3 <ResampleResult[22]> spam classif.featureless cv
## 4: 4 <ResampleResult[22]> german_credit classif.ranger cv
## 5: 5 <ResampleResult[22]> german_credit classif.rpart cv
## 6: 6 <ResampleResult[22]> german_credit classif.featureless cv
## 7: 7 <ResampleResult[22]> sonar classif.ranger cv
## 8: 8 <ResampleResult[22]> sonar classif.rpart cv
## 9: 9 <ResampleResult[22]> sonar classif.featureless cv
## iters auc_train auc_test
## 1: 3 0.9994479 0.9849603
## 2: 3 0.9041325 0.8951998
## 3: 3 0.5000000 0.5000000
## 4: 3 0.9983067 0.7986295
## 5: 3 0.8163585 0.7034258
## 6: 3 0.5000000 0.5000000
## 7: 3 1.0000000 0.8857579
## 8: 3 0.9242398 0.7547127
## 9: 3 0.5000000 0.5000000
我们可以进一步汇总结果。例如,我们可能想知道哪个学习者在所有任务中表现最好。
library("data.table")
# group by levels of task_id, return columns:
# - learner_id
# - rank of col '-auc_train' (per level of learner_id)
# - rank of col '-auc_test' (per level of learner_id)
ranks = tab[, .(learner_id, rank_train = rank(-auc_train), rank_test = rank(-auc_test)), by = task_id]
print(ranks)
## task_id learner_id rank_train rank_test
## 1: spam classif.ranger 1 1
## 2: spam classif.rpart 2 2
## 3: spam classif.featureless 3 3
## 4: german_credit classif.ranger 1 1
## 5: german_credit classif.rpart 2 2
## 6: german_credit classif.featureless 3 3
## 7: sonar classif.ranger 1 1
## 8: sonar classif.rpart 2 2
## 9: sonar classif.featureless 3 3
# group by levels of learner_id, return columns:
# - mean rank of col 'rank_train' (per level of learner_id)
# - mean rank of col 'rank_test' (per level of learner_id)
ranks = ranks[, .(mrank_train = mean(rank_train), mrank_test = mean(rank_test)), by = learner_id]
# print the final table, ordered by mean rank of AUC test
ranks[order(mrank_test)]
## learner_id mrank_train mrank_test
## 1: classif.ranger 1 1
## 2: classif.rpart 2 2
## 3: classif.featureless 3 3
我们还可以对基准结果进行可视化。
autoplot(bmr) + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))
我们还可以绘制 ROC(接收器操作特性)曲线。
bmr_small = bmr$clone()$filter(task_id = "german_credit")
autoplot(bmr_small, type = "roc")
提取重采样结果。一个BenchmarkResult对象本质上是多个ResampleResult对象的集合。由于这些存储在聚合的列中data.table(),我们可以轻松地提取它们:
tab = bmr$aggregate(measures)
rr = tab[task_id == "german_credit" & learner_id == "classif.ranger"]$resample_result[[1]]
print(rr)
## <ResampleResult> of 3 iterations
## * Task: german_credit
## * Learner: classif.ranger
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
查看auc
measure = msr("classif.auc")
rr$aggregate(measure)
## classif.auc
## 0.7986295
# get the iteration with worst AUC
perf = rr$score(measure)
i = which.min(perf$classif.auc)
# get the corresponding learner and training set
print(rr$learners[[i]])
## <LearnerClassifRanger:classif.ranger>
## * Model: -
## * Parameters: num.threads=1
## * Packages: mlr3, mlr3learners, ranger
## * Predict Type: prob
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: hotstart_backward, importance, multiclass, oob_error,
## twoclass, weights
head(rr$resampling$train_set(i))
## [1] 1 5 7 8 10 13
转换与合并
ResampleResult可以BenchmarkResult使用函数将A转换为 a as_benchmark_result()。我们还可以将两个合并BenchmarkResults为一个更大的结果对象,例如在不同机器上完成的两个相关基准测试。
task = tsk("iris")
resampling = rsmp("holdout")$instantiate(task)
rr1 = resample(task, lrn("classif.rpart"), resampling)
## INFO [23:35:36.425] [mlr3] Applying learner 'classif.rpart' on task 'iris' (iter 1/1)
rr2 = resample(task, lrn("classif.featureless"), resampling)
## INFO [23:35:36.450] [mlr3] Applying learner 'classif.featureless' on task 'iris' (iter 1/1)
# Cast both ResampleResults to BenchmarkResults
bmr1 = as_benchmark_result(rr1)
bmr2 = as_benchmark_result(rr2)
# Merge 2nd BMR into the first BMR
bmr1$combine(bmr2)
bmr1
## <BenchmarkResult> of 2 rows with 2 resampling runs
## nr task_id learner_id resampling_id iters warnings errors
## 1 iris classif.rpart holdout 1 0 0
## 2 iris classif.featureless holdout 1 0 0
模型优化有很多种角度,例如选择更加优秀的特征,调整参数等等。
通过mlr3tuning扩展包支持超参数调整。
我们来看一个简单的例子。 首先是定义好task
library("mlr3verse")
task = tsk("pima")
print(task)
## <TaskClassif:pima> (768 x 9): Pima Indian Diabetes
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
## - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
## triceps
然后定义好学习器
learner = lrn("classif.rpart")
learner$param_set
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
## 7: minsplit ParamInt 1 Inf Inf 20
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
在这里,我们选择调整两个超参数:
调整空间需要以超参数值的下限和上限为界:
search_space = ps(
cp = p_dbl(lower = 0.001, upper = 0.1),
minsplit = p_int(lower = 1, upper = 10)
)
search_space
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>
接下来,我们需要指定如何评估训练模型的性能。
hout = rsmp("holdout")
measure = msr("classif.ce")
最后,我们必须指定可用于调整的终止条件。这是至关重要的一步,因为详尽地评估所有可能的超参数配置通常是不可行的。
mlr3允许通过选择以下可用条件之一来指定复杂的终止条件Terminators
library("mlr3tuning")
## Warning: package 'mlr3tuning' was built under R version 4.1.2
## Loading required package: paradox
## Warning: package 'paradox' was built under R version 4.1.2
evals20 = trm("evals", n_evals = 20)
instance = TuningInstanceSingleCrit$new(
task = task,
learner = learner,
resampling = hout,
measure = measure,
search_space = search_space,
terminator = evals20
)
instance
## <TuningInstanceSingleCrit>
## * State: Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
## id class lower upper nlevels
## 1: cp ParamDbl 0.001 0.1 Inf
## 2: minsplit ParamInt 1.000 10.0 10
## * Terminator: <TerminatorEvals>
要开始调整,我们仍然需要选择优化的方式。换句话说,我们需要通过类来选择优化算法Tuner。
目前在mlr3tuning中实现了以下算法:
我们将使用网格分辨率为 5 的简单网格搜索。
tuner = tnr("grid_search", resolution = 5)
由于我们只有数字参数,TunerGridSearch因此将在各自的上限和下限之间创建一个等距网格。
接着就可以开始调优
tuner$optimize(instance)
## INFO [23:35:36.619] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]'
## INFO [23:35:36.622] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:36.630] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:36.634] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:36.646] [mlr3] Finished benchmark
## INFO [23:35:36.661] [bbotk] Result of batch 1:
## INFO [23:35:36.662] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:36.662] [bbotk] 0.02575 1 0.2382812 0 0 0.008
## INFO [23:35:36.662] [bbotk] uhash
## INFO [23:35:36.662] [bbotk] bfeccc54-6dbc-433d-a50c-7253c4670497
## INFO [23:35:36.663] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:36.670] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:36.675] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:36.689] [mlr3] Finished benchmark
## INFO [23:35:36.707] [bbotk] Result of batch 2:
## INFO [23:35:36.708] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:36.708] [bbotk] 0.1 5 0.2773438 0 0 0.008
## INFO [23:35:36.708] [bbotk] uhash
## INFO [23:35:36.708] [bbotk] 7e462295-d7b9-4fa6-8f36-177c8ddcf206
## INFO [23:35:36.709] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:36.717] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:36.722] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:36.738] [mlr3] Finished benchmark
## INFO [23:35:36.754] [bbotk] Result of batch 3:
## INFO [23:35:36.755] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:36.755] [bbotk] 0.0505 8 0.2773438 0 0 0.009
## INFO [23:35:36.755] [bbotk] uhash
## INFO [23:35:36.755] [bbotk] 5c8e97eb-e9fc-4280-963d-63631afe7dd6
## INFO [23:35:36.756] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:36.763] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:36.767] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:36.779] [mlr3] Finished benchmark
## INFO [23:35:36.795] [bbotk] Result of batch 4:
## INFO [23:35:36.796] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:36.796] [bbotk] 0.02575 10 0.2382812 0 0 0.007
## INFO [23:35:36.796] [bbotk] uhash
## INFO [23:35:36.796] [bbotk] 54f3a8bc-f1f2-47fa-98ff-d1b8c5466eb9
## INFO [23:35:36.797] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:36.804] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:36.808] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:36.822] [mlr3] Finished benchmark
## INFO [23:35:36.838] [bbotk] Result of batch 5:
## INFO [23:35:36.839] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:36.839] [bbotk] 0.07525 8 0.2773438 0 0 0.007
## INFO [23:35:36.839] [bbotk] uhash
## INFO [23:35:36.839] [bbotk] 6021e15b-fe4c-4d0a-9472-8974d76ac6a8
## INFO [23:35:36.840] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:36.847] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:36.851] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:36.869] [mlr3] Finished benchmark
## INFO [23:35:36.886] [bbotk] Result of batch 6:
## INFO [23:35:36.887] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:36.887] [bbotk] 0.07525 1 0.2773438 0 0 0.007
## INFO [23:35:36.887] [bbotk] uhash
## INFO [23:35:36.887] [bbotk] f636a86b-7128-4809-9fb0-d3a25161a67f
## INFO [23:35:36.888] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:36.895] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:36.899] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:36.911] [mlr3] Finished benchmark
## INFO [23:35:36.926] [bbotk] Result of batch 7:
## INFO [23:35:36.927] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:36.927] [bbotk] 0.02575 8 0.2382812 0 0 0.007
## INFO [23:35:36.927] [bbotk] uhash
## INFO [23:35:36.927] [bbotk] 3be12953-ca21-4d41-93d5-623bfae1b72d
## INFO [23:35:36.928] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:36.935] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:36.939] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:36.954] [mlr3] Finished benchmark
## INFO [23:35:36.970] [bbotk] Result of batch 8:
## INFO [23:35:36.971] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:36.971] [bbotk] 0.001 10 0.2890625 0 0 0.009
## INFO [23:35:36.971] [bbotk] uhash
## INFO [23:35:36.971] [bbotk] 9f3a7b8f-b6b8-4680-afdf-0ed2f907c66a
## INFO [23:35:36.971] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:36.978] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:36.982] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:36.994] [mlr3] Finished benchmark
## INFO [23:35:37.010] [bbotk] Result of batch 9:
## INFO [23:35:37.011] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.011] [bbotk] 0.0505 1 0.2773438 0 0 0.007
## INFO [23:35:37.011] [bbotk] uhash
## INFO [23:35:37.011] [bbotk] c26e62a2-496e-4093-961a-fbf953cbad76
## INFO [23:35:37.012] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.019] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.023] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.037] [mlr3] Finished benchmark
## INFO [23:35:37.053] [bbotk] Result of batch 10:
## INFO [23:35:37.055] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.055] [bbotk] 0.0505 10 0.2773438 0 0 0.008
## INFO [23:35:37.055] [bbotk] uhash
## INFO [23:35:37.055] [bbotk] 48b683c6-e936-45b1-9a75-9c4e87687d7d
## INFO [23:35:37.055] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.063] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.068] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.084] [mlr3] Finished benchmark
## INFO [23:35:37.100] [bbotk] Result of batch 11:
## INFO [23:35:37.101] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.101] [bbotk] 0.02575 3 0.2382812 0 0 0.008
## INFO [23:35:37.101] [bbotk] uhash
## INFO [23:35:37.101] [bbotk] 503cd25f-3d6d-4f86-afa3-c4dc127ae332
## INFO [23:35:37.102] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.109] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.114] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.128] [mlr3] Finished benchmark
## INFO [23:35:37.144] [bbotk] Result of batch 12:
## INFO [23:35:37.145] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.145] [bbotk] 0.001 1 0.3046875 0 0 0.009
## INFO [23:35:37.145] [bbotk] uhash
## INFO [23:35:37.145] [bbotk] 778da658-0d85-46e1-92cf-6da057a5cf45
## INFO [23:35:37.146] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.153] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.157] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.170] [mlr3] Finished benchmark
## INFO [23:35:37.188] [bbotk] Result of batch 13:
## INFO [23:35:37.189] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.189] [bbotk] 0.001 3 0.296875 0 0 0.008
## INFO [23:35:37.189] [bbotk] uhash
## INFO [23:35:37.189] [bbotk] 5a373c68-3d1e-4396-b258-44cd30143129
## INFO [23:35:37.189] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.197] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.201] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.214] [mlr3] Finished benchmark
## INFO [23:35:37.230] [bbotk] Result of batch 14:
## INFO [23:35:37.231] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.231] [bbotk] 0.001 5 0.2773438 0 0 0.007
## INFO [23:35:37.231] [bbotk] uhash
## INFO [23:35:37.231] [bbotk] f6d1f72f-73f6-41f5-bbbe-a2693a7ea8c0
## INFO [23:35:37.232] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.240] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.244] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.256] [mlr3] Finished benchmark
## INFO [23:35:37.273] [bbotk] Result of batch 15:
## INFO [23:35:37.274] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.274] [bbotk] 0.07525 10 0.2773438 0 0 0.007
## INFO [23:35:37.274] [bbotk] uhash
## INFO [23:35:37.274] [bbotk] efa79e53-6047-4ed6-9255-55da95576fa9
## INFO [23:35:37.275] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.282] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.291] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.303] [mlr3] Finished benchmark
## INFO [23:35:37.318] [bbotk] Result of batch 16:
## INFO [23:35:37.319] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.319] [bbotk] 0.0505 5 0.2773438 0 0 0.007
## INFO [23:35:37.319] [bbotk] uhash
## INFO [23:35:37.319] [bbotk] 980d871c-4801-4b1e-b0af-276442f1b2f4
## INFO [23:35:37.320] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.327] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.332] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.345] [mlr3] Finished benchmark
## INFO [23:35:37.361] [bbotk] Result of batch 17:
## INFO [23:35:37.362] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.362] [bbotk] 0.0505 3 0.2773438 0 0 0.008
## INFO [23:35:37.362] [bbotk] uhash
## INFO [23:35:37.362] [bbotk] 35ece802-7247-4269-a797-76cee3b9a0af
## INFO [23:35:37.362] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.369] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.373] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.385] [mlr3] Finished benchmark
## INFO [23:35:37.401] [bbotk] Result of batch 18:
## INFO [23:35:37.402] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.402] [bbotk] 0.1 10 0.2773438 0 0 0.007
## INFO [23:35:37.402] [bbotk] uhash
## INFO [23:35:37.402] [bbotk] a494383c-9d09-40c5-a537-08791be99302
## INFO [23:35:37.403] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.411] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.415] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.428] [mlr3] Finished benchmark
## INFO [23:35:37.445] [bbotk] Result of batch 19:
## INFO [23:35:37.446] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.446] [bbotk] 0.1 1 0.2773438 0 0 0.009
## INFO [23:35:37.446] [bbotk] uhash
## INFO [23:35:37.446] [bbotk] 2824c2f3-e732-4240-9e9c-109dc3b52b3d
## INFO [23:35:37.447] [bbotk] Evaluating 1 configuration(s)
## INFO [23:35:37.454] [mlr3] Running benchmark with 1 resampling iterations
## INFO [23:35:37.458] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [23:35:37.471] [mlr3] Finished benchmark
## INFO [23:35:37.488] [bbotk] Result of batch 20:
## INFO [23:35:37.489] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [23:35:37.489] [bbotk] 0.001 8 0.2773438 0 0 0.008
## INFO [23:35:37.489] [bbotk] uhash
## INFO [23:35:37.489] [bbotk] b70e2328-6018-4f69-95fb-6c5c1efa9da5
## INFO [23:35:37.497] [bbotk] Finished optimizing after 20 evaluation(s)
## INFO [23:35:37.498] [bbotk] Result:
## INFO [23:35:37.498] [bbotk] cp minsplit learner_param_vals x_domain classif.ce
## INFO [23:35:37.498] [bbotk] 0.02575 1 <list[3]> <list[2]> 0.2382812
## cp minsplit learner_param_vals x_domain classif.ce
## 1: 0.02575 1 <list[3]> <list[2]> 0.2382812
我们可以查看最优的参数
instance$result_learner_param_vals
## $xval
## [1] 0
##
## $cp
## [1] 0.02575
##
## $minsplit
## [1] 1
instance$result_y
## classif.ce
## 0.2382812
所有已执行的评估;
as.data.table(instance$archive)
## cp minsplit classif.ce x_domain_cp x_domain_minsplit runtime_learners
## 1: 0.02575 1 0.2382812 0.02575 1 0.008
## 2: 0.10000 5 0.2773438 0.10000 5 0.008
## 3: 0.05050 8 0.2773438 0.05050 8 0.009
## 4: 0.02575 10 0.2382812 0.02575 10 0.007
## 5: 0.07525 8 0.2773438 0.07525 8 0.007
## 6: 0.07525 1 0.2773438 0.07525 1 0.007
## 7: 0.02575 8 0.2382812 0.02575 8 0.007
## 8: 0.00100 10 0.2890625 0.00100 10 0.009
## 9: 0.05050 1 0.2773438 0.05050 1 0.007
## 10: 0.05050 10 0.2773438 0.05050 10 0.008
## 11: 0.02575 3 0.2382812 0.02575 3 0.008
## 12: 0.00100 1 0.3046875 0.00100 1 0.009
## 13: 0.00100 3 0.2968750 0.00100 3 0.008
## 14: 0.00100 5 0.2773438 0.00100 5 0.007
## 15: 0.07525 10 0.2773438 0.07525 10 0.007
## 16: 0.05050 5 0.2773438 0.05050 5 0.007
## 17: 0.05050 3 0.2773438 0.05050 3 0.008
## 18: 0.10000 10 0.2773438 0.10000 10 0.007
## 19: 0.10000 1 0.2773438 0.10000 1 0.009
## 20: 0.00100 8 0.2773438 0.00100 8 0.008
## timestamp batch_nr warnings errors resample_result
## 1: 2022-04-27 23:35:36 1 0 0 <ResampleResult[22]>
## 2: 2022-04-27 23:35:36 2 0 0 <ResampleResult[22]>
## 3: 2022-04-27 23:35:36 3 0 0 <ResampleResult[22]>
## 4: 2022-04-27 23:35:36 4 0 0 <ResampleResult[22]>
## 5: 2022-04-27 23:35:36 5 0 0 <ResampleResult[22]>
## 6: 2022-04-27 23:35:36 6 0 0 <ResampleResult[22]>
## 7: 2022-04-27 23:35:36 7 0 0 <ResampleResult[22]>
## 8: 2022-04-27 23:35:36 8 0 0 <ResampleResult[22]>
## 9: 2022-04-27 23:35:37 9 0 0 <ResampleResult[22]>
## 10: 2022-04-27 23:35:37 10 0 0 <ResampleResult[22]>
## 11: 2022-04-27 23:35:37 11 0 0 <ResampleResult[22]>
## 12: 2022-04-27 23:35:37 12 0 0 <ResampleResult[22]>
## 13: 2022-04-27 23:35:37 13 0 0 <ResampleResult[22]>
## 14: 2022-04-27 23:35:37 14 0 0 <ResampleResult[22]>
## 15: 2022-04-27 23:35:37 15 0 0 <ResampleResult[22]>
## 16: 2022-04-27 23:35:37 16 0 0 <ResampleResult[22]>
## 17: 2022-04-27 23:35:37 17 0 0 <ResampleResult[22]>
## 18: 2022-04-27 23:35:37 18 0 0 <ResampleResult[22]>
## 19: 2022-04-27 23:35:37 19 0 0 <ResampleResult[22]>
## 20: 2022-04-27 23:35:37 20 0 0 <ResampleResult[22]>
可以在BenchmarkResult调整实例中访问相关的重采样迭代:
instance$archive$benchmark_result
## <BenchmarkResult> of 20 rows with 20 resampling runs
## nr task_id learner_id resampling_id iters warnings errors
## 1 pima classif.rpart holdout 1 0 0
## 2 pima classif.rpart holdout 1 0 0
## 3 pima classif.rpart holdout 1 0 0
## 4 pima classif.rpart holdout 1 0 0
## 5 pima classif.rpart holdout 1 0 0
## 6 pima classif.rpart holdout 1 0 0
## 7 pima classif.rpart holdout 1 0 0
## 8 pima classif.rpart holdout 1 0 0
## 9 pima classif.rpart holdout 1 0 0
## 10 pima classif.rpart holdout 1 0 0
## 11 pima classif.rpart holdout 1 0 0
## 12 pima classif.rpart holdout 1 0 0
## 13 pima classif.rpart holdout 1 0 0
## 14 pima classif.rpart holdout 1 0 0
## 15 pima classif.rpart holdout 1 0 0
## 16 pima classif.rpart holdout 1 0 0
## 17 pima classif.rpart holdout 1 0 0
## 18 pima classif.rpart holdout 1 0 0
## 19 pima classif.rpart holdout 1 0 0
## 20 pima classif.rpart holdout 1 0 0
查看不同模型的预测评估结果
instance$archive$benchmark_result$score(msr("classif.acc"))
## uhash nr task task_id
## 1: bfeccc54-6dbc-433d-a50c-7253c4670497 1 <TaskClassif[50]> pima
## 2: 7e462295-d7b9-4fa6-8f36-177c8ddcf206 2 <TaskClassif[50]> pima
## 3: 5c8e97eb-e9fc-4280-963d-63631afe7dd6 3 <TaskClassif[50]> pima
## 4: 54f3a8bc-f1f2-47fa-98ff-d1b8c5466eb9 4 <TaskClassif[50]> pima
## 5: 6021e15b-fe4c-4d0a-9472-8974d76ac6a8 5 <TaskClassif[50]> pima
## 6: f636a86b-7128-4809-9fb0-d3a25161a67f 6 <TaskClassif[50]> pima
## 7: 3be12953-ca21-4d41-93d5-623bfae1b72d 7 <TaskClassif[50]> pima
## 8: 9f3a7b8f-b6b8-4680-afdf-0ed2f907c66a 8 <TaskClassif[50]> pima
## 9: c26e62a2-496e-4093-961a-fbf953cbad76 9 <TaskClassif[50]> pima
## 10: 48b683c6-e936-45b1-9a75-9c4e87687d7d 10 <TaskClassif[50]> pima
## 11: 503cd25f-3d6d-4f86-afa3-c4dc127ae332 11 <TaskClassif[50]> pima
## 12: 778da658-0d85-46e1-92cf-6da057a5cf45 12 <TaskClassif[50]> pima
## 13: 5a373c68-3d1e-4396-b258-44cd30143129 13 <TaskClassif[50]> pima
## 14: f6d1f72f-73f6-41f5-bbbe-a2693a7ea8c0 14 <TaskClassif[50]> pima
## 15: efa79e53-6047-4ed6-9255-55da95576fa9 15 <TaskClassif[50]> pima
## 16: 980d871c-4801-4b1e-b0af-276442f1b2f4 16 <TaskClassif[50]> pima
## 17: 35ece802-7247-4269-a797-76cee3b9a0af 17 <TaskClassif[50]> pima
## 18: a494383c-9d09-40c5-a537-08791be99302 18 <TaskClassif[50]> pima
## 19: 2824c2f3-e732-4240-9e9c-109dc3b52b3d 19 <TaskClassif[50]> pima
## 20: b70e2328-6018-4f69-95fb-6c5c1efa9da5 20 <TaskClassif[50]> pima
## learner learner_id resampling
## 1: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 2: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 3: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 4: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 5: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 6: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 7: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 8: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 9: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 10: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 11: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 12: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 13: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 14: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 15: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 16: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 17: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 18: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 19: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 20: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## resampling_id iteration prediction classif.acc
## 1: holdout 1 <PredictionClassif[20]> 0.7617188
## 2: holdout 1 <PredictionClassif[20]> 0.7226562
## 3: holdout 1 <PredictionClassif[20]> 0.7226562
## 4: holdout 1 <PredictionClassif[20]> 0.7617188
## 5: holdout 1 <PredictionClassif[20]> 0.7226562
## 6: holdout 1 <PredictionClassif[20]> 0.7226562
## 7: holdout 1 <PredictionClassif[20]> 0.7617188
## 8: holdout 1 <PredictionClassif[20]> 0.7109375
## 9: holdout 1 <PredictionClassif[20]> 0.7226562
## 10: holdout 1 <PredictionClassif[20]> 0.7226562
## 11: holdout 1 <PredictionClassif[20]> 0.7617188
## 12: holdout 1 <PredictionClassif[20]> 0.6953125
## 13: holdout 1 <PredictionClassif[20]> 0.7031250
## 14: holdout 1 <PredictionClassif[20]> 0.7226562
## 15: holdout 1 <PredictionClassif[20]> 0.7226562
## 16: holdout 1 <PredictionClassif[20]> 0.7226562
## 17: holdout 1 <PredictionClassif[20]> 0.7226562
## 18: holdout 1 <PredictionClassif[20]> 0.7226562
## 19: holdout 1 <PredictionClassif[20]> 0.7226562
## 20: holdout 1 <PredictionClassif[20]> 0.7226562
现在我们可以采用优化的超参数,将它们设置为先前创建的Learner,并在完整数据集上对其进行训练。
learner$param_set$values = instance$result_learner_param_vals
learner$train(task)
训练后的模型现在可用于对新的外部数据进行预测