mlr于 2013年首次发布到CRAN。其核心设计和架构可以追溯到更早。由于各种原因,开发团队重构了mlr的功能,现在这个包已经改为mlr3 。

mlr3与tidymodels 类似,也是用于建模的一套R包,这些R包包括。 https://github.com/mlr-org/mlr3/wiki/Extension-Packages

R6 类

mlr3 这一套包是使用R中的R6类编写的,我们首先学习R中的R6类。R中有几种面向对象系统,S3,S4,RC和R6 。

R6不是R自带的类,使用他需要安装并加载R6包。install.packages(“R6”) 。R6通过R6Class函数创建类以及其方法。

R6Class函数的最重要的两个参数:

  1. 第一个参数是类名
  2. 第二个参数是方法和字段列表

我们看一个例子:

library(R6)
Accumulator <- R6Class("Accumulator", list(
  sum = 0,
  add = function(x = 1) {
    self$sum <- self$sum + x 
    invisible(self)
  })
)

Accumulator
## <Accumulator> object generator
##   Public:
##     sum: 0
##     add: function (x = 1) 
##     clone: function (deep = FALSE) 
##   Parent env: <environment: R_GlobalEnv>
##   Locked objects: TRUE
##   Locked class: FALSE
##   Portable: TRUE

使用R6类的new方法可以构建一个新的对象,在R6中,方法属于对象,使用美元符号访问方法和字段($)

x <- Accumulator$new() 

x$add(100)
x$sum
## [1] 100

在这个类中,字段和方法是公共的,可以自由访问。

大多数类定义两个重要的方法:\(initialize()和\)print()。

Person <- R6Class("Person", list(
  name = NULL,
  age = NA,
  initialize = function(name, age = NA) {
    stopifnot(is.character(name), length(name) == 1)
    stopifnot(is.numeric(age), length(age) == 1)
    
    self$name <- name
    self$age <- age
  }
))

# hadley <- Person$new("Hadley", age = "thirty-eight")
 

hadley <- Person$new("Hadley", age = 38)


Person <- R6Class("Person", list(
  name = NULL,
  age = NA,
  initialize = function(name, age = NA) {
    self$name <- name
    self$age <- age
  },
  print = function(...) {
    cat("Person: \n")
    cat("  Name: ", self$name, "\n", sep = "")
    cat("  Age:  ", self$age, "\n", sep = "")
    invisible(self)
  }
))

hadley2 <- Person$new("Hadley")
hadley2
## Person: 
##   Name: Hadley
##   Age:  NA

另外, 还可以修改现有类的字段和方法

Accumulator <- R6Class("Accumulator")
Accumulator$set("public", "sum", 0)
Accumulator$set("public", "add", function(x = 1) {
  self$sum <- self$sum + x 
  invisible(self)
})

需要注意的是,新方法和字段仅对新对象可用;它们不会追溯添加到现有对象中。

要从现有类继承行为,请将类对象提供给inherit参数:

AccumulatorChatty <- R6Class("AccumulatorChatty", 
  inherit = Accumulator,
  public = list(
    add = function(x = 1) {
      cat("Adding ", x, "\n", sep = "")
      super$add(x = x)
    }
  )
)

x2 <- AccumulatorChatty$new()
x2$add(10)$add(1)$sum
## Adding 10
## Adding 1
## [1] 11
#> Adding 10
#> Adding 1
#> [1] 11

R6Class()还有两个与 类似的参数public:

  1. private允许您创建仅在类内部而不是外部可用的字段和方法。
  2. active允许您使用访问器函数来定义动态或活动字段。
Person <- R6Class("Person", 
  public = list(
    initialize = function(name, age = NA) {
      private$name <- name
      private$age <- age
    },
    print = function(...) {
      cat("Person: \n")
      cat("  Name: ", private$name, "\n", sep = "")
      cat("  Age:  ", private$age, "\n", sep = "")
    }
  ),
  private = list(
    age = NA,
    name = NULL
  )
)

hadley3 <- Person$new("Hadley")
hadley3
## Person: 
##   Name: Hadley
##   Age:  NA
hadley3$name
## NULL

引用与复制

y1 <- Accumulator$new() 
y2 <- y1

y1$add(10)
c(y1 = y1$sum, y2 = y2$sum)
## y1 y2 
## 10 10

如果你想要一个副本,你需要明确$clone()的对象:

y1 <- Accumulator$new() 
y2 <- y1$clone()

y1$add(10)
c(y1 = y1$sum, y2 = y2$sum)
## y1 y2 
## 10  0

mlr3 相关包

  1. bbotk 黑盒优化工具包。
  2. mlr3benchmark 用于基准的事后分析(测试、绘图),即比较多个学习者在多个任务上的结果。
  3. mlr3cluster 聚类分析的扩展。
  4. mlr3data 数据集
  5. mlr3db 操作数据库
  6. mlr3filters 变量选择过滤器
  7. mlr3fselect 特征选择顺序前向/后向搜索、穷举搜索或遗传算法。
  8. mlr3hyperband 超参数调整
  9. mlr3learners 回归和分类的其他学习器。
  10. mlr3measures 分类和回归性能度量
  11. mlr3oml OpenML的连接器,https://openml.org/ ,开放的机器学习平台
  12. mlr3pipelines 用于预处理和构建复杂工作流的管道和DAG
  13. mlr3proba 监督概率学习的扩展(包括生存分析)
  14. mlr3spatiotempcv 重采样时空任务
  15. mlr3tuning 过随机搜索、网格搜索
  16. mlr3verse 用于安装和加载已发布核心包的元包
  17. mlr3viz 通过ggplot2的autoplot()函数进行可视化。
  18. mlrintermbo mlrMBO将用于贝叶斯优化的软件包与机器学习算法连接起来,mlr3并使其可用于高效调整机器学习算法。
  19. mlr3batchmark mlr3和批处理工具之间的连接器
  20. mlr3extralearners 获取和管理其他拓展包
  21. mlr3forecasting 时间序列预测的扩展
  22. mlr3keras 通过keras进行深度学习的扩展
  23. mlr3mbo 通过基于模型的优化(又名贝叶斯优化)进行超参数调整。
  24. mlr3ordinal 序数回归的扩展。
  25. miesmuschel 混合整数进化策略优化。
  26. mlr3tuningspaces 调整搜索空间
  27. mlr3fairness 机器学习公平性的扩展。Google Summer of Code 项目。
  28. mlr3spatial 空间数据分析的扩展。
  29. mlr3fda 功能数据分析的扩展。
  30. mlr3multilabel 多标签分类的扩展。

先讲教程:https://mlr3book.mlr-org.com/

然后一个包一个包的介绍https://mlr3book.mlr-org.com/

简介

第一步是安装mlr3.

我们使用iris数据集构建一个决策树模型。

library("mlr3")
## Warning: package 'mlr3' was built under R version 4.1.2
task = tsk("iris")
learner = lrn("classif.featureless")

# train a model of this learner for a subset of the task
learner$train(task, row_ids = 1:120)
# this is what the decision tree looks like
learner$model
## $tab
## 
##     setosa versicolor  virginica 
##         50         50         20 
## 
## $features
## [1] "Petal.Length" "Petal.Width"  "Sepal.Length" "Sepal.Width" 
## 
## attr(,"class")
## [1] "classif.featureless_model"

如果想要查看有哪些可选模型,可以通过乳腺癌代码。

lrn()
## <DictionaryLearner> with 6 stored values
## Keys: classif.debug, classif.featureless, classif.rpart, regr.debug,
##   regr.featureless, regr.rpart

更多的模型需要使用拓展包:mlr3learners,这些模型包括:

分类模型

  1. classif.cv_glmnet Penalized Logistic Regression glmnet
  2. classif.glmnet Penalized Logistic Regression glmnet
  3. classif.kknn k-Nearest Neighbors kknn
  4. classif.lda LDA MASS
  5. classif.log_reg Logistic Regression stats
  6. classif.multinom Multinomial log-linear model nnet
  7. classif.naive_bayes Naive Bayes e1071
  8. classif.nnet Single Layer Neural Network nnet
  9. classif.qda QDA MASS
  10. classif.ranger Random Forest ranger
  11. classif.svm SVM e1071
  12. classif.xgboost Gradient Boosting xgboost

回归模型

  1. regr.cv_glmnet Penalized Linear Regression glmnet
  2. regr.glmnet Penalized Linear Regression glmnet
  3. regr.kknn k-Nearest Neighbors kknn
  4. regr.km Kriging DiceKriging
  5. regr.lm Linear Regression stats
  6. regr.ranger Random Forest ranger
  7. regr.svm SVM e1071
  8. regr.xgboost Gradient Boosting xgboost

生存模型

  1. surv.cv_glmnet Penalized Cox Regression glmnet
  2. surv.glmnet Penalized Cox Regression glmnet
  3. surv.ranger Random Forest ranger
  4. surv.xgboost Gradient Boosting xgboost

我们使用构建好的模型进行预测,代码如下所示。

predictions = learner$predict(task, row_ids = 121:150)
predictions
## <PredictionClassif> for 30 observations:
##     row_ids     truth   response
##         121 virginica versicolor
##         122 virginica versicolor
##         123 virginica versicolor
## ---                             
##         148 virginica versicolor
##         149 virginica versicolor
##         150 virginica versicolor

评估模型,计算准确度

# accuracy of our model on the test set of the final 30 rows
predictions$score(msr("classif.acc"))
## classif.acc 
##           0

基础例子

我们在这一部分稍微详细的介绍mlr3建模

Task 类

Task 对象通常包含数据,用于定义机器学习的问题。

我们使用mtcars包中的数据集创建一个回归任务。

library("mlr3")

task_mtcars = as_task_regr(mtcars, target = "mpg", id = "cars")
print(task_mtcars)
## <TaskRegr:cars> (32 x 11)
## * Target: mpg
## * Properties: -
## * Features (10):
##   - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt
class(task_mtcars)
## [1] "TaskRegr"       "TaskSupervised" "Task"           "R6"

as_task_regr 函数有三个参数,1 数据集 2 target 参数表示目标变量 3 id ,可以不设置

另外一种写法是

task_mtcars <- TaskRegr$new(backend = mtcars,target = "mpg",id = "cars")

task_mtcars
## <TaskRegr:cars> (32 x 11)
## * Target: mpg
## * Properties: -
## * Features (10):
##   - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt
task_mtcars$print()
## <TaskRegr:cars> (32 x 11)
## * Target: mpg
## * Properties: -
## * Features (10):
##   - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt

我们还可以使用mlr3viz 进行数据可视化。

library("mlr3viz")
## Warning: package 'mlr3viz' was built under R version 4.1.2
autoplot(task_mtcars$select(c("am","carb")), type = "pairs")
## Registered S3 method overwritten by 'GGally':
##   method from   
##   +.gg   ggplot2

mlr3包含一些预定义数据,使用mlr_tasks进行查看

mlr_tasks
## <DictionaryTask> with 11 stored values
## Keys: boston_housing, breast_cancer, german_credit, iris, mtcars,
##   penguins, pima, sonar, spam, wine, zoo

可以将结果转变成为数据框

as.data.table(mlr_tasks)
##                key                   label task_type nrow ncol properties lgl
##  1: boston_housing   Boston Housing Prices      regr  506   19              0
##  2:  breast_cancer Wisconsin Breast Cancer   classif  683   10   twoclass   0
##  3:  german_credit           German Credit   classif 1000   21   twoclass   0
##  4:           iris            Iris Flowers   classif  150    5 multiclass   0
##  5:         mtcars            Motor Trends      regr   32   11              0
##  6:       penguins         Palmer Penguins   classif  344    8 multiclass   0
##  7:           pima    Pima Indian Diabetes   classif  768    9   twoclass   0
##  8:          sonar  Sonar: Mines vs. Rocks   classif  208   61   twoclass   0
##  9:           spam       HP Spam Detection   classif 4601   58   twoclass   0
## 10:           wine            Wine Regions   classif  178   14 multiclass   0
## 11:            zoo             Zoo Animals   classif  101   17 multiclass  15
##     int dbl chr fct ord pxc
##  1:   3  13   0   2   0   0
##  2:   0   0   0   0   9   0
##  3:   3   0   0  14   3   0
##  4:   0   4   0   0   0   0
##  5:   0  10   0   0   0   0
##  6:   3   2   0   2   0   0
##  7:   0   8   0   0   0   0
##  8:   0  60   0   0   0   0
##  9:   0  57   0   0   0   0
## 10:   2  11   0   0   0   0
## 11:   1   0   0   0   0   0

想要获取数据,可以使用get方法

mlr_tasks$get("sonar")
## <TaskClassif:sonar> (208 x 61): Sonar: Mines vs. Rocks
## * Target: Class
## * Properties: twoclass
## * Features (60):
##   - dbl (60): V1, V10, V11, V12, V13, V14, V15, V16, V17, V18, V19, V2,
##     V20, V21, V22, V23, V24, V25, V26, V27, V28, V29, V3, V30, V31,
##     V32, V33, V34, V35, V36, V37, V38, V39, V4, V40, V41, V42, V43,
##     V44, V45, V46, V47, V48, V49, V5, V50, V51, V52, V53, V54, V55,
##     V56, V57, V58, V59, V6, V60, V7, V8, V9

可以理解为一些自带的数据。使用tsk函数可以获取数据,代码如下所示。

task_penguins = tsk("penguins")
print(task_penguins)
## <TaskClassif:penguins> (344 x 8): Palmer Penguins
## * Target: species
## * Properties: multiclass
## * Features (7):
##   - int (3): body_mass, flipper_length, year
##   - dbl (2): bill_depth, bill_length
##   - fct (2): island, sex

我们可以对task进行操作。

task_mtcars$nrow
## [1] 32
task_mtcars$ncol
## [1] 3
task_mtcars$data()
##      mpg am carb
##  1: 21.0  1    4
##  2: 21.0  1    4
##  3: 22.8  1    1
##  4: 21.4  0    1
##  5: 18.7  0    2
##  6: 18.1  0    1
##  7: 14.3  0    4
##  8: 24.4  0    2
##  9: 22.8  0    2
## 10: 19.2  0    4
## 11: 17.8  0    4
## 12: 16.4  0    3
## 13: 17.3  0    3
## 14: 15.2  0    3
## 15: 10.4  0    4
## 16: 10.4  0    4
## 17: 14.7  0    4
## 18: 32.4  1    1
## 19: 30.4  1    2
## 20: 33.9  1    1
## 21: 21.5  0    1
## 22: 15.5  0    2
## 23: 15.2  0    2
## 24: 13.3  0    4
## 25: 19.2  0    2
## 26: 27.3  1    1
## 27: 26.0  1    2
## 28: 30.4  1    2
## 29: 15.8  1    4
## 30: 19.7  1    6
## 31: 15.0  1    8
## 32: 21.4  1    2
##      mpg am carb
task_mtcars$feature_names
## [1] "am"   "carb"
task_mtcars$target_names
## [1] "mpg"

如果是二分类变量,可以使用positive 属性定义哪一个是正样本。

# during construction
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "R")

# switch positive clas

可以为行和列分配不同的角色。这些角色会影响不同操作的任务行为。

print(task_mtcars$col_roles)
## $feature
## [1] "am"   "carb"
## 
## $target
## [1] "mpg"
## 
## $name
## character(0)
## 
## $order
## character(0)
## 
## $stratum
## character(0)
## 
## $group
## character(0)
## 
## $weight
## character(0)

列可以没有角色(它们被忽略)或有多个角色。设置属性可以在初始化类的时候,也可以是使用set_col_roles方法。

task_mtcars$set_col_roles("rn", roles = "name")

还可以对task 进行简单的数据处理,包括筛选行,列等等

task_penguins = tsk("penguins")
task_penguins$select(c("body_mass", "flipper_length")) # keep only these features
task_penguins$filter(1:3) # keep only these rows
task_penguins$head()
##    species body_mass flipper_length
## 1:  Adelie      3750            181
## 2:  Adelie      3800            186
## 3:  Adelie      3250            195
task_penguins$cbind(data.frame(letters = letters[1:3])) # add column letters
task_penguins$head()
##    species body_mass flipper_length letters
## 1:  Adelie      3750            181       a
## 2:  Adelie      3800            186       b
## 3:  Adelie      3250            195       c

可视化任务

library("mlr3viz")

# get the pima indians task
task = tsk("pima")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: class frequencies
autoplot(task)

# pairs plot (requires package GGally)
autoplot(task, type = "pairs")
## Warning in ggally_statistic(data = data, mapping = mapping, na.rm = na.rm, :
## Removed 5 rows containing missing values
## Warning in ggally_statistic(data = data, mapping = mapping, na.rm = na.rm, :
## Removed 374 rows containing missing values
## Warning in ggally_statistic(data = data, mapping = mapping, na.rm = na.rm, :
## Removed 375 rows containing missing values
## Warning: Removed 5 rows containing non-finite values (stat_boxplot).
## Warning: Removed 374 rows containing non-finite values (stat_boxplot).
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 5 rows containing non-finite values (stat_bin).
## Warning: Removed 5 rows containing missing values (geom_point).
## Warning: Removed 5 rows containing non-finite values (stat_density).
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 374 rows containing non-finite values (stat_bin).
## Warning: Removed 374 rows containing missing values (geom_point).
## Warning: Removed 375 rows containing missing values (geom_point).
## Warning: Removed 374 rows containing non-finite values (stat_density).

# duo plot (requires package GGally)
autoplot(task, type = "duo")
## Warning: Removed 5 rows containing non-finite values (stat_boxplot).
## Warning: Removed 374 rows containing non-finite values (stat_boxplot).

library("mlr3viz")

# get the complete mtcars task
task = tsk("mtcars")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: boxplot of target variable
autoplot(task)

# pairs plot (requires package GGally)
autoplot(task, type = "pairs")

Learner 类

该mlr3软件包附带以下一组分类和回归学习器。

  1. mlr_learners_classif.featureless: 简单的基线分类学习器(继承自LearnerClassif)。默认是每次预测训练集中出现频率最高的标签。
  2. mlr_learners_regr.featureless:简单的基线回归学习器(继承自LearnerRegr)。默认是每次预测训练集中目标的均值。
  3. mlr_learners_classif.rpart:来自包rpart的单个分类树。
  4. mlr_learners_regr.rpart:来自包rpart的单一回归树。

使用lrm函数查看有哪些模型可选

lrn()
## <DictionaryLearner> with 6 stored values
## Keys: classif.debug, classif.featureless, classif.rpart, regr.debug,
##   regr.featureless, regr.rpart

mlr3learners 包提供了其他更多的方法。

mlr3extralearners 包也提供了很多方法。remotes::install_github("mlr-org/mlr3extralearners")

mlr3extralearners::list_mlr3learners()
##            name   class                 id      mlr3_package
##   1: AdaBoostM1 classif classif.AdaBoostM1 mlr3extralearners
##   2:       bart classif       classif.bart mlr3extralearners
##   3:        C50 classif        classif.C50 mlr3extralearners
##   4:   catboost classif   classif.catboost mlr3extralearners
##   5:    cforest classif    classif.cforest mlr3extralearners
##  ---                                                        
## 132:     ranger    surv        surv.ranger      mlr3learners
## 133:      rfsrc    surv         surv.rfsrc mlr3extralearners
## 134:      rpart    surv         surv.rpart         mlr3proba
## 135:        svm    surv           surv.svm mlr3extralearners
## 136:    xgboost    surv       surv.xgboost      mlr3learners
##                                            required_packages
##   1:                            mlr3,mlr3extralearners,RWeka
##   2:                           mlr3,mlr3extralearners,dbarts
##   3:                              mlr3,mlr3extralearners,C50
##   4:                         mlr3,mlr3extralearners,catboost
##   5:           mlr3,mlr3extralearners,partykit,sandwich,coin
##  ---                                                        
## 132:                      mlr3,mlr3proba,mlr3learners,ranger
## 133: mlr3,mlr3proba,mlr3extralearners,randomForestSRC,pracma
## 134:                    mlr3,mlr3proba,rpart,distr6,survival
## 135:            mlr3,mlr3proba,mlr3extralearners,survivalsvm
## 136:                     mlr3,mlr3proba,mlr3learners,xgboost
##                                           properties
##   1:                             multiclass,twoclass
##   2:                                twoclass,weights
##   3:            missings,multiclass,twoclass,weights
##   4: importance,missings,multiclass,twoclass,weights
##   5:           multiclass,oob_error,twoclass,weights
##  ---                                                
## 132:                    importance,oob_error,weights
## 133:           importance,missings,oob_error,weights
## 134:   importance,missings,selected_features,weights
## 135:                                                
## 136:                     importance,missings,weights
##                                         feature_types  predict_types
##   1:                   numeric,factor,ordered,integer  response,prob
##   2:                   integer,numeric,factor,ordered  response,prob
##   3:                           numeric,factor,ordered  response,prob
##   4:                           numeric,factor,ordered  response,prob
##   5:                   integer,numeric,factor,ordered  response,prob
##  ---                                                                
## 132: logical,integer,numeric,character,factor,ordered    distr,crank
## 133:                   logical,integer,numeric,factor    crank,distr
## 134: logical,integer,numeric,character,factor,ordered    crank,distr
## 135:         integer,numeric,character,factor,logical crank,response
## 136:                                  integer,numeric       crank,lp

也可以使用(“mlr3verse”)包中的mlr_learners对象查看所有可用的模型。

library("mlr3verse")
## Warning: package 'mlr3verse' was built under R version 4.1.2
mlr_learners
## <DictionaryLearner> with 136 stored values
## Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
##   classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
##   classif.earth, classif.extratrees, classif.featureless, classif.fnn,
##   classif.gam, classif.gamboost, classif.gausspr, classif.gbm,
##   classif.glmboost, classif.glmnet, classif.IBk, classif.J48,
##   classif.JRip, classif.kknn, classif.ksvm, classif.lda,
##   classif.liblinear, classif.lightgbm, classif.LMT, classif.log_reg,
##   classif.lssvm, classif.mob, classif.multinom, classif.naive_bayes,
##   classif.nnet, classif.OneR, classif.PART, classif.qda,
##   classif.randomForest, classif.ranger, classif.rfsrc, classif.rpart,
##   classif.svm, classif.xgboost, clust.agnes, clust.ap, clust.cmeans,
##   clust.cobweb, clust.dbscan, clust.diana, clust.em, clust.fanny,
##   clust.featureless, clust.ff, clust.hclust, clust.kkmeans,
##   clust.kmeans, clust.MBatchKMeans, clust.meanshift, clust.pam,
##   clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde, dens.kde_kd,
##   dens.kde_ks, dens.locfit, dens.logspline, dens.mixed, dens.nonpar,
##   dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
##   regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.debug,
##   regr.earth, regr.extratrees, regr.featureless, regr.fnn, regr.gam,
##   regr.gamboost, regr.gausspr, regr.gbm, regr.glm, regr.glmboost,
##   regr.glmnet, regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
##   regr.lightgbm, regr.lm, regr.M5Rules, regr.mars, regr.mob,
##   regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
##   regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
##   surv.coxboost, surv.coxph, surv.coxtime, surv.ctree,
##   surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv,
##   surv.dnnsurv, surv.flexible, surv.gamboost, surv.gbm, surv.glmboost,
##   surv.glmnet, surv.kaplan, surv.loghaz, surv.mboost, surv.nelson,
##   surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized,
##   surv.ranger, surv.rfsrc, surv.rpart, surv.svm, surv.xgboost

训练模型需要创建learner对象

learner = lrn("classif.rpart")
print(learner)
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

learner对象包含了很多信息,包括训练模型需要的包,目标变量的类别,特征类别,其他属性和函数。

每一个learner 类都有其对应的超参数

learner$param_set
## <ParamSet>
##                 id    class lower upper nlevels        default value
##  1:             cp ParamDbl     0     1     Inf           0.01      
##  2:     keep_model ParamLgl    NA    NA       2          FALSE      
##  3:     maxcompete ParamInt     0   Inf     Inf              4      
##  4:       maxdepth ParamInt     1    30      30             30      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
##  7:       minsplit ParamInt     1   Inf     Inf             20      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:   usesurrogate ParamInt     0     2       3              2      
## 10:           xval ParamInt     0   Inf     Inf             10     0

我们可以直接修改参数

learner$param_set$values = list(cp = 0.01, xval = 0)
learner
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: -
## * Parameters: cp=0.01, xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

此操作会覆盖所有先前设置的参数。您还可以获取当前的一组超参数值,对其进行修改,然后将其写回学习器:

pv = learner$param_set$values
pv$cp = 0.02
learner$param_set$values = pv

另外,还可以在创建learner对象的时候指定参数。

learner = lrn("classif.rpart", id = "rp", cp = 0.001)
learner$id
## [1] "rp"
learner$param_set$values
## $xval
## [1] 0
## 
## $cp
## [1] 0.001

对于分类任务,模型通常会给予预测概率,如果概率大于 50%,则预测正标签,否则预测负标签,当然我们可以修改阈值。

data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "M")
learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)

measures = msrs(c("classif.tpr", "classif.tnr")) # use msrs() to get a list of multiple measures
pred$confusion
##         truth
## response  M  R
##        M 95 10
##        R 16 87
pred$score(measures)
## classif.tpr classif.tnr 
##   0.8558559   0.8969072
pred$set_threshold(0.2)
pred$confusion
##         truth
## response   M   R
##        M 104  25
##        R   7  72
pred$score(measures)
## classif.tpr classif.tnr 
##   0.9369369   0.7422680

可以使用mlr3pipelines包根据性能度量自动调整阈值,即使用PipeOpTuneThreshold

我们再看一个例子。

library("mlr3verse")
task = tsk("penguins")
learner = lrn("classif.rpart")
train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)


learner$model
## NULL
learner$train(task, row_ids = train_set)

print(learner$model) # 查看模型
## n= 275 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 275 150 Adelie (0.454545455 0.203636364 0.341818182)  
##   2) flipper_length< 206.5 176  53 Adelie (0.698863636 0.295454545 0.005681818)  
##     4) bill_length< 43.35 123   3 Adelie (0.975609756 0.024390244 0.000000000) *
##     5) bill_length>=43.35 53   4 Chinstrap (0.056603774 0.924528302 0.018867925) *
##   3) flipper_length>=206.5 99   6 Gentoo (0.020202020 0.040404040 0.939393939)  
##     6) bill_depth>=17.2 8   4 Chinstrap (0.250000000 0.500000000 0.250000000) *
##     7) bill_depth< 17.2 91   0 Gentoo (0.000000000 0.000000000 1.000000000) *
# 进行预测

prediction = learner$predict(task, row_ids = test_set)
print(prediction)
## <PredictionClassif> for 69 observations:
##     row_ids     truth  response
##           1    Adelie    Adelie
##           4    Adelie    Adelie
##           5    Adelie    Adelie
## ---                            
##         334 Chinstrap Chinstrap
##         335 Chinstrap Chinstrap
##         340 Chinstrap Chinstrap
# 将预测结果转化成为data.frame

head(as.data.table(prediction)) # show first six predictions
##    row_ids  truth response
## 1:       1 Adelie   Adelie
## 2:       4 Adelie   Adelie
## 3:       5 Adelie   Adelie
## 4:      16 Adelie   Adelie
## 5:      22 Adelie   Adelie
## 6:      26 Adelie   Adelie
# 计算混淆矩阵

prediction$confusion
##            truth
## response    Adelie Chinstrap Gentoo
##   Adelie        26         2      0
##   Chinstrap      1        10      0
##   Gentoo         0         0     30
# 修改预测类型

learner$predict_type = "prob"

# re-fit the model
learner$train(task, row_ids = train_set)

# rebuild prediction object
prediction = learner$predict(task, row_ids = test_set)


# data.table conversion
head(as.data.table(prediction)) # show first six
##    row_ids  truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 1:       1 Adelie   Adelie   0.9756098     0.02439024           0
## 2:       4 Adelie   Adelie   0.9756098     0.02439024           0
## 3:       5 Adelie   Adelie   0.9756098     0.02439024           0
## 4:      16 Adelie   Adelie   0.9756098     0.02439024           0
## 5:      22 Adelie   Adelie   0.9756098     0.02439024           0
## 6:      26 Adelie   Adelie   0.9756098     0.02439024           0
# directly access the predicted labels:
head(prediction$response)
## [1] Adelie Adelie Adelie Adelie Adelie Adelie
## Levels: Adelie Chinstrap Gentoo
# directly access the matrix of probabilities:
head(prediction$prob)
##         Adelie  Chinstrap Gentoo
## [1,] 0.9756098 0.02439024      0
## [2,] 0.9756098 0.02439024      0
## [3,] 0.9756098 0.02439024      0
## [4,] 0.9756098 0.02439024      0
## [5,] 0.9756098 0.02439024      0
## [6,] 0.9756098 0.02439024      0
# 对结果进行可视化

task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)

建模的最后一步通常是评估训练模型的性能,我们可以使用mlr_measures查看有哪些评估方法

mlr_measures
## <DictionaryMeasure> with 88 stored values
## Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
##   clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
##   regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
##   regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
##   regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
##   regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi,
##   surv.brier, surv.calib_alpha, surv.calib_beta, surv.chambless_auc,
##   surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
##   surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2,
##   surv.rcll, surv.rmse, surv.schmid, surv.song_auc, surv.song_tnr,
##   surv.song_tpr, surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2,
##   time_both, time_predict, time_train

例如我们选择准确率

measure = msr("classif.acc")
print(measure)
## <MeasureClassifSimple:classif.acc>: Classification Accuracy
## * Packages: mlr3, mlr3measures
## * Range: [0, 1]
## * Minimize: FALSE
## * Average: macro
## * Parameters: list()
## * Properties: -
## * Predict type: response
prediction$score(measure)
## classif.acc 
##   0.9651163

性能评估与比较

ROC 和 招呼

我们首先来分析二分类问题,例如绘制ROC曲线和混淆矩阵

data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "M")

learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)
C = pred$confusion
print(C)
##         truth
## response  M  R
##        M 95 10
##        R 16 87

要绘制ROC曲线,有很多种方式:

library("mlr3viz")

# TPR vs FPR / Sensitivity vs (1 - Specificity)
autoplot(pred, type = "roc")

我们还可以绘制精确召回曲线(PPV 与 TPR)。ROC 曲线和精确召回曲线 (PRC) 之间的主要区别在于,真阴性结果的数量不用于制作 PRC。对于不平衡的人群,PRC 优于 ROC 曲线。

# Precision vs Recall
autoplot(pred, type = "prc")

重采样

在评估模型的性能时,我们对它的泛化性能感兴趣——它在训练期间没有看到的新数据上的表现如何?我们可以通过评估测试集上的模型来估计泛化性能。

将数据集划分为训练和测试有许多不同的策略;我们将mlr3这些策略称为“重采样”。 mlr3包括以下预定义的重采样策略:

  1. cross validation( “cv”),
  2. leave-one-out cross validation( “loo”),
  3. repeated cross validation( “repeated_cv”),
  4. bootstrapping( “bootstrap”),
  5. subsampling( “subsampling”),
  6. holdout( “holdout”),
  7. in-sample resampling( “insample”)
  8. custom resampling( “custom”)

我们来看一个例子

library("mlr3verse")

task = tsk("penguins")
learner = lrn("classif.rpart")

在对数据集执行重采样时,我们首先需要定义应该使用哪种方法。

mlr_resamplings这还列出了可以更改以影响每个策略的行为的参数:

as.data.table(mlr_resamplings)
##            key                         label        params iters
## 1:   bootstrap                     Bootstrap ratio,repeats    30
## 2:      custom                 Custom Splits                  NA
## 3:   custom_cv Custom Split Cross-Validation                  NA
## 4:          cv              Cross-Validation         folds    10
## 5:     holdout                       Holdout         ratio     1
## 6:    insample           Insample Resampling                   1
## 7:         loo                 Leave-One-Out                  NA
## 8: repeated_cv     Repeated Cross-Validation folds,repeats   100
## 9: subsampling                   Subsampling ratio,repeats    30

特殊用例的其他重采样方法可通过扩展包获得,例如用于空间数据的mlr3spatiotemporal 。

我们使用holdout 进行重抽样

resampling <- rsmp("holdout")
print(resampling)
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.6667

也可以是

resampling <- mlr_resamplings$get("holdout")

resampling
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.6667

需要注意的是,该$is_instantiated字段设置为FALSE。这意味着我们实际上还没有将策略应用于数据集。

默认情况下,我们将数据以 .66/.33 的比例分成训练和测试。有两种方法可以更改此比率:

  1. \(param_set\)values使用命名列表覆盖插槽:
resampling$param_set$values = list(ratio = 0.8)
  1. 在构造过程(rsmp)中直接指定重采样参数:
rsmp("holdout", ratio = 0.8)
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.8

以上,我们定义好了重采样的策略,下一步就是添加数据进行实例化。

要实际执行拆分并获得训练和测试拆分的索引,重采样需要一个Task. 通过调用该方法instantiate(),我们将数据的索引拆分为训练集和测试集的索引。这些结果索引存储在Resampling对象中。

resampling$instantiate(task)

str(resampling$train_set(1))
##  int [1:275] 1 2 3 4 5 6 7 10 11 13 ...
str(resampling$test_set(1))
##  int [1:69] 8 9 12 15 26 35 36 41 42 52 ...

下一步是调用resample函数

task = tsk("penguins")
learner = lrn("classif.rpart", maxdepth = 3, predict_type = "prob")
resampling = rsmp("cv", folds = 3)

rr = resample(task, learner, resampling, store_models = TRUE)
## INFO  [23:35:27.806] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 1/3) 
## INFO  [23:35:27.837] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 2/3) 
## INFO  [23:35:27.851] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 3/3)
print(rr)
## <ResampleResult> of 3 iterations
## * Task: penguins
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

在这里,我们使用三重交叉验证重采样,它在三个不同的训练和测试集上训练和评估。返回的ResampleResult,存储为rr。

我们可以从rr获取很多信息。根据分类错误计算所有重采样迭代的平均性能:

rr$aggregate(msr("classif.ce"))
## classif.ce 
## 0.07561658

提取各个重采样迭代的性能:

rr$score(msr("classif.ce"))
##                 task  task_id                   learner    learner_id
## 1: <TaskClassif[50]> penguins <LearnerClassifRpart[38]> classif.rpart
## 2: <TaskClassif[50]> penguins <LearnerClassifRpart[38]> classif.rpart
## 3: <TaskClassif[50]> penguins <LearnerClassifRpart[38]> classif.rpart
##            resampling resampling_id iteration              prediction
## 1: <ResamplingCV[20]>            cv         1 <PredictionClassif[20]>
## 2: <ResamplingCV[20]>            cv         2 <PredictionClassif[20]>
## 3: <ResamplingCV[20]>            cv         3 <PredictionClassif[20]>
##    classif.ce
## 1: 0.09565217
## 2: 0.04347826
## 3: 0.08771930

查看警告和错误

rr$warnings
## Empty data.table (0 rows and 2 cols): iteration,msg
rr$errors
## Empty data.table (0 rows and 2 cols): iteration,msg

提取并检查重采样拆分;

rr$resampling
## <ResamplingCV>: Cross-Validation
## * Iterations: 3
## * Instantiated: TRUE
## * Parameters: folds=3

查看模型

lrn = rr$learners 
lrn
## [[1]]
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: rpart
## * Parameters: xval=0, maxdepth=3
## * Packages: mlr3, rpart
## * Predict Type: prob
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights
## 
## [[2]]
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: rpart
## * Parameters: xval=0, maxdepth=3
## * Packages: mlr3, rpart
## * Predict Type: prob
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights
## 
## [[3]]
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: rpart
## * Parameters: xval=0, maxdepth=3
## * Packages: mlr3, rpart
## * Predict Type: prob
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

提取单个预测:

rr$prediction() # all predictions merged into a single Prediction object
## <PredictionClassif> for 344 observations:
##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
##           8    Adelie    Adelie   0.9719626     0.02803738   0.0000000
##          11    Adelie    Adelie   0.9719626     0.02803738   0.0000000
##          15    Adelie    Adelie   0.9719626     0.02803738   0.0000000
## ---                                                                   
##         333 Chinstrap Chinstrap   0.1063830     0.87234043   0.0212766
##         334 Chinstrap Chinstrap   0.1063830     0.87234043   0.0212766
##         344 Chinstrap Chinstrap   0.1063830     0.87234043   0.0212766
rr$predictions() 
## [[1]]
## <PredictionClassif> for 115 observations:
##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
##           8    Adelie    Adelie  0.97196262     0.02803738   0.0000000
##          11    Adelie    Adelie  0.97196262     0.02803738   0.0000000
##          15    Adelie    Adelie  0.97196262     0.02803738   0.0000000
## ---                                                                   
##         340 Chinstrap    Gentoo  0.00000000     0.02352941   0.9764706
##         342 Chinstrap Chinstrap  0.05405405     0.94594595   0.0000000
##         343 Chinstrap    Gentoo  0.00000000     0.02352941   0.9764706
## 
## [[2]]
## <PredictionClassif> for 115 observations:
##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
##           2    Adelie    Adelie  0.95049505     0.04950495  0.00000000
##           3    Adelie    Adelie  0.95049505     0.04950495  0.00000000
##           4    Adelie    Adelie  0.95049505     0.04950495  0.00000000
## ---                                                                   
##         338 Chinstrap Chinstrap  0.02272727     0.95454545  0.02272727
##         339 Chinstrap Chinstrap  0.02272727     0.95454545  0.02272727
##         341 Chinstrap    Adelie  0.95049505     0.04950495  0.00000000
## 
## [[3]]
## <PredictionClassif> for 114 observations:
##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
##           1    Adelie    Adelie   0.9892473     0.01075269   0.0000000
##           5    Adelie    Adelie   0.9892473     0.01075269   0.0000000
##           6    Adelie    Adelie   0.9892473     0.01075269   0.0000000
## ---                                                                   
##         333 Chinstrap Chinstrap   0.1063830     0.87234043   0.0212766
##         334 Chinstrap Chinstrap   0.1063830     0.87234043   0.0212766
##         344 Chinstrap Chinstrap   0.1063830     0.87234043   0.0212766

过滤结果以仅保留指定的重采样迭代:

rr$filter(c(1, 3))
print(rr)
## <ResampleResult> of 2 iterations
## * Task: penguins
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

还可以自定义重采样,有时有必要使用自定义拆分执行重新采样,可以使用“custom”模板创建手动重采样实例。

resampling = rsmp("custom")
resampling$instantiate(task,
  train = list(c(1:10, 51:60, 101:110)), # 训练集的行
  test = list(c(11:20, 61:70, 111:120)) # 测试集的行
)
resampling$iters
## [1] 1
resampling$train_set(1)
##  [1]   1   2   3   4   5   6   7   8   9  10  51  52  53  54  55  56  57  58  59
## [20]  60 101 102 103 104 105 106 107 108 109 110
resampling$test_set(1)
##  [1]  11  12  13  14  15  16  17  18  19  20  61  62  63  64  65  66  67  68  69
## [20]  70 111 112 113 114 115 116 117 118 119 120

mlr3viz提供了autoplot()一种重新采样结果的方法。例如,我们创建一个具有两个特征的二元分类任务,使用 10 倍交叉验证执行重采样并可视化结果:

task = tsk("pima")
task$select(c("glucose", "mass"))
learner = lrn("classif.rpart", predict_type = "prob")
rr = resample(task, learner, rsmp("cv"), store_models = TRUE)
## INFO  [23:35:28.042] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 6/10) 
## INFO  [23:35:28.056] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 9/10) 
## INFO  [23:35:28.075] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 7/10) 
## INFO  [23:35:28.089] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/10) 
## INFO  [23:35:28.102] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/10) 
## INFO  [23:35:28.115] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 5/10) 
## INFO  [23:35:28.129] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 10/10) 
## INFO  [23:35:28.141] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/10) 
## INFO  [23:35:28.153] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 8/10) 
## INFO  [23:35:28.164] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/10)
# boxplot of AUC values across the 10 folds
autoplot(rr, measure = msr("classif.auc"))

绘制ROC 曲线

# ROC curve, averaged over 10 folds
autoplot(rr, type = "roc")

我们还可以绘制单个模型的预测:

# learner predictions for the first fold
rr$filter(1)
autoplot(rr, type = "prediction")
## Warning: Removed 3 rows containing missing values (geom_point).

基准测试 Benchmarking

比较不同学习器在多个任务和/或不同重采样方案上的表现是一项常见任务。这种操作在机器学习领域通常被称为“基准测试”

mlr3通过设计指定基准实验。这样的设计本质上是一张要评估的场景表;Task特别是,Learner和Resampling三元组的独特组合。

我们使用benchmark_grid()函数创建一个详尽的设计(通过每次重采样评估每个任务的每个学习者)并正确实例化重采样,以便所有学习者在每个任务的相同训练/测试拆分上执行。我们让学习者预测概率,并告诉他们预测训练集的观察结果(通过设置predict_sets为c(“train”, “test”))。此外,我们使用tsks(),lrns()和以与,和相同的方式rsmps()检索,和的列表。TaskLearnerResamplingtsk()lrn()rsmp()

library("mlr3verse")

design = benchmark_grid(
  tasks = tsks(c("spam", "german_credit", "sonar")),
  learners = lrns(c("classif.ranger", "classif.rpart", "classif.featureless"),
    predict_type = "prob", predict_sets = c("train", "test")),
  resamplings = rsmps("cv", folds = 3)
)
print(design)
##                 task                         learner         resampling
## 1: <TaskClassif[50]>      <LearnerClassifRanger[38]> <ResamplingCV[20]>
## 2: <TaskClassif[50]>       <LearnerClassifRpart[38]> <ResamplingCV[20]>
## 3: <TaskClassif[50]> <LearnerClassifFeatureless[38]> <ResamplingCV[20]>
## 4: <TaskClassif[50]>      <LearnerClassifRanger[38]> <ResamplingCV[20]>
## 5: <TaskClassif[50]>       <LearnerClassifRpart[38]> <ResamplingCV[20]>
## 6: <TaskClassif[50]> <LearnerClassifFeatureless[38]> <ResamplingCV[20]>
## 7: <TaskClassif[50]>      <LearnerClassifRanger[38]> <ResamplingCV[20]>
## 8: <TaskClassif[50]>       <LearnerClassifRpart[38]> <ResamplingCV[20]>
## 9: <TaskClassif[50]> <LearnerClassifFeatureless[38]> <ResamplingCV[20]>

创建的design可以传递benchmark()给开始计算。

bmr = benchmark(design)
## INFO  [23:35:29.275] [mlr3] Running benchmark with 27 resampling iterations 
## INFO  [23:35:29.279] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 3/3) 
## INFO  [23:35:29.563] [mlr3] Applying learner 'classif.ranger' on task 'sonar' (iter 1/3) 
## INFO  [23:35:29.647] [mlr3] Applying learner 'classif.featureless' on task 'spam' (iter 1/3) 
## INFO  [23:35:29.669] [mlr3] Applying learner 'classif.rpart' on task 'spam' (iter 1/3) 
## INFO  [23:35:29.785] [mlr3] Applying learner 'classif.featureless' on task 'sonar' (iter 1/3) 
## INFO  [23:35:29.796] [mlr3] Applying learner 'classif.featureless' on task 'german_credit' (iter 1/3) 
## INFO  [23:35:29.807] [mlr3] Applying learner 'classif.rpart' on task 'sonar' (iter 3/3) 
## INFO  [23:35:29.832] [mlr3] Applying learner 'classif.rpart' on task 'spam' (iter 2/3) 
## INFO  [23:35:29.903] [mlr3] Applying learner 'classif.ranger' on task 'spam' (iter 3/3) 
## INFO  [23:35:31.255] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 2/3) 
## INFO  [23:35:31.285] [mlr3] Applying learner 'classif.ranger' on task 'sonar' (iter 3/3) 
## INFO  [23:35:31.365] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 2/3) 
## INFO  [23:35:31.632] [mlr3] Applying learner 'classif.featureless' on task 'sonar' (iter 2/3) 
## INFO  [23:35:31.644] [mlr3] Applying learner 'classif.featureless' on task 'german_credit' (iter 2/3) 
## INFO  [23:35:31.656] [mlr3] Applying learner 'classif.rpart' on task 'sonar' (iter 1/3) 
## INFO  [23:35:31.681] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 1/3) 
## INFO  [23:35:31.936] [mlr3] Applying learner 'classif.rpart' on task 'sonar' (iter 2/3) 
## INFO  [23:35:31.961] [mlr3] Applying learner 'classif.rpart' on task 'spam' (iter 3/3) 
## INFO  [23:35:32.032] [mlr3] Applying learner 'classif.ranger' on task 'spam' (iter 2/3) 
## INFO  [23:35:33.406] [mlr3] Applying learner 'classif.featureless' on task 'german_credit' (iter 3/3) 
## INFO  [23:35:33.419] [mlr3] Applying learner 'classif.featureless' on task 'sonar' (iter 3/3) 
## INFO  [23:35:33.429] [mlr3] Applying learner 'classif.featureless' on task 'spam' (iter 3/3) 
## INFO  [23:35:33.452] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 3/3) 
## INFO  [23:35:33.485] [mlr3] Applying learner 'classif.ranger' on task 'spam' (iter 1/3) 
## INFO  [23:35:34.830] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 1/3) 
## INFO  [23:35:34.857] [mlr3] Applying learner 'classif.featureless' on task 'spam' (iter 2/3) 
## INFO  [23:35:34.879] [mlr3] Applying learner 'classif.ranger' on task 'sonar' (iter 2/3) 
## INFO  [23:35:34.960] [mlr3] Finished benchmark

一旦基准测试完成(并且,根据您的设计大小,这可能需要相当长的时间),我们可以使用aggregate来度量训练集和测试集合的效果。

measures = list(
  msr("classif.auc", predict_sets = "train", id = "auc_train"),
  msr("classif.auc", id = "auc_test")
)

tab = bmr$aggregate(measures)
print(tab)
##    nr      resample_result       task_id          learner_id resampling_id
## 1:  1 <ResampleResult[22]>          spam      classif.ranger            cv
## 2:  2 <ResampleResult[22]>          spam       classif.rpart            cv
## 3:  3 <ResampleResult[22]>          spam classif.featureless            cv
## 4:  4 <ResampleResult[22]> german_credit      classif.ranger            cv
## 5:  5 <ResampleResult[22]> german_credit       classif.rpart            cv
## 6:  6 <ResampleResult[22]> german_credit classif.featureless            cv
## 7:  7 <ResampleResult[22]>         sonar      classif.ranger            cv
## 8:  8 <ResampleResult[22]>         sonar       classif.rpart            cv
## 9:  9 <ResampleResult[22]>         sonar classif.featureless            cv
##    iters auc_train  auc_test
## 1:     3 0.9994479 0.9849603
## 2:     3 0.9041325 0.8951998
## 3:     3 0.5000000 0.5000000
## 4:     3 0.9983067 0.7986295
## 5:     3 0.8163585 0.7034258
## 6:     3 0.5000000 0.5000000
## 7:     3 1.0000000 0.8857579
## 8:     3 0.9242398 0.7547127
## 9:     3 0.5000000 0.5000000

我们可以进一步汇总结果。例如,我们可能想知道哪个学习者在所有任务中表现最好。

library("data.table")
# group by levels of task_id, return columns:
# - learner_id
# - rank of col '-auc_train' (per level of learner_id)
# - rank of col '-auc_test' (per level of learner_id)
ranks = tab[, .(learner_id, rank_train = rank(-auc_train), rank_test = rank(-auc_test)), by = task_id]
print(ranks)
##          task_id          learner_id rank_train rank_test
## 1:          spam      classif.ranger          1         1
## 2:          spam       classif.rpart          2         2
## 3:          spam classif.featureless          3         3
## 4: german_credit      classif.ranger          1         1
## 5: german_credit       classif.rpart          2         2
## 6: german_credit classif.featureless          3         3
## 7:         sonar      classif.ranger          1         1
## 8:         sonar       classif.rpart          2         2
## 9:         sonar classif.featureless          3         3
# group by levels of learner_id, return columns:
# - mean rank of col 'rank_train' (per level of learner_id)
# - mean rank of col 'rank_test' (per level of learner_id)
ranks = ranks[, .(mrank_train = mean(rank_train), mrank_test = mean(rank_test)), by = learner_id]

# print the final table, ordered by mean rank of AUC test
ranks[order(mrank_test)]
##             learner_id mrank_train mrank_test
## 1:      classif.ranger           1          1
## 2:       classif.rpart           2          2
## 3: classif.featureless           3          3

我们还可以对基准结果进行可视化。

autoplot(bmr) + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))

我们还可以绘制 ROC(接收器操作特性)曲线。

bmr_small = bmr$clone()$filter(task_id = "german_credit")
autoplot(bmr_small, type = "roc")

提取重采样结果。一个BenchmarkResult对象本质上是多个ResampleResult对象的集合。由于这些存储在聚合的列中data.table(),我们可以轻松地提取它们:

tab = bmr$aggregate(measures)
rr = tab[task_id == "german_credit" & learner_id == "classif.ranger"]$resample_result[[1]]
print(rr)
## <ResampleResult> of 3 iterations
## * Task: german_credit
## * Learner: classif.ranger
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

查看auc

measure = msr("classif.auc")
rr$aggregate(measure)
## classif.auc 
##   0.7986295
# get the iteration with worst AUC
perf = rr$score(measure)
i = which.min(perf$classif.auc)

# get the corresponding learner and training set
print(rr$learners[[i]])
## <LearnerClassifRanger:classif.ranger>
## * Model: -
## * Parameters: num.threads=1
## * Packages: mlr3, mlr3learners, ranger
## * Predict Type: prob
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: hotstart_backward, importance, multiclass, oob_error,
##   twoclass, weights
head(rr$resampling$train_set(i))
## [1]  1  5  7  8 10 13

转换与合并

ResampleResult可以BenchmarkResult使用函数将A转换为 a as_benchmark_result()。我们还可以将两个合并BenchmarkResults为一个更大的结果对象,例如在不同机器上完成的两个相关基准测试。

task = tsk("iris")
resampling = rsmp("holdout")$instantiate(task)

rr1 = resample(task, lrn("classif.rpart"), resampling)
## INFO  [23:35:36.425] [mlr3] Applying learner 'classif.rpart' on task 'iris' (iter 1/1)
rr2 = resample(task, lrn("classif.featureless"), resampling)
## INFO  [23:35:36.450] [mlr3] Applying learner 'classif.featureless' on task 'iris' (iter 1/1)
# Cast both ResampleResults to BenchmarkResults
bmr1 = as_benchmark_result(rr1)
bmr2 = as_benchmark_result(rr2)

# Merge 2nd BMR into the first BMR
bmr1$combine(bmr2)

bmr1
## <BenchmarkResult> of 2 rows with 2 resampling runs
##  nr task_id          learner_id resampling_id iters warnings errors
##   1    iris       classif.rpart       holdout     1        0      0
##   2    iris classif.featureless       holdout     1        0      0

模型优化

模型优化有很多种角度,例如选择更加优秀的特征,调整参数等等。

超参数调优

通过mlr3tuning扩展包支持超参数调整。

  1. TuningInstanceSingleCrit,TuningInstanceMultiCrit描述调优问题并存储结果,以及
  2. Tuner作为实现调优算法的基类。

我们来看一个简单的例子。 首先是定义好task

library("mlr3verse")
task = tsk("pima")
print(task)
## <TaskClassif:pima> (768 x 9): Pima Indian Diabetes
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
##   - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
##     triceps

然后定义好学习器

learner = lrn("classif.rpart")
learner$param_set
## <ParamSet>
##                 id    class lower upper nlevels        default value
##  1:             cp ParamDbl     0     1     Inf           0.01      
##  2:     keep_model ParamLgl    NA    NA       2          FALSE      
##  3:     maxcompete ParamInt     0   Inf     Inf              4      
##  4:       maxdepth ParamInt     1    30      30             30      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
##  7:       minsplit ParamInt     1   Inf     Inf             20      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:   usesurrogate ParamInt     0     2       3              2      
## 10:           xval ParamInt     0   Inf     Inf             10     0

在这里,我们选择调整两个超参数:

  1. cp控制学习器何时考虑引入另一个分支的复杂度超参数。
  2. minsplit控制叶子中必须存在多少观察值才能尝试进行另一个拆分的超参数。

调整空间需要以超参数值的下限和上限为界:

search_space = ps(
  cp = p_dbl(lower = 0.001, upper = 0.1),
  minsplit = p_int(lower = 1, upper = 10)
)
search_space
## <ParamSet>
##          id    class lower upper nlevels        default value
## 1:       cp ParamDbl 0.001   0.1     Inf <NoDefault[3]>      
## 2: minsplit ParamInt 1.000  10.0      10 <NoDefault[3]>

接下来,我们需要指定如何评估训练模型的性能。

hout = rsmp("holdout")
measure = msr("classif.ce")

最后,我们必须指定可用于调整的终止条件。这是至关重要的一步,因为详尽地评估所有可能的超参数配置通常是不可行的。

mlr3允许通过选择以下可用条件之一来指定复杂的终止条件Terminators

  1. 在给定时间后终止 ( TerminatorClockTime)。
  2. 在给定的迭代次数 ( TerminatorEvals) 后终止。
  3. 在达到特定性能后终止 ( TerminatorPerfReached)。
  4. 当调优确实为给定的迭代次数找到更好的配置时终止 ( TerminatorStagnation)。 以上以ALL或ANY方式组合 ( TerminatorCombo)。
library("mlr3tuning")
## Warning: package 'mlr3tuning' was built under R version 4.1.2
## Loading required package: paradox
## Warning: package 'paradox' was built under R version 4.1.2
evals20 = trm("evals", n_evals = 20)

instance = TuningInstanceSingleCrit$new(
  task = task,
  learner = learner,
  resampling = hout,
  measure = measure,
  search_space = search_space,
  terminator = evals20
)
instance
## <TuningInstanceSingleCrit>
## * State:  Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
##          id    class lower upper nlevels
## 1:       cp ParamDbl 0.001   0.1     Inf
## 2: minsplit ParamInt 1.000  10.0      10
## * Terminator: <TerminatorEvals>

要开始调整,我们仍然需要选择优化的方式。换句话说,我们需要通过类来选择优化算法Tuner。

目前在mlr3tuning中实现了以下算法:

  1. 网格搜索 ( TunerGridSearch)
  2. 随机搜索 ( TunerRandomSearch) ( Bergstra 和 Bengio 2012 )
  3. 广义模拟退火 ( TunerGenSA)
  4. 非线性优化 ( TunerNLoptr)

我们将使用网格分辨率为 5 的简单网格搜索。

tuner = tnr("grid_search", resolution = 5)

由于我们只有数字参数,TunerGridSearch因此将在各自的上限和下限之间创建一个等距网格。

接着就可以开始调优

tuner$optimize(instance)
## INFO  [23:35:36.619] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]' 
## INFO  [23:35:36.622] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:36.630] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:36.634] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:36.646] [mlr3] Finished benchmark 
## INFO  [23:35:36.661] [bbotk] Result of batch 1: 
## INFO  [23:35:36.662] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:36.662] [bbotk]  0.02575        1  0.2382812        0      0            0.008 
## INFO  [23:35:36.662] [bbotk]                                 uhash 
## INFO  [23:35:36.662] [bbotk]  bfeccc54-6dbc-433d-a50c-7253c4670497 
## INFO  [23:35:36.663] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:36.670] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:36.675] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:36.689] [mlr3] Finished benchmark 
## INFO  [23:35:36.707] [bbotk] Result of batch 2: 
## INFO  [23:35:36.708] [bbotk]   cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:36.708] [bbotk]  0.1        5  0.2773438        0      0            0.008 
## INFO  [23:35:36.708] [bbotk]                                 uhash 
## INFO  [23:35:36.708] [bbotk]  7e462295-d7b9-4fa6-8f36-177c8ddcf206 
## INFO  [23:35:36.709] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:36.717] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:36.722] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:36.738] [mlr3] Finished benchmark 
## INFO  [23:35:36.754] [bbotk] Result of batch 3: 
## INFO  [23:35:36.755] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:36.755] [bbotk]  0.0505        8  0.2773438        0      0            0.009 
## INFO  [23:35:36.755] [bbotk]                                 uhash 
## INFO  [23:35:36.755] [bbotk]  5c8e97eb-e9fc-4280-963d-63631afe7dd6 
## INFO  [23:35:36.756] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:36.763] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:36.767] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:36.779] [mlr3] Finished benchmark 
## INFO  [23:35:36.795] [bbotk] Result of batch 4: 
## INFO  [23:35:36.796] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:36.796] [bbotk]  0.02575       10  0.2382812        0      0            0.007 
## INFO  [23:35:36.796] [bbotk]                                 uhash 
## INFO  [23:35:36.796] [bbotk]  54f3a8bc-f1f2-47fa-98ff-d1b8c5466eb9 
## INFO  [23:35:36.797] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:36.804] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:36.808] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:36.822] [mlr3] Finished benchmark 
## INFO  [23:35:36.838] [bbotk] Result of batch 5: 
## INFO  [23:35:36.839] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:36.839] [bbotk]  0.07525        8  0.2773438        0      0            0.007 
## INFO  [23:35:36.839] [bbotk]                                 uhash 
## INFO  [23:35:36.839] [bbotk]  6021e15b-fe4c-4d0a-9472-8974d76ac6a8 
## INFO  [23:35:36.840] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:36.847] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:36.851] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:36.869] [mlr3] Finished benchmark 
## INFO  [23:35:36.886] [bbotk] Result of batch 6: 
## INFO  [23:35:36.887] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:36.887] [bbotk]  0.07525        1  0.2773438        0      0            0.007 
## INFO  [23:35:36.887] [bbotk]                                 uhash 
## INFO  [23:35:36.887] [bbotk]  f636a86b-7128-4809-9fb0-d3a25161a67f 
## INFO  [23:35:36.888] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:36.895] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:36.899] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:36.911] [mlr3] Finished benchmark 
## INFO  [23:35:36.926] [bbotk] Result of batch 7: 
## INFO  [23:35:36.927] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:36.927] [bbotk]  0.02575        8  0.2382812        0      0            0.007 
## INFO  [23:35:36.927] [bbotk]                                 uhash 
## INFO  [23:35:36.927] [bbotk]  3be12953-ca21-4d41-93d5-623bfae1b72d 
## INFO  [23:35:36.928] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:36.935] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:36.939] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:36.954] [mlr3] Finished benchmark 
## INFO  [23:35:36.970] [bbotk] Result of batch 8: 
## INFO  [23:35:36.971] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:36.971] [bbotk]  0.001       10  0.2890625        0      0            0.009 
## INFO  [23:35:36.971] [bbotk]                                 uhash 
## INFO  [23:35:36.971] [bbotk]  9f3a7b8f-b6b8-4680-afdf-0ed2f907c66a 
## INFO  [23:35:36.971] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:36.978] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:36.982] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:36.994] [mlr3] Finished benchmark 
## INFO  [23:35:37.010] [bbotk] Result of batch 9: 
## INFO  [23:35:37.011] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.011] [bbotk]  0.0505        1  0.2773438        0      0            0.007 
## INFO  [23:35:37.011] [bbotk]                                 uhash 
## INFO  [23:35:37.011] [bbotk]  c26e62a2-496e-4093-961a-fbf953cbad76 
## INFO  [23:35:37.012] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.019] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.023] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.037] [mlr3] Finished benchmark 
## INFO  [23:35:37.053] [bbotk] Result of batch 10: 
## INFO  [23:35:37.055] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.055] [bbotk]  0.0505       10  0.2773438        0      0            0.008 
## INFO  [23:35:37.055] [bbotk]                                 uhash 
## INFO  [23:35:37.055] [bbotk]  48b683c6-e936-45b1-9a75-9c4e87687d7d 
## INFO  [23:35:37.055] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.063] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.068] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.084] [mlr3] Finished benchmark 
## INFO  [23:35:37.100] [bbotk] Result of batch 11: 
## INFO  [23:35:37.101] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.101] [bbotk]  0.02575        3  0.2382812        0      0            0.008 
## INFO  [23:35:37.101] [bbotk]                                 uhash 
## INFO  [23:35:37.101] [bbotk]  503cd25f-3d6d-4f86-afa3-c4dc127ae332 
## INFO  [23:35:37.102] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.109] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.114] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.128] [mlr3] Finished benchmark 
## INFO  [23:35:37.144] [bbotk] Result of batch 12: 
## INFO  [23:35:37.145] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.145] [bbotk]  0.001        1  0.3046875        0      0            0.009 
## INFO  [23:35:37.145] [bbotk]                                 uhash 
## INFO  [23:35:37.145] [bbotk]  778da658-0d85-46e1-92cf-6da057a5cf45 
## INFO  [23:35:37.146] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.153] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.157] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.170] [mlr3] Finished benchmark 
## INFO  [23:35:37.188] [bbotk] Result of batch 13: 
## INFO  [23:35:37.189] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.189] [bbotk]  0.001        3   0.296875        0      0            0.008 
## INFO  [23:35:37.189] [bbotk]                                 uhash 
## INFO  [23:35:37.189] [bbotk]  5a373c68-3d1e-4396-b258-44cd30143129 
## INFO  [23:35:37.189] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.197] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.201] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.214] [mlr3] Finished benchmark 
## INFO  [23:35:37.230] [bbotk] Result of batch 14: 
## INFO  [23:35:37.231] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.231] [bbotk]  0.001        5  0.2773438        0      0            0.007 
## INFO  [23:35:37.231] [bbotk]                                 uhash 
## INFO  [23:35:37.231] [bbotk]  f6d1f72f-73f6-41f5-bbbe-a2693a7ea8c0 
## INFO  [23:35:37.232] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.240] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.244] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.256] [mlr3] Finished benchmark 
## INFO  [23:35:37.273] [bbotk] Result of batch 15: 
## INFO  [23:35:37.274] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.274] [bbotk]  0.07525       10  0.2773438        0      0            0.007 
## INFO  [23:35:37.274] [bbotk]                                 uhash 
## INFO  [23:35:37.274] [bbotk]  efa79e53-6047-4ed6-9255-55da95576fa9 
## INFO  [23:35:37.275] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.282] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.291] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.303] [mlr3] Finished benchmark 
## INFO  [23:35:37.318] [bbotk] Result of batch 16: 
## INFO  [23:35:37.319] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.319] [bbotk]  0.0505        5  0.2773438        0      0            0.007 
## INFO  [23:35:37.319] [bbotk]                                 uhash 
## INFO  [23:35:37.319] [bbotk]  980d871c-4801-4b1e-b0af-276442f1b2f4 
## INFO  [23:35:37.320] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.327] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.332] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.345] [mlr3] Finished benchmark 
## INFO  [23:35:37.361] [bbotk] Result of batch 17: 
## INFO  [23:35:37.362] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.362] [bbotk]  0.0505        3  0.2773438        0      0            0.008 
## INFO  [23:35:37.362] [bbotk]                                 uhash 
## INFO  [23:35:37.362] [bbotk]  35ece802-7247-4269-a797-76cee3b9a0af 
## INFO  [23:35:37.362] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.369] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.373] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.385] [mlr3] Finished benchmark 
## INFO  [23:35:37.401] [bbotk] Result of batch 18: 
## INFO  [23:35:37.402] [bbotk]   cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.402] [bbotk]  0.1       10  0.2773438        0      0            0.007 
## INFO  [23:35:37.402] [bbotk]                                 uhash 
## INFO  [23:35:37.402] [bbotk]  a494383c-9d09-40c5-a537-08791be99302 
## INFO  [23:35:37.403] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.411] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.415] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.428] [mlr3] Finished benchmark 
## INFO  [23:35:37.445] [bbotk] Result of batch 19: 
## INFO  [23:35:37.446] [bbotk]   cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.446] [bbotk]  0.1        1  0.2773438        0      0            0.009 
## INFO  [23:35:37.446] [bbotk]                                 uhash 
## INFO  [23:35:37.446] [bbotk]  2824c2f3-e732-4240-9e9c-109dc3b52b3d 
## INFO  [23:35:37.447] [bbotk] Evaluating 1 configuration(s) 
## INFO  [23:35:37.454] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [23:35:37.458] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [23:35:37.471] [mlr3] Finished benchmark 
## INFO  [23:35:37.488] [bbotk] Result of batch 20: 
## INFO  [23:35:37.489] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [23:35:37.489] [bbotk]  0.001        8  0.2773438        0      0            0.008 
## INFO  [23:35:37.489] [bbotk]                                 uhash 
## INFO  [23:35:37.489] [bbotk]  b70e2328-6018-4f69-95fb-6c5c1efa9da5 
## INFO  [23:35:37.497] [bbotk] Finished optimizing after 20 evaluation(s) 
## INFO  [23:35:37.498] [bbotk] Result: 
## INFO  [23:35:37.498] [bbotk]       cp minsplit learner_param_vals  x_domain classif.ce 
## INFO  [23:35:37.498] [bbotk]  0.02575        1          <list[3]> <list[2]>  0.2382812
##         cp minsplit learner_param_vals  x_domain classif.ce
## 1: 0.02575        1          <list[3]> <list[2]>  0.2382812

我们可以查看最优的参数

instance$result_learner_param_vals
## $xval
## [1] 0
## 
## $cp
## [1] 0.02575
## 
## $minsplit
## [1] 1
instance$result_y
## classif.ce 
##  0.2382812

所有已执行的评估;

as.data.table(instance$archive)
##          cp minsplit classif.ce x_domain_cp x_domain_minsplit runtime_learners
##  1: 0.02575        1  0.2382812     0.02575                 1            0.008
##  2: 0.10000        5  0.2773438     0.10000                 5            0.008
##  3: 0.05050        8  0.2773438     0.05050                 8            0.009
##  4: 0.02575       10  0.2382812     0.02575                10            0.007
##  5: 0.07525        8  0.2773438     0.07525                 8            0.007
##  6: 0.07525        1  0.2773438     0.07525                 1            0.007
##  7: 0.02575        8  0.2382812     0.02575                 8            0.007
##  8: 0.00100       10  0.2890625     0.00100                10            0.009
##  9: 0.05050        1  0.2773438     0.05050                 1            0.007
## 10: 0.05050       10  0.2773438     0.05050                10            0.008
## 11: 0.02575        3  0.2382812     0.02575                 3            0.008
## 12: 0.00100        1  0.3046875     0.00100                 1            0.009
## 13: 0.00100        3  0.2968750     0.00100                 3            0.008
## 14: 0.00100        5  0.2773438     0.00100                 5            0.007
## 15: 0.07525       10  0.2773438     0.07525                10            0.007
## 16: 0.05050        5  0.2773438     0.05050                 5            0.007
## 17: 0.05050        3  0.2773438     0.05050                 3            0.008
## 18: 0.10000       10  0.2773438     0.10000                10            0.007
## 19: 0.10000        1  0.2773438     0.10000                 1            0.009
## 20: 0.00100        8  0.2773438     0.00100                 8            0.008
##               timestamp batch_nr warnings errors      resample_result
##  1: 2022-04-27 23:35:36        1        0      0 <ResampleResult[22]>
##  2: 2022-04-27 23:35:36        2        0      0 <ResampleResult[22]>
##  3: 2022-04-27 23:35:36        3        0      0 <ResampleResult[22]>
##  4: 2022-04-27 23:35:36        4        0      0 <ResampleResult[22]>
##  5: 2022-04-27 23:35:36        5        0      0 <ResampleResult[22]>
##  6: 2022-04-27 23:35:36        6        0      0 <ResampleResult[22]>
##  7: 2022-04-27 23:35:36        7        0      0 <ResampleResult[22]>
##  8: 2022-04-27 23:35:36        8        0      0 <ResampleResult[22]>
##  9: 2022-04-27 23:35:37        9        0      0 <ResampleResult[22]>
## 10: 2022-04-27 23:35:37       10        0      0 <ResampleResult[22]>
## 11: 2022-04-27 23:35:37       11        0      0 <ResampleResult[22]>
## 12: 2022-04-27 23:35:37       12        0      0 <ResampleResult[22]>
## 13: 2022-04-27 23:35:37       13        0      0 <ResampleResult[22]>
## 14: 2022-04-27 23:35:37       14        0      0 <ResampleResult[22]>
## 15: 2022-04-27 23:35:37       15        0      0 <ResampleResult[22]>
## 16: 2022-04-27 23:35:37       16        0      0 <ResampleResult[22]>
## 17: 2022-04-27 23:35:37       17        0      0 <ResampleResult[22]>
## 18: 2022-04-27 23:35:37       18        0      0 <ResampleResult[22]>
## 19: 2022-04-27 23:35:37       19        0      0 <ResampleResult[22]>
## 20: 2022-04-27 23:35:37       20        0      0 <ResampleResult[22]>

可以在BenchmarkResult调整实例中访问相关的重采样迭代:

instance$archive$benchmark_result
## <BenchmarkResult> of 20 rows with 20 resampling runs
##  nr task_id    learner_id resampling_id iters warnings errors
##   1    pima classif.rpart       holdout     1        0      0
##   2    pima classif.rpart       holdout     1        0      0
##   3    pima classif.rpart       holdout     1        0      0
##   4    pima classif.rpart       holdout     1        0      0
##   5    pima classif.rpart       holdout     1        0      0
##   6    pima classif.rpart       holdout     1        0      0
##   7    pima classif.rpart       holdout     1        0      0
##   8    pima classif.rpart       holdout     1        0      0
##   9    pima classif.rpart       holdout     1        0      0
##  10    pima classif.rpart       holdout     1        0      0
##  11    pima classif.rpart       holdout     1        0      0
##  12    pima classif.rpart       holdout     1        0      0
##  13    pima classif.rpart       holdout     1        0      0
##  14    pima classif.rpart       holdout     1        0      0
##  15    pima classif.rpart       holdout     1        0      0
##  16    pima classif.rpart       holdout     1        0      0
##  17    pima classif.rpart       holdout     1        0      0
##  18    pima classif.rpart       holdout     1        0      0
##  19    pima classif.rpart       holdout     1        0      0
##  20    pima classif.rpart       holdout     1        0      0

查看不同模型的预测评估结果

instance$archive$benchmark_result$score(msr("classif.acc"))
##                                    uhash nr              task task_id
##  1: bfeccc54-6dbc-433d-a50c-7253c4670497  1 <TaskClassif[50]>    pima
##  2: 7e462295-d7b9-4fa6-8f36-177c8ddcf206  2 <TaskClassif[50]>    pima
##  3: 5c8e97eb-e9fc-4280-963d-63631afe7dd6  3 <TaskClassif[50]>    pima
##  4: 54f3a8bc-f1f2-47fa-98ff-d1b8c5466eb9  4 <TaskClassif[50]>    pima
##  5: 6021e15b-fe4c-4d0a-9472-8974d76ac6a8  5 <TaskClassif[50]>    pima
##  6: f636a86b-7128-4809-9fb0-d3a25161a67f  6 <TaskClassif[50]>    pima
##  7: 3be12953-ca21-4d41-93d5-623bfae1b72d  7 <TaskClassif[50]>    pima
##  8: 9f3a7b8f-b6b8-4680-afdf-0ed2f907c66a  8 <TaskClassif[50]>    pima
##  9: c26e62a2-496e-4093-961a-fbf953cbad76  9 <TaskClassif[50]>    pima
## 10: 48b683c6-e936-45b1-9a75-9c4e87687d7d 10 <TaskClassif[50]>    pima
## 11: 503cd25f-3d6d-4f86-afa3-c4dc127ae332 11 <TaskClassif[50]>    pima
## 12: 778da658-0d85-46e1-92cf-6da057a5cf45 12 <TaskClassif[50]>    pima
## 13: 5a373c68-3d1e-4396-b258-44cd30143129 13 <TaskClassif[50]>    pima
## 14: f6d1f72f-73f6-41f5-bbbe-a2693a7ea8c0 14 <TaskClassif[50]>    pima
## 15: efa79e53-6047-4ed6-9255-55da95576fa9 15 <TaskClassif[50]>    pima
## 16: 980d871c-4801-4b1e-b0af-276442f1b2f4 16 <TaskClassif[50]>    pima
## 17: 35ece802-7247-4269-a797-76cee3b9a0af 17 <TaskClassif[50]>    pima
## 18: a494383c-9d09-40c5-a537-08791be99302 18 <TaskClassif[50]>    pima
## 19: 2824c2f3-e732-4240-9e9c-109dc3b52b3d 19 <TaskClassif[50]>    pima
## 20: b70e2328-6018-4f69-95fb-6c5c1efa9da5 20 <TaskClassif[50]>    pima
##                       learner    learner_id              resampling
##  1: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
##  2: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
##  3: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
##  4: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
##  5: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
##  6: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
##  7: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
##  8: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
##  9: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 10: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 11: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 12: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 13: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 14: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 15: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 16: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 17: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 18: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 19: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
## 20: <LearnerClassifRpart[38]> classif.rpart <ResamplingHoldout[20]>
##     resampling_id iteration              prediction classif.acc
##  1:       holdout         1 <PredictionClassif[20]>   0.7617188
##  2:       holdout         1 <PredictionClassif[20]>   0.7226562
##  3:       holdout         1 <PredictionClassif[20]>   0.7226562
##  4:       holdout         1 <PredictionClassif[20]>   0.7617188
##  5:       holdout         1 <PredictionClassif[20]>   0.7226562
##  6:       holdout         1 <PredictionClassif[20]>   0.7226562
##  7:       holdout         1 <PredictionClassif[20]>   0.7617188
##  8:       holdout         1 <PredictionClassif[20]>   0.7109375
##  9:       holdout         1 <PredictionClassif[20]>   0.7226562
## 10:       holdout         1 <PredictionClassif[20]>   0.7226562
## 11:       holdout         1 <PredictionClassif[20]>   0.7617188
## 12:       holdout         1 <PredictionClassif[20]>   0.6953125
## 13:       holdout         1 <PredictionClassif[20]>   0.7031250
## 14:       holdout         1 <PredictionClassif[20]>   0.7226562
## 15:       holdout         1 <PredictionClassif[20]>   0.7226562
## 16:       holdout         1 <PredictionClassif[20]>   0.7226562
## 17:       holdout         1 <PredictionClassif[20]>   0.7226562
## 18:       holdout         1 <PredictionClassif[20]>   0.7226562
## 19:       holdout         1 <PredictionClassif[20]>   0.7226562
## 20:       holdout         1 <PredictionClassif[20]>   0.7226562

现在我们可以采用优化的超参数,将它们设置为先前创建的Learner,并在完整数据集上对其进行训练。

learner$param_set$values = instance$result_learner_param_vals
learner$train(task)

训练后的模型现在可用于对新的外部数据进行预测