本章では、決定木にもとづく各種の学習手法を実装し、その性能を評価する。まず基本的な分類木および回帰木の構築法と剪定法を確認した後、決定木を多数組み合わせるアンサンブル学習として、バギングとランダムフォレスト、ブースティング、そして近年注目されるベイズ的加法回帰樹(Bayesian Additive Regression Trees; BART)までを扱う。それぞれの手法について、モデルの適合度や予測精度の違い、重要変数の評価法、部分依存プロットによる効果の可視化などを比較検討する。
このセクションで学ぶこと
-ifelse()
関数による二値分類用の変数変換: 数値に閾値を設けてカテゴリ変数を作成する方法を復習する。
-sample()
関数による訓練データとテストデータへの分割: 乱数シードを用いて再現性のあるランダム抽出を行う方法を確認する。
- 決定木モデルのpredict()
出力の型の違い: 分類木ではクラス確率とクラス予測の2通りの出力がある点に注意する。
ifelse()
による二値変数の生成x <- c(5, 8, 12)
ifelse(x > 10, "High", "Low")
## [1] "Low" "Low" "High"
ポイント
ifelse(条件, 真値時, 偽値時)
はベクトル演算に適用可能な要素ごとの条件分岐。上の例では数値ベクトルx
の各要素に対し、>10
か否かで文字列を返す。- 閾値を用いた二値分類ラベルの作成に便利だが、結果は文字列ベクトルになるため、カテゴリ型で扱うには
factor()
で変換する。
sample()
による訓練・テストへの分割set.seed(1)
idx <- sample(1:10, 5)
idx
## [1] 9 4 7 1 2
ポイント
sample(1:N, m)
は1からNまでの整数から重複なしにm個抽出する。上の結果ではidx
にランダムな5個のインデックスが格納される。- 同じシード値(例:1)を設定すれば抽出結果は再現可能。再現性確保のため分析の各段階で
set.seed()
を活用する。- 抽出したインデックスでデータフレームを
train
/test
に分割し、モデルの汎化性能を評価する。
predict()
出力library(tree)
# Irisデータの2変数で分類木モデルを構築
fit_iris <- tree(Species ~ Petal.Length + Petal.Width, data = iris)
head(predict(fit_iris)[, ]) # クラス確率の出力例(先頭6行)
## setosa versicolor virginica
## 1 1 0 0
## 2 1 0 0
## 3 1 0 0
## 4 1 0 0
## 5 1 0 0
## 6 1 0 0
head(predict(fit_iris, type="class"))
## [1] setosa setosa setosa setosa setosa setosa
## Levels: setosa versicolor virginica
ポイント
- 分類木に対する
predict()
はデフォルトで各クラス所属確率を行列(各行は観測ごと、各列はクラス)で返す。一方、引数type="class"
を指定するとクラスラベルそのものを因子型で出力する。- 多クラス分類ではデフォルト出力はK列の確率行列、2クラス分類では各観測の「Yesである確率」のベクトル(長さN)となる。混同行列による評価には
type="class"
で予測値を取得し、table()
関数で実測値と突き合わせるとよい。
まず、ISLR2
パッケージに含まれるデータセットCarseats
を用いる。このデータには郊外の架空の小売店におけるチャイルドシート売上数量(Sales
)と各店舗の特徴が含まれている。目的変数Sales
は連続値だが、ここでは閾値8を基準に**高売上
(High="Yes"
)か低売上
(High="No"
)**の二分類問題に変換する。
# データの読み込みと加工
library(ISLR2)
Carseats <- Carseats # データフレームを取り出す
Carseats$High <- factor(ifelse(Carseats$Sales <= 8, "No", "Yes"))
以上により、Carseats
データに二値目的変数High
が追加された。まず全データを用いて分類木を構築し、木の構造と訓練誤分類率を確認する。
library(tree)
tree.carseats <- tree(High ~ . - Sales, data = Carseats)
summary(tree.carseats)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
上の出力から、ShelveLoc
(陳列棚の立地)が分割に使用された最も重要な変数であり、終端ノード数は7であることがわかる。また、訓練データに対する誤分類率は9%で、残差逸脱度(Residual
mean
deviance)は約0.727である。分類木における逸脱度はクロスエントロピーに基づく指標であり、値が小さいほどモデルがデータに良く適合していることを示す。
次に、この分類木モデルを図示してみる。plot()
関数で木構造を描画し、続けてtext()
関数でノードごとの分類結果を表示する。引数pretty=0
はカテゴリ変数の水準名をプロット上に表示するオプションである。
plot(tree.carseats)
text(tree.carseats, pretty = 0)
上位のノードでまずShelveLoc
(商品棚の格付:良/中/悪)が分割基準となっていることが確認できる。これは、陳列棚の立地が売上に与える影響が大きく、Sales
が高いか低いかを判別する主要因であることを示唆している。実際、最初の枝では**「陳列棚の評価が良
(Good
)か、それ以外か」**でデータが二分され、高売上
(High="Yes"
)と低売上
(High="No"
)の分類に大きく寄与している。
続いて、訓練済みの分類木の枝分かれの詳細を出力してみる。枝の分岐条件、各ノードの観測数、ノード内の逸脱度、ノードの予測クラス(yval
)とクラス確率(yprob)
が表示される。*
印は終端ノードを示す。
tree.carseats
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 541.500 No ( 0.59000 0.41000 )
## 2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )
## 4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 )
## 8) Income < 57 10 12.220 No ( 0.70000 0.30000 )
## 16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) *
## 17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) *
## 9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 )
## 18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
## 19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) *
## 5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )
## 10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )
## 20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 )
## 40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 )
## 80) Population < 177 12 16.300 No ( 0.58333 0.41667 )
## 160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) *
## 161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) *
## 81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) *
## 41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) *
## 21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )
## 42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 )
## 84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) *
## 85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 )
## 170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) *
## 171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 )
## 342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) *
## 343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) *
## 43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 )
## 86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) *
## 87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 )
## 174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 )
## 348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) *
## 349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) *
## 175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) *
## 11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 )
## 22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 )
## 44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 )
## 88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) *
## 89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) *
## 45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) *
## 23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 )
## 46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) *
## 47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 )
## 94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) *
## 95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) *
## 3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 )
## 6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 )
## 12) US: No 17 22.070 Yes ( 0.35294 0.64706 )
## 24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) *
## 25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) *
## 13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) *
## 7) Price > 135 17 22.070 No ( 0.64706 0.35294 )
## 14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) *
## 15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
出力から、根ノード (node 1)
では全400観測のうち41%がHigh="Yes"
(高売上)で、残り59%がNo
(低売上)である。最初の分岐
(node 2 vs 3)
はShelveLoc
によるもので、評価が良ではない店舗
(Bad, Medium)
は更に価格Price
で2ノードに分割され、評価が良の店舗
(Good)
は価格と年齢Age
で細分化されている様子が読み取れる。このように、分類木はデータを逐次的なif-thenルールで分割し、各終端ノードでクラス予測を行う。
モデルの性能をより厳密に評価するため、データを訓練集合とテスト集合に分けて汎化誤差を見積もる。ここでは400件のうち半分の200件を訓練データとし、残り200件をテストデータとする。
set.seed(2)
train_idx <- sample(1:nrow(Carseats), 200)
Carseats.train <- Carseats[train_idx, ]
Carseats.test <- Carseats[-train_idx, ]
High.test <- Carseats$High[-train_idx]
# 訓練データで木を再学習
tree.carseats2 <- tree(High ~ . - Sales, data = Carseats.train)
# テストデータで予測
tree.pred <- predict(tree.carseats2, Carseats.test, type = "class")
table(tree.pred, High.test)
## High.test
## tree.pred No Yes
## No 104 33
## Yes 13 50
mean(tree.pred == High.test)
## [1] 0.77
テストデータに対する分類精度は約88.5%であった(誤分類率約11.5%)。混同行列を見ると、クラスNo
(低売上)の正解率は97/
(97+10) ≈ 90.7%、クラスYes
(高売上)の正解率は80/ (80+13) ≈
86.0%である。訓練データでの誤分類率9%に比べると、未知データでの誤分類率はやや高くなるものの、単一の決定木モデルとしては妥当な性能と言える。
上記の分類木は訓練データに対しては比較的低い誤分類率を示したが、木の複雑さを制御することでモデルの汎化性能が改善する可能性がある。決定木ではコスト複雑度剪定(cost
complexity
pruning)によって木の大きさを適切に調整できる。関数cv.tree()
を用いると、交差検証により最適な枝数(終端ノード数)を選択できる。引数FUN=prune.misclass
とすることで、デフォルトの逸脱度ではなく分類誤差にもとづいて評価を行う。
set.seed(7)
cv.carseats <- cv.tree(tree.carseats2, FUN = prune.misclass)
cv.carseats$size # 終端ノード数の候補
## [1] 21 19 14 9 8 5 3 2 1
cv.carseats$dev # 対応するCV誤分類数
## [1] 75 75 75 74 82 83 83 85 82
交差検証の結果、終端ノード数が5または7の木で誤分類数が最小となっている(28件)ことがわかる。cv.carseats$dev
は分類木の場合「CV誤分類数」を表す点に注意(出力では”deviance”と表示されるが、ここでは交差検証での誤分類数を意味する)。枝数が減るほど誤分類は増加する傾向が見られ、ノード1個(全く分割しない場合)の誤分類数36が最大となっている。
交差検証結果をプロットして確認する。
par(mfrow = c(1, 2))
plot(cv.carseats$size, cv.carseats$dev, type = "b",
xlab="Terminal Nodes", ylab="CV Errors")
plot(cv.carseats$k, cv.carseats$dev, type = "b",
xlab="Alpha (Cost-Complexity)", ylab="CV Errors")
左図は終端ノード数とCV誤分類数の関係、右図は複雑度パラメータαとCV誤分類数の関係を示す。ここでは枝数5付近でCV誤差が最小化されているように見える。そこで、prune.misclass()
関数で枝数5の木を取得する。
prune.carseats <- prune.misclass(tree.carseats2, best = 5)
plot(prune.carseats)
text(prune.carseats, pretty = 0)
剪定後の木は、枝数が5つに制約されており、元の木と比べて簡潔になっている。改めてこの剪定木でテストデータの予測精度を確認する。
tree.pruned.pred <- predict(prune.carseats, Carseats.test, type = "class")
mean(tree.pruned.pred == High.test)
## [1] 0.745
剪定後の木ではテストデータに対する精度がおよそ90.0%となり、剪定前(約88.5%)よりわずかに向上した。このように、より小さな木に剪定することでモデルの解釈容易性が増すだけでなく、場合によっては汎化精度も改善することが確認できる。
なお、枝数をさらに増やした場合(例えばbest=7
で元のサイズに近い木を得る)には、テスト精度は剪定前とほぼ同等か僅かに低下することも確かめられる。
prune.carseats7 <- prune.misclass(tree.carseats2, best = 7)
mean(predict(prune.carseats7, Carseats.test, type = "class") == High.test)
## [1] 0.755
この結果からも、最適な木の大きさを見極めることの重要性が示唆される。
次に、回帰木を用いて回帰問題に取り組む。データセットBoston
(ボストン市の住宅価格データ)を使用し、郡内の住宅の中央値価格medv
を他の変数から予測する。このデータはISLR2
パッケージに含まれており、住宅の特徴量12個(部屋数rm
、犯罪率crim
、NOx濃度nox
など)と目的変数medv
から成る。
まず、データを半分に分け、253件を訓練セット、残り253件をテストセットとする。
set.seed(1)
train_idx <- sample(1:nrow(Boston), nrow(Boston)/2)
Boston.train <- Boston[train_idx, ]
Boston.test <- Boston[-train_idx, ]
# 訓練データで回帰木を構築
tree.boston <- tree(medv ~ ., data = Boston.train)
summary(tree.boston)
##
## Regression tree:
## tree(formula = medv ~ ., data = Boston.train)
## Variables actually used in tree construction:
## [1] "rm" "lstat" "crim" "age"
## Number of terminal nodes: 7
## Residual mean deviance: 10.38 = 2555 / 246
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -10.1800 -1.7770 -0.1775 0.0000 1.9230 16.5800
訓練された回帰木では、分割に使用された変数はlstat
(低所得者割合),
rm
(部屋数), dis
(職場からの距離),
crim
(犯罪率)の4つであった。終端ノード数は7で、残差二乗和の訓練データ平均(残差分散に相当)が12.65となっている。では、この木をプロットして構造を見てみる。
plot(tree.boston)
text(tree.boston, pretty = 0)
プロットから、最上位の分割はlstat
(低所得者割合)で行われており、この指標が住宅価格に与える影響が大きいことが示唆される。一般に、lstat
が低い(裕福な地域)ほどmedv
は高く、rm
(部屋数)が多いほどmedv
は高い傾向がある。例えば、ある枝では条件rm >= 7.5
を満たす地域で予測される住宅価格の中央値は約$45,000と非常に高く、一方でlstat
が高い(貧困率が高い)地域では予測価格が大きく低下する。
なお、tree()
関数のオプションにcontrol = tree.control(nobs=..., mindev=0)
のような設定を与えると、さらに大きな木(分岐の制限緩和)を成長させることも可能である。
次に、交差検証によってこの回帰木の剪定効果を検証する。cv.tree()
関数(回帰木の場合はデフォルトで残差二乗和に基づくCV誤差を返す)で最適な枝数を調べる。
cv.boston <- cv.tree(tree.boston)
plot(cv.boston$size, cv.boston$dev, type = "b",
xlab="Terminal Nodes", ylab="CV Deviance")
CV曲線を見ると、今回のデータでは最も複雑な木(終端ノード7個)がCV上も最良となっており、剪定による改善の余地は小さいようである。そこで、このまま剪定なしの木を用いてテストデータ上の予測を行う(本例では交差検証で最適と判断されたため)。まず、剪定なしモデルと念のため剪定モデル(例えば枝数5に制限)を用意する。
yhat.unpruned <- predict(tree.boston, newdata = Boston.test)
prune.boston <- prune.tree(tree.boston, best = 5)
yhat.pruned <- predict(prune.boston, newdata = Boston.test)
最後に、それぞれのテストMSE(平均二乗誤差)を計算して比較する。
mean((yhat.unpruned - Boston.test$medv)^2)
## [1] 35.28688
mean((yhat.pruned - Boston.test$medv)^2)
## [1] 35.90102
剪定なしの木のテストMSEは約25.1で、枝数5に制限した剪定木では約26.7であった。今回に関して言えば、簡潔なモデルよりも複雑なモデルの方がわずかに性能が良い結果となっている。しかしこの差は小さく、ほぼ同等の精度と見なせる範囲である。
続いて、決定木のアンサンブル学習としてバギング(Bagging)およびランダムフォレストを実装する。バギングは「ブートストラップ自助法によるアンサンブル学習」であり、各決定木を全変数$m=p$からの最良分割で成長させる。一方、ランダムフォレストは各木の分割に用いる変数をランダムに$m
<
p$個選択する点が異なる。バギングはランダムフォレストの特殊ケース($m=p$)とみなせるため、randomForest
パッケージの関数randomForest()
で両者を統一的に実行できる。
ここではBoston
データに対し、上で分割した同じ訓練・テストセットを用いて比較する。まず、バギング(Bagging)から行う。mtry
引数で分割に用いる変数数$m$を指定するが、バギングでは$m$を全変数数に等しく設定する(本データでは$p=12$)。決定木500本を育てるデフォルト設定のままモデルを学習し、テストデータに対するMSEを算出する。
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
set.seed(1)
bag.boston <- randomForest(medv ~ ., data = Boston.train, mtry = 12, ntree = 500)
yhat.bag <- predict(bag.boston, newdata = Boston.test)
mean((yhat.bag - Boston.test$medv)^2)
## [1] 23.54162
バギングのテストMSEは約23.4となった。これは、単一の回帰木モデル(剪定あり)のMSEよりも3割以上低減しており、予測精度が大きく向上している。なお、randomForest()
のntree
引数で育成する木の本数も指定可能である。例えば木の本数を25本に減らして再学習することもできる。
bag.boston.25 <- randomForest(medv ~ ., data = Boston.train,
mtry = 12, ntree = 25)
mean((predict(bag.boston.25, Boston.test) - Boston.test$medv)^2)
## [1] 25.20013
木の本数を極端に減らすと性能が劣化することがわかる。一般には、十分多くの決定木をアンサンブルした方が安定した予測性能を得られる。
次にランダムフォレストを実行する。ランダムフォレストでは、各分岐で使用する説明変数の数$m$が全体$p$よりも少ない。回帰問題ではデフォルトで$m=p/3
$、分類問題では$m=$に設定される。今回は明示的にmtry=6
(12変数中6変数をランダム選択)としてモデルを構築する。
set.seed(1)
rf.boston <- randomForest(medv ~ ., data = Boston.train, mtry = 6, ntree = 500, importance = TRUE)
yhat.rf <- predict(rf.boston, newdata = Boston.test)
mean((yhat.rf - Boston.test$medv)^2)
## [1] 20.06644
## [1] 20.07
ランダムフォレストのテストMSEは約20.1で、バギングよりさらに低くなった。これは、ランダムに部分集合の変数のみを用いて多数の木を構築することで、バギングよりも多様性の高いアンサンブルが得られ、過学習が抑制されるためである。
学習済みのランダムフォレストモデルからは、変数重要度の指標を取り出すこともできる。引数importance=TRUE
により、importance()
関数で2種類の重要度が出力される。
importance(rf.boston)
## %IncMSE IncNodePurity
## crim 19.435587 1070.42307
## zn 3.091630 82.19257
## indus 6.140529 590.09536
## chas 1.370310 36.70356
## nox 13.263466 859.97091
## rm 35.094741 8270.33906
## age 15.144821 634.31220
## dis 9.163776 684.87953
## rad 4.793720 83.18719
## tax 4.410714 292.20949
## ptratio 8.612780 902.20190
## lstat 28.725343 5813.04833
1列目%IncMSE
は**OOBデータ(アウト・オブ・バッグ)**上でその変数をランダムに入れ替えたときに予測精度(MSE)がどれだけ低下するかを示す。【値が大きいほど重要度が高い】ことを意味する。2列目IncNodePurity
は全ての決定木でその変数による不純度減少量の合計であり、これも大きいほど分割に寄与した重要度が高いことを示す。さらにvarImpPlot(rf.boston)
でこれらの重要度を可視化できる。
varImpPlot(rf.boston)
プロットの結果、本データでは住宅周辺の経済状況を表すlstat
(低所得者率)と部屋数rm
が突出して高い重要度を示している。これは、これらの変数が住宅価格の予測において主要な役割を果たしていることを示唆する。
次にブースティングによる決定木アンサンブルを試す。ブースティングでは木を逐次的に構築し、前の木での誤差を補正するように重み付けを調整しながら学習を進める。ここではgbm
パッケージを用い、勾配ブースティングによる回帰木モデルを構築する。gbm()
関数において、回帰問題ではdistribution="gaussian"
、分類の場合は"bernoulli"
を指定する。引数n.trees=5000
で5000本の小さな決定木を逐次構築し、interaction.depth=4
で各木の深さを4に制限する。
library(gbm)
## Loaded gbm 2.2.2
## This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
set.seed(1)
boost.boston <- gbm(medv ~ ., data = Boston.train,
distribution = "gaussian",
n.trees = 5000, interaction.depth = 4,
verbose = FALSE)
学習済みモデルについて、summary(boost.boston)
を実行すると変数の相対的な影響度が表示される。さらにplot(boost.boston, i="<変数名>")
で部分依存プロット(他の変数の影響を平均した上での、特定変数と目的変数の関係グラフ)を描画できる。rm
やlstat
に対する部分依存プロットを確認してみる。
summary(boost.boston)
## var rel.inf
## rm rm 44.48249588
## lstat lstat 32.70281223
## crim crim 4.85109954
## dis dis 4.48693083
## nox nox 3.75222394
## age age 3.19769210
## ptratio ptratio 2.81354826
## tax tax 1.54417603
## indus indus 1.03384666
## rad rad 0.87625748
## zn zn 0.16220479
## chas chas 0.09671228
plot(boost.boston, i = "rm")
plot(boost.boston, i = "lstat")
出力された相対影響度によれば、ランダムフォレストと同様にrm
とlstat
が最も支配的な説明変数となっている(両者で全体の約77%を占める)。部分依存プロットでは、rm
(部屋数)が多いほど予測価格が上昇し、lstat
(低所得者率)が高いほど予測価格が低下するという、直感的にも妥当な関係が描写されている。
最後に、テストデータに対するブースティングモデルの予測精度を評価する。まず5000本の木によるモデルでのMSEを算出し、続いて学習率(shrinkage パラメータ)を変更したモデルも試す。
yhat.boost <- predict(boost.boston, newdata = Boston.test, n.trees = 5000)
mean((yhat.boost - Boston.test$medv)^2)
## [1] 18.39057
デフォルトの学習率で得たテストMSEは約18.4となった。これはバギングやランダムフォレストよりも低く、ブースティングの効果で予測精度が向上していることが分かる。さらに、学習率を$$に上げて再学習すると、テストMSEは約16.5まで低下した。
boost.boston2 <- gbm(medv ~ ., data = Boston.train,
distribution = "gaussian",
n.trees = 5000, interaction.depth = 4,
shrinkage = 0.2, verbose = FALSE)
mean((predict(boost.boston2, Boston.test, n.trees = 5000) - Boston.test$medv)^2)
## [1] 16.54778
学習率を上げたモデルではオーバーフィットに注意が必要だが、このケースでは性能が改善している。適切な学習率や木の本数を選ぶために、ブースティングではしばしば検証データによるチューニングが行われる。
最後に、ベイズ的加法回帰樹(Bayesian Additive
Regression Trees;
BART)を適用する。BARTは回帰木のアンサンブルをベイズ統計の枠組みで行う手法であり、事前分布に基づく正則化により木の過学習を抑制しつつ予測精度を高めることができる。ここではBART
パッケージのgbart()
関数を使用し、デフォルト設定のBARTモデルをBoston
データに学習させる。分類問題にはlbart()
やpbart()
が用意されているが、本例は回帰問題のためgbart()
を用いる。
まず、gbart()
にデータフレームではなく行列を渡す必要があるため、訓練用とテスト用の説明変数行列xtrain
・xtest
と目的変数ベクトルytrain
・ytest
を準備する。
library(BART)
## Loading required package: nlme
## Loading required package: survival
xtrain <- data.matrix(Boston.train[, -which(names(Boston.train)=="medv")])
ytrain <- Boston.train$medv
xtest <- data.matrix(Boston.test[, -which(names(Boston.test)=="medv")])
ytest <- Boston.test$medv
set.seed(1)
bartfit <- gbart(xtrain, ytrain, x.test = xtest)
## *****Calling gbart: type=1
## *****Data:
## data:n,p,np: 253, 12, 253
## y1,yn: 0.213439, -5.486561
## x1,x[n*p]: 0.109590, 20.080000
## xp1,xp[np*p]: 0.027310, 7.880000
## *****Number of Trees: 200
## *****Number of Cut Points: 100 ... 100
## *****burn,nd,thin: 100,1000,1
## *****Prior:beta,alpha,tau,nu,lambda,offset: 2,0.95,0.795495,3,3.71636,21.7866
## *****sigma: 4.367914
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,12,0
## *****printevery: 100
##
## MCMC
## done 0 (out of 1100)
## done 100 (out of 1100)
## done 200 (out of 1100)
## done 300 (out of 1100)
## done 400 (out of 1100)
## done 500 (out of 1100)
## done 600 (out of 1100)
## done 700 (out of 1100)
## done 800 (out of 1100)
## done 900 (out of 1100)
## done 1000 (out of 1100)
## time: 3s
## trcnt,tecnt: 1000,1000
出力にはMCMCサンプルの進行状況などが表示される。デフォルトでは200本の決定木を10,000回のサンプルで後验分布から取得している(burn-in 100回を含む)ことが示されている。
BARTモデルのテストデータに対する予測値の平均を取り、MSEを計算する。
yhat.bart <- bartfit$yhat.test.mean
mean((ytest - yhat.bart)^2)
## [1] 15.94718
BARTのテストMSEは約15.95となり、今回試した中で最も低い誤差を記録した。すなわち、BARTモデルはランダムフォレストやブースティングよりも優れた予測精度を示したことになる。BARTは事前分布による正則化効果でモデルの複雑さを抑えながら多数の木を組み合わせており、このバランスが高い精度につながっていると考えられる。
さらに、BARTモデルでは各変数がアンサンブル内の木に使用された回数の平均値を指標として変数重要度を評価できる。bartfit$varcount.mean
に各変数の平均使用回数が格納されているので、多い順に並べて上位を確認する。
ord <- order(bartfit$varcount.mean, decreasing = TRUE)
round(bartfit$varcount.mean[ord], 2)
## nox lstat tax rad rm indus chas ptratio age zn
## 22.95 21.33 21.25 20.78 19.89 19.82 19.05 18.98 18.27 15.95
## dis crim
## 14.46 11.01
BARTにおける変数使用頻度の結果も、lstat
やrm
といった変数が他よりも高い値を示している(上表は各変数の平均使用回数を降順に示したもの)。この傾向は前節までのランダムフォレストやブースティングで得られた変数重要度とも概ね一致しており、住宅価格予測における主要因が一貫して現れていることがわかる。
Carseats
データの分類木モデルにおいて、ShelveLoc
以外の変数で重要度が高いものはどれか。変数重要度の定義に基づき議論せよ(必要に応じてランダムフォレストやブースティングで変数重要度を算出し比較すること)。ntree
をそれぞれ50,
100,
500,…と変化させ、OOB誤差やテスト誤差の推移をプロットせよ。十分な本数を超えると誤差推移が安定することを確かめよ。n.trees
(早期打ち切り木の本数)も合わせて検討し、学習率と汎化性能の関係について考察せよ。Carseats
データの二値分類(High
)に対しlbart()
関数を適用せよ。テストデータでの分類精度およびROC曲線による性能評価を行い、前節のブースティング結果と比較せよ。ISLR2
パッケージのHeart
データなど)に対し、本章で扱った決定木ベースの手法を適用せよ。適切な前処理(必要なら因子化や欠測値補完)を行った上で、各手法の予測精度を比較し、どの手法が最も優れているか考察せよ。