はじめに

Rのmlrパッケージのチュートリアルにあるコードの写経をする. 説明もほとんど丸写しになるかもしれないが,翻訳ではないし, この文書の著者自身のメモや意見が含まれている.

mlrのバージョンは2015年10月現在の最新版である2.4を使う.

公式な情報源は以下の通り.

チュートリアルとヘルプは非常に充実しているので, 英語が苦にならない人はそちらを読むべきである.

mlrの特徴

Rには機械学習のための標準インターフェイスがない. mlrの目的は分類・回帰・クラスタリングや生存時間分析といった 機械学習タスクへの統一されたインターフェイスを提供することだ.

mlrは以下のような特徴を持つ.

インストール

いつも通りやればよい.

# CRAN
# install.packages("mlr")
# 開発版
# devtools::install_github("mlr-org/mlr")
library(mlr)
## Loading required package: BBmisc
## Loading required package: ggplot2
## Loading required package: ParamHelpers

まずは使ってみる

irisの判別分析をしてクロスバリデーションで誤分類率を評価する.

data(iris)

## タスクの定義
task = makeClassifTask(id = "tutorial", data = iris, target = "Species")
## 学習器の定義
lrn = makeLearner("classif.lda")
## リサンプリング法の定義
rdesc = makeResampleDesc(method = "CV", stratify = TRUE)
## 実行!
r = resample(learner = lrn, task = task, resampling = rdesc, show.info = FALSE)
## 誤分類率
r$aggr
## mmce.test.mean 
##           0.02

この例を見るだけでも非常に整理されたインターフェイスが提供されていそうなことが分かる.

基本的な機能

学習タスク

タスクの種類

タスクは学習に使うデータや問題の設定を持っておくためのオブジェクト. 具体的なタスクのクラスは以下のように色々あるが,皆Taskクラスを継承している.

  • 分類: ClassifTask
  • 回帰: RegrTask
  • 生存時間分析: SuvTask
  • コスト考慮型分類: CostSensTask
  • クラスタリング: ClusterTask

タスクを作成するには,makeClassifTaskのような関数を呼べばよい. make*関数には,学習データのデータフレームをdata引数で渡す必要がある. id引数でタスクの識別子になる文字列を指定できるが,必須ではない.

以下,具体例を見ていく.

回帰

教師あり学習では,make*関数のtarget引数で目的変数を指定する.

data(BostonHousing, package = "mlbench")
regr.task = makeRegrTask(id = "bh", data = BostonHousing, target = "medv")
regr.task
## Supervised task: bh
## Type: regr
## Target: medv
## Observations: 506
## Features:
## numerics  factors  ordered 
##       12        1        0 
## Missings: FALSE
## Has weights: FALSE
## Has blocking: FALSE

タスクオブジェクトがデータや問題の情報を持っていることが分かる.

分類

分類タスクの場合はtargetにする変数はファクタ型でなければならない.

data(BreastCancer, package = "mlbench")
df = BreastCancer
df$Id = NULL
classif.task = makeClassifTask(id = "BreastCancer", data = df, target = "Class")
# positiveなクラスを自分で決めたい場合
# classif.task = makeClassifTask(id = "BreastCancer", data = df, target = "Class", positive = "malignant")
classif.task
## Supervised task: BreastCancer
## Type: classif
## Target: Class
## Observations: 699
## Features:
## numerics  factors  ordered 
##        0        4        5 
## Missings: TRUE
## Has weights: FALSE
## Has blocking: FALSE
## Classes: 2
##    benign malignant 
##       458       241 
## Positive class: benign

データ中の各クラスの件数も集計してくれている.

クラスタリング

教師なし学習なのでtarget引数が不要.

data(mtcars, package = "datasets")
cluster.task = makeClusterTask(data = mtcars)
cluster.task
## Unsupervised task: mtcars
## Type: cluster
## Observations: 32
## Features:
## numerics  factors  ordered 
##       11        0        0 
## Missings: FALSE
## Has weights: FALSE
## Has blocking: FALSE

生存時間分析

生存時間分析では,targetに生存時間と打ち切り有無の2変数を渡す必要がある. 打ち切り有無の変数は論理型でなければならない.

data(lung, package = "survival")
lung$status = (lung$status == 2)
surv.task = makeSurvTask(data = lung, target = c("time", "status"))
surv.task
## Supervised task: lung
## Type: surv
## Target: time,status
## Observations: 228
## Features:
## numerics  factors  ordered 
##        8        0        0 
## Missings: TRUE
## Has weights: FALSE
## Has blocking: FALSE

コスト考慮型分類

ここでのコスト考慮型分類とは,データの1ケース毎に誤分類コストを与えて学習させること. したがって,makeCostSensTaskにはデータdataとコスト行列costを与える必要がある.

コスト行列は(データ件数N)×(クラス数K)の行列として与える.

コスト行列の各行毎に,コスト最小の列に対応するクラスが正しいクラスだと仮定されるので, makeCostSensTaskにはtarget引数を与える必要はない.

df = iris
# 架空のコスト行列
cost = matrix(runif(150 * 3, 0, 2000), 150) * (1 - diag(3))[df$Species,]
colnames(cost) = levels(df$Species)
df$Species = NULL

costsens.task = makeCostSensTask(data = df, cost = cost)
costsens.task
## Supervised task: df
## Type: costsens
## Observations: 150
## Features:
## numerics  factors  ordered 
##        4        0        0 
## Missings: FALSE
## Has blocking: FALSE
## Classes: 3
## setosa, versicolor, virginica

タスクへのアクセス方法

タスクオブジェクトから情報を抜き出すためのアクセサメソッドが色々と提供されている (実際にはTaskオブジェクトの実体はリストなので,必ずしもこれらの関数を使わなくてもタスク情報にアクセスできる).

# タスクの概要
getTaskDescription(regr.task)
## $id
## [1] "bh"
## 
## $type
## [1] "regr"
## 
## $target
## [1] "medv"
## 
## $size
## [1] 506
## 
## $n.feat
## numerics  factors  ordered 
##       12        1        0 
## 
## $has.missings
## [1] FALSE
## 
## $has.weights
## [1] FALSE
## 
## $has.blocking
## [1] FALSE
## 
## attr(,"class")
## [1] "TaskDescRegr" "TaskDesc"
# タスクで使うデータ
str(getTaskData(classif.task))
## 'data.frame':    699 obs. of  10 variables:
##  $ Cl.thickness   : Ord.factor w/ 10 levels "1"<"2"<"3"<"4"<..: 5 5 3 6 4 8 1 2 2 4 ...
##  $ Cell.size      : Ord.factor w/ 10 levels "1"<"2"<"3"<"4"<..: 1 4 1 8 1 10 1 1 1 2 ...
##  $ Cell.shape     : Ord.factor w/ 10 levels "1"<"2"<"3"<"4"<..: 1 4 1 8 1 10 1 2 1 1 ...
##  $ Marg.adhesion  : Ord.factor w/ 10 levels "1"<"2"<"3"<"4"<..: 1 5 1 1 3 8 1 1 1 1 ...
##  $ Epith.c.size   : Ord.factor w/ 10 levels "1"<"2"<"3"<"4"<..: 2 7 2 3 2 7 2 2 2 2 ...
##  $ Bare.nuclei    : Factor w/ 10 levels "1","2","3","4",..: 1 10 2 4 1 10 10 1 1 1 ...
##  $ Bl.cromatin    : Factor w/ 10 levels "1","2","3","4",..: 3 3 3 3 3 9 3 3 1 2 ...
##  $ Normal.nucleoli: Factor w/ 10 levels "1","2","3","4",..: 1 2 1 7 1 7 1 1 1 1 ...
##  $ Mitoses        : Factor w/ 9 levels "1","2","3","4",..: 1 1 1 1 1 1 1 1 5 1 ...
##  $ Class          : Factor w/ 2 levels "benign","malignant": 1 1 1 1 1 2 1 1 1 1 ...
# データの件数
getTaskSize(classif.task)
## [1] 699
# 特徴量の数
getTaskNFeats(cluster.task)
## [1] 11
# 特徴量の名前
getTaskFeatureNames(cluster.task)
##  [1] "mpg"  "cyl"  "disp" "hp"   "drat" "wt"   "qsec" "vs"   "am"   "gear"
## [11] "carb"
# 目的変数の名前
getTaskTargetNames(surv.task)
## [1] "time"   "status"
# 目的変数の値
head(getTaskTargets(regr.task))
## [1] 24.0 21.6 34.7 33.4 36.2 28.7
# コスト行列
head(getTaskCosts(costsens.task))
##      setosa versicolor virginica
## [1,]      0  1149.4722 1358.3272
## [2,]      0   786.5157  392.2210
## [3,]      0   251.0483 1318.4457
## [4,]      0   156.6110 1871.8594
## [5,]      0   763.1985 1070.2518
## [6,]      0  1128.9268  304.4168
# タスクのformula
getTaskFormula(surv.task)
## Surv(time, status, type = "right") ~ .
## <environment: 0x000000000ba43148>
getTaskFormulaAsString(surv.task)
## [1] "Surv(time, status, type = \"right\") ~ ."

タスクの修正や変更

既存のタスクに修正や変更を加えて新しいタスクを作ることができる.

また,基本的なデータクリーニングのための関数も用意されている.

# データの一部だけ使う
subsetTask(cluster.task, subset = 4:17)
## Unsupervised task: mtcars
## Type: cluster
## Observations: 14
## Features:
## numerics  factors  ordered 
##       11        0        0 
## Missings: FALSE
## Has weights: FALSE
## Has blocking: FALSE
# 定数値の特徴量を捨てる
removeConstantFeatures(cluster.task)
## Unsupervised task: mtcars
## Type: cluster
## Observations: 32
## Features:
## numerics  factors  ordered 
##       11        0        0 
## Missings: FALSE
## Has weights: FALSE
## Has blocking: FALSE
# 指定した特徴量を捨てる
dropFeatures(surv.task, c("meal.cal", "wt.loss"))
## Supervised task: lung
## Type: surv
## Target: time,status
## Observations: 228
## Features:
## numerics  factors  ordered 
##        6        0        0 
## Missings: TRUE
## Has weights: FALSE
## Has blocking: FALSE
# 変数の正規化
## (method="range"は値が区間[0, 1]に収まるようスケールする)
summary(getTaskData(normalizeFeatures(cluster.task, method = "range")))
##       mpg              cyl              disp              hp        
##  Min.   :0.0000   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.2138   1st Qu.:0.0000   1st Qu.:0.1240   1st Qu.:0.1572  
##  Median :0.3745   Median :0.5000   Median :0.3123   Median :0.2509  
##  Mean   :0.4124   Mean   :0.5469   Mean   :0.3982   Mean   :0.3346  
##  3rd Qu.:0.5277   3rd Qu.:1.0000   3rd Qu.:0.6358   3rd Qu.:0.4523  
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
##       drat              wt              qsec              vs        
##  Min.   :0.0000   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.1475   1st Qu.:0.2731   1st Qu.:0.2848   1st Qu.:0.0000  
##  Median :0.4309   Median :0.4633   Median :0.3821   Median :0.0000  
##  Mean   :0.3855   Mean   :0.4358   Mean   :0.3987   Mean   :0.4375  
##  3rd Qu.:0.5346   3rd Qu.:0.5362   3rd Qu.:0.5238   3rd Qu.:1.0000  
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
##        am              gear             carb       
##  Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.1429  
##  Median :0.0000   Median :0.5000   Median :0.1429  
##  Mean   :0.4062   Mean   :0.3438   Mean   :0.2589  
##  3rd Qu.:1.0000   3rd Qu.:0.5000   3rd Qu.:0.4286  
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000
## (method="standardize"は平均0,分散1に正規化する)
summary(getTaskData(normalizeFeatures(regr.task, method = "standardize")))
##       crim                 zn               indus         chas   
##  Min.   :-0.419367   Min.   :-0.48724   Min.   :-1.5563   0:471  
##  1st Qu.:-0.410563   1st Qu.:-0.48724   1st Qu.:-0.8668   1: 35  
##  Median :-0.390280   Median :-0.48724   Median :-0.2109          
##  Mean   : 0.000000   Mean   : 0.00000   Mean   : 0.0000          
##  3rd Qu.: 0.007389   3rd Qu.: 0.04872   3rd Qu.: 1.0150          
##  Max.   : 9.924110   Max.   : 3.80047   Max.   : 2.4202          
##       nox                rm               age               dis         
##  Min.   :-1.4644   Min.   :-3.8764   Min.   :-2.3331   Min.   :-1.2658  
##  1st Qu.:-0.9121   1st Qu.:-0.5681   1st Qu.:-0.8366   1st Qu.:-0.8049  
##  Median :-0.1441   Median :-0.1084   Median : 0.3171   Median :-0.2790  
##  Mean   : 0.0000   Mean   : 0.0000   Mean   : 0.0000   Mean   : 0.0000  
##  3rd Qu.: 0.5981   3rd Qu.: 0.4823   3rd Qu.: 0.9059   3rd Qu.: 0.6617  
##  Max.   : 2.7296   Max.   : 3.5515   Max.   : 1.1164   Max.   : 3.9566  
##       rad               tax             ptratio              b          
##  Min.   :-0.9819   Min.   :-1.3127   Min.   :-2.7047   Min.   :-3.9033  
##  1st Qu.:-0.6373   1st Qu.:-0.7668   1st Qu.:-0.4876   1st Qu.: 0.2049  
##  Median :-0.5225   Median :-0.4642   Median : 0.2746   Median : 0.3808  
##  Mean   : 0.0000   Mean   : 0.0000   Mean   : 0.0000   Mean   : 0.0000  
##  3rd Qu.: 1.6596   3rd Qu.: 1.5294   3rd Qu.: 0.8058   3rd Qu.: 0.4332  
##  Max.   : 1.6596   Max.   : 1.7964   Max.   : 1.6372   Max.   : 0.4406  
##      lstat              medv      
##  Min.   :-1.5296   Min.   : 5.00  
##  1st Qu.:-0.7986   1st Qu.:17.02  
##  Median :-0.1811   Median :21.20  
##  Mean   : 0.0000   Mean   :22.53  
##  3rd Qu.: 0.6024   3rd Qu.:25.00  
##  Max.   : 3.5453   Max.   :50.00

学習器

mlrでは,学習器はLearnerというクラスに抽象化されている.

mlrに統合済みの機械学習パッケージ一覧はここで見られる.

よく知られたパッケージはかなり入っているが,最近イケイケな感じのパッケージで入っていないものもある (例えば xgboostrangerrborist あたり.).

新たなLearnerの追加は容易にできるらしいので,やる気のある人は開発されたし.

学習器の作成

Learnerオブジェクトを作成するにはmakeLearner関数を呼べばよい. Learnerには以下のような情報を与える.

  • 学習器(これは必須)
  • ハイパーパラメータの値
  • 出力の種類(ただの分類か,確率も出力するのか等)
  • 学習器の識別子

ともかく例を見よう.

# ランダムフォレストで確率を出力
# (fix.factors.prediction=TRUEはファクタの水準数が学習データとテストデータで異なる場合に生じる問題をうまく処理してくれる)
classif.lrn = makeLearner("classif.randomForest", predict.type = "prob", fix.factors.prediction = TRUE)
# 勾配ブースティングでハイパーパラメータを指定
regr.lrn = makeLearner("regr.gbm", par.vals = list(n.trees = 500, interaction.depth = 3))
# Cox比例ハザードモデル,識別子付き
surv.lrn = makeLearner("surv.coxph", id = "cph")
# クラスタ数5でK-means
cluster.lrn = makeLearner("cluster.kmeans", centers = 5)

classif.lrn
## Learner classif.randomForest from package randomForest
## Type: classif
## Name: Random Forest; Short name: rf
## Class: classif.randomForest
## Properties: twoclass,multiclass,numerics,factors,ordered,prob
## Predict-Type: prob
## Hyperparameters:
cluster.lrn
## Learner cluster.kmeans from package stats,clue
## Type: cluster
## Name: K-Means; Short name: kmeans
## Class: cluster.kmeans
## Properties: numerics
## Predict-Type: response
## Hyperparameters: centers=<numeric>

一見してわかるように,学習器の名前は<タスクの種類>.<Rの関数名>で統一されている. また扱える特徴量の型や出力の種類といった学習器の性質がプロパティ(Properties)として抽象化されていることが分かる.

学習器へのアクセス

LearnerTaskと同様,実体はリストだが色々なアクセサメソッドが提供されている.

# 設定済みハイパーパラメータ
getHyperPars(cluster.lrn)
## $centers
## [1] 5
# 設定可能なハイパーパラメータ一覧
getParamSet(cluster.lrn)
##               Type len           Def                             Constr
## centers    untyped   -             -                                  -
## iter.max   integer   -            10                           1 to Inf
## nstart     integer   -             1                           1 to Inf
## algorithm discrete   - Hartigan-Wong Hartigan-Wong,Lloyd,Forgy,MacQueen
## trace      logical   -             -                                  -
##           Req Trafo
## centers     -     -
## iter.max    -     -
## nstart      -     -
## algorithm   -     -
## trace       -     -
# ハイパーパラメータ一覧はLearnerを作成しなくても学習器名から取得できる
getParamSet("classif.randomForest")
##                     Type  len   Def   Constr Req Trafo
## ntree            integer    -   500 1 to Inf   -     -
## mtry             integer    -     - 1 to Inf   -     -
## replace          logical    -  TRUE        -   -     -
## classwt    numericvector <NA>     - 0 to Inf   -     -
## cutoff     numericvector <NA>     -   0 to 1   -     -
## sampsize   integervector <NA>     - 0 to Inf   -     -
## nodesize         integer    -     1 1 to Inf   -     -
## maxnodes         integer    -     - 1 to Inf   -     -
## importance       logical    - FALSE        -   -     -
## localImp         logical    - FALSE        -   -     -
## norm.votes       logical    -  TRUE        -   -     -
## keep.inbag       logical    - FALSE        -   -     -

学習器の修正や変更

既存のLernerを修正・変更して新しいLearnerを作成できる.

## 学習器の識別子を設定
surv.lrn = setId(surv.lrn, "CoxModel")
## 出力を確率からクラスラベルに変更
classif.lrn = setPredictType(classif.lrn, "response")
## ハイパーパラメータの値を変更
cluster.lrn = setHyperPars(cluster.lrn, centers = 4)
## ハイパーパラメータの値をデフォルト値にする
regr.lrn = removeHyperPars(regr.lrn, c("n.trees", "interaction.depth"))

学習器を探す

listLearners関数で,利用可能な学習器の一覧が見られる. 指定したタスクの種類やプロパティを持つ学習器だけ探すことも可能.

# 利用可能な学習器全部
head(listLearners())
##           classif.ada   classif.bartMachine           classif.bdk 
##         "classif.ada" "classif.bartMachine"         "classif.bdk" 
##      classif.binomial      classif.boosting    classif.extraTrees 
##    "classif.binomial"    "classif.boosting"  "classif.extraTrees"
# 確率が出力できる分類器だけ
head(listLearners("classif", properties = "prob"))
##           classif.ada   classif.bartMachine           classif.bdk 
##         "classif.ada" "classif.bartMachine"         "classif.bdk" 
##      classif.binomial      classif.boosting    classif.extraTrees 
##    "classif.binomial"    "classif.boosting"  "classif.extraTrees"

学習の実行

train関数にLearnerオブジェクトとTaskオブジェクトを渡せば学習が実行される.

# 線形判別分析でirisを分類
lrn = makeLearner("classif.lda")
mod = train(lrn, iris.task)
mod
## Model for learner.id=classif.lda; learner.class=classif.lda
## Trained on: task.id = iris-example; obs = 150; features = 4
## Hyperparameters:
# Learnerを明示的に作らなくても学習できる
train("classif.randomForest", iris.task)
## Model for learner.id=classif.randomForest; learner.class=classif.randomForest
## Trained on: task.id = iris-example; obs = 150; features = 4
## Hyperparameters:
# 学習に使うデータを指定できる
n = getTaskSize(bh.task)
train.set = sample(n, size = n/3)
train("regr.lm", bh.task, subset = train.set)
## Model for learner.id=regr.lm; learner.class=regr.lm
## Trained on: task.id = BostonHousing-example; obs = 168; features = 13
## Hyperparameters:
# データの重み付けもできる(Taskで設定した値があれば上書きする)
target = getTaskTargets(bc.task)
tab = as.numeric(table(target))
w = 1/tab[target]
train("classif.rpart", task = bc.task, weights = w)
## Model for learner.id=classif.rpart; learner.class=classif.rpart
## Trained on: task.id = BreastCancer-example; obs = 683; features = 9
## Hyperparameters: xval=0

train関数が返すオブジェクトはWrappedModelというクラスであり, 名前の通り,元々のRの関数が返すモデルオブジェクトをラップしたものになっている.

元のRのモデルオブジェクトが欲しいときはgetLearnerModel関数を呼ぶ.

getLearnerModel(mod)
## Call:
## lda(f, data = getTaskData(.task, .subset))
## 
## Prior probabilities of groups:
##     setosa versicolor  virginica 
##  0.3333333  0.3333333  0.3333333 
## 
## Group means:
##            Sepal.Length Sepal.Width Petal.Length Petal.Width
## setosa            5.006       3.428        1.462       0.246
## versicolor        5.936       2.770        4.260       1.326
## virginica         6.588       2.974        5.552       2.026
## 
## Coefficients of linear discriminants:
##                     LD1         LD2
## Sepal.Length  0.8293776  0.02410215
## Sepal.Width   1.5344731  2.16452123
## Petal.Length -2.2012117 -0.93192121
## Petal.Width  -2.8104603  2.83918785
## 
## Proportion of trace:
##    LD1    LD2 
## 0.9912 0.0088

予測

train関数で作成したモデルで新しいデータに対する予測をしたい場合はpredict関数を使う.

predict関数へのデータの与え方は2通りある.

  • task引数にTaskオブジェクトを渡す方法
  • newdata引数にデータフレームを直接渡す方法
# taskでデータを渡す例
n = getTaskSize(bh.task)
train.set = seq(1, n, by = 2)
test.set = seq(2, n, by = 2)
lrn = makeLearner("regr.gbm", n.trees = 100)
mod = train(lrn, bh.task, subset = train.set)

task.pred = predict(mod, task = bh.task, subset = test.set)
task.pred
## Prediction: 253 observations
## predict.type: response
## threshold: 
## time: 0.00
##    id truth response
## 2   2  21.6 22.26409
## 4   4  33.4 23.24189
## 6   6  28.7 22.39762
## 8   8  27.1 22.13458
## 10 10  18.9 22.13458
## 12 12  18.9 22.14528
# newdataでデータを渡す例
n = nrow(iris)
iris.train = iris[seq(1, n, by = 2), -5]
iris.test = iris[seq(2, n, by = 2), -5]
task = makeClusterTask(data = iris.train)
mod = train("cluster.kmeans", task)

newdata.pred = predict(mod, newdata = iris.test)
newdata.pred
## Prediction: 75 observations
## predict.type: response
## threshold: 
## time: 0.00
##    response
## 2         1
## 4         1
## 6         1
## 8         1
## 10        1
## 12        1

predict関数はPredictionオブジェクトを返す. このオブジェクトには予測に関する情報が色々入っているが,大事なのはdata要素で,ここに予測値が入っている.

Predictionオブジェクトに対しても種々のアクセサメソッドや便利な関数が提供されている.

# 確率を出力するランダムフォレスト
lrn = makeLearner("classif.randomForest", predict.type = "prob")
mod = train(lrn, iris.task)
pred = predict(mod, newdata = iris)
# 予測値
head(pred$data)
##    truth prob.setosa prob.versicolor prob.virginica response
## 1 setosa           1               0              0   setosa
## 2 setosa           1               0              0   setosa
## 3 setosa           1               0              0   setosa
## 4 setosa           1               0              0   setosa
## 5 setosa           1               0              0   setosa
## 6 setosa           1               0              0   setosa
# getProbabilitiesで確率だけ取得
head(getProbabilities(pred))
##   setosa versicolor virginica
## 1      1          0         0
## 2      1          0         0
## 3      1          0         0
## 4      1          0         0
## 5      1          0         0
## 6      1          0         0
# confusion matrix
getConfMatrix(pred)
##             predicted
## true         setosa versicolor virginica -SUM-
##   setosa         50          0         0     0
##   versicolor      0         50         0     0
##   virginica       0          0        50     0
##   -SUM-           0          0         0     0

閾値の調整

確率を出力する(2値)分類器の場合,デフォルトでは確率を閾値0.5で切って所属クラスを定めるようになっている. setThreshold関数を使うとこの閾値を調整することができる. (多クラス分類でも閾値調整はできるが,省略.)

# 決定木でソナー信号を分類
lrn = makeLearner("classif.rpart", predict.type = "prob")
mod = train(lrn, task = sonar.task)

## デフォルト閾値は0.5
pred1 = predict(mod, sonar.task)
pred1$threshold
##   M   R 
## 0.5 0.5
## 閾値を0.9に変更
pred2 = setThreshold(pred1, 0.9)
pred2$threshold
##   M   R 
## 0.9 0.1
# 閾値変更前後のconfusion matrixを比較
getConfMatrix(pred1)
##        predicted
## true     M  R -SUM-
##   M     95 16    16
##   R     10 87    10
##   -SUM- 10 16    26
getConfMatrix(pred2)
##        predicted
## true     M  R -SUM-
##   M     84 27    27
##   R      6 91     6
##   -SUM-  6 27    33

予測の可視化

plotLearnerPrediction関数にLearnerオブジェクトとTaskオブジェクトを渡すと, 適当な特徴量2つで学習させたモデルをggplot2で可視化してくれる.

なぜ構築した本物のモデルではなく,2変数で作ったモデルしかプロットしてくれないのか謎だが, 現在議論中のようだ.

# クラスタリング
lrn = makeLearner("cluster.kmeans")
plotLearnerPrediction(lrn, task = mtcars.task, features = c("disp", "drat"), cv = 0)

# 分類
lrn = makeLearner("classif.rpart", id = "CART")
plotLearnerPrediction(lrn, task = iris.task)

# 回帰(1次元)
plotLearnerPrediction("regr.lm", features = "lstat", task = bh.task)

# 回帰(2次元)
plotLearnerPrediction("regr.lm", features = c("lstat", "rm"), task = bh.task)

ここまでのまとめ

mlrチュートリアルの写経(その2)に続く……