if (!require(rpart))      install.packages("rpart")
if (!require(rpart.plot)) install.packages("rpart.plot")
if (!require(DT))         install.packages("DT")
if (!require(kableExtra)) install.packages("kableExtra")

1 データ

タイタニック号の乗客データ 【kaggle】Titanic: cleaned data(一部竹田改変)

説明変数 内容
乗客番号
生死 {生存,死亡}
客室等級 {1, 2, 3}等
乗客名
性別 {男性,女性}
年齢
兄弟配偶者数 同乗している兄弟・配偶者の数
親子数 同乗している親・子供の数
運賃
乗船地 {C:Cherbourg,Q:Queenstown,S:Southampton}
d <- na.omit(read.csv('https://stats.dip.jp/01_ds/data/titanic_data_jp.csv'))
n <- nrow(d)

datatable(d, options = list(pageLength = 5))

1.1 グラフ

# カラーパレット
COL <- c(rgb(255,   0,   0,  105, max = 255), # 赤
         rgb(  0,   0, 255,  105, max = 255), # 青
         rgb(  0, 155,   0,  105, max = 255), # 緑
         rgb(100, 100, 100,   55, max = 255)) # 灰

RGB_Color

2 分類木

tree <- rpart(生死 ~ 年齢 + 性別 + 客室等級 + 運賃 + 兄弟配偶者数 + 親子数 + 乗船地,
              data = d, method = 'class', cp = 0.02)
rpart.plot(tree, type = 5)

各ノードでは,上から順に予測されたクラス{生 or 死},生存確率, ノードに含まれるデータサイズが全体に対する割合[%]が表示されている。

1,2等客室の女性や幼い子どもが優先的に救命ボートに乗せられた。

r <- rpart.rules(tree, cover = T)
kable(r)
生死 cover
10 0.11 when 性別 is 男性 & 年齢 < 6.5 & 兄弟配偶者数 >= 3 1%
12 0.13 when 性別 is 女性 & 客室等級 is 3等 & 運賃 >= 21 3%
4 0.18 when 性別 is 男性 & 年齢 >= 6.5 60%
13 0.56 when 性別 is 女性 & 客室等級 is 3等 & 運賃 < 21 11%
7 0.94 when 性別 is 女性 & 客室等級 is 1等 or 2等 22%
11 1.00 when 性別 is 男性 & 年齢 < 6.5 & 兄弟配偶者数 < 3 2%

様々な種類の樹形図を描くことができるので, 詳細はrpart.plot を参照のこと。

「branch.type = 5」にすると%を幹の太さ(灰色)で表現してくれる。

rpart.plot(tree, branch.type = 5)

2.1 剪定

CP(複雑度パラメータ)を与えて枝を剪定(せんてい)することができる。 CPの値を小さくすると分岐が多くなる。 CP図(plotcp)はCPが大きい方から順に並べてあり,CPごとの相対誤差を示す。 枝数が少ない方がルールとしては分かりやすい。 また,不必要に枝数が多いと過学習になる。 最適なCPは,CP図内の横点線下側に最も近い分岐数に対応するものを選ぶとよい。

2.1.1 CP図

plotcp(tree)

plotcp(tree)

tree2 <- prune(tree, cp = 0.1)

rpart.plot(tree2, branch.type = 5)

tree3 <- prune(tree, cp = 0.02)

rpart.plot(tree3, branch.type = 5)

2.2 予測

データを与えると死亡率(または生存率)が出力される。

d.new <- data.frame(年齢         = 5,
                    性別         = '女性', 
                    客室等級     = '3等', 
                    運賃         = 10,
                    兄弟配偶者数 = 5,
                    親子数       = 7,
                    乗船地       = 'S')

rpart.predict(tree, newdata = d.new)
##       死亡     生存
## 1 0.443038 0.556962

3 回帰木

3.1 テストデータ

xs <- c(1, 2, 1, 2)
ys <- c(1, 1, 2, 2)

d0 <- data.frame(x = c(xs, xs,   xs+5, xs+5),
                 y = c(ys, ys+5, ys+3, ys+9),
                 z = c(rep(1, 4), rep(2, 4), rep(3, 4), rep(4, 4)))
kable(d0)
x y z
1 1 1
2 1 1
1 2 1
2 2 1
1 6 2
2 6 2
1 7 2
2 7 2
6 4 3
7 4 3
6 5 3
7 5 3
6 10 4
7 10 4
6 11 4
7 11 4
matplot(x = d0$x, y = d0$y, pch = 1, cex = 2, col = COL[1])

3.2 回帰木分析

回帰木は,オプション「method = “anova”」にする。

tree <- rpart(z ~ x + y, data = d0, method = "anova", minsplit = 2, minbucket = 1, model = T)

3.3 樹形図

rpart.plot(tree)

z=1,2,3,4の平均は2.5(分岐前),x<4だとz=1, 2の平均1.5, x>= 4だとz=3, 4の平均3.5になる。 ノードに含まれるデータサイズが全体に対する割合[%]で表示されている。

r <- rpart.rules(tree, cover = T)
kable(r)
z cover
4 1 when x < 4 & y < 4 25%
5 2 when x < 4 & y >= 4 25%
6 3 when x >= 4 & y < 8 25%
7 4 when x >= 4 & y >= 8 25%

3.4 予測

d.new <- data.frame(x = c(2.5, 4.5, 3.5),
                    y = c(8.0, 5.4, 3.0))

zhat <- rpart.predict(tree, newdata = d.new)
zhat
## 1 2 3 
## 2 3 1

3.5 樹形図(分類後)

matplot(x = d0$x, y = d0$y, pch = 1, cex = 2, col = COL[1])
text(tree$model$x, tree$model$y, tree$model$z) # データ分類番号表示
text(d.new$x, d.new$y, zhat, col = "blue")     # 予測分類番号表示

# 分岐ルールを青点線で示す。
segments(4, 0, 4,12, lty = 3, col = COL[2])
segments(4, 8, 8, 8, lty = 3, col = COL[2])
segments(0, 4, 4, 4, lty = 3, col = COL[2])

3.6 CP図

plotcp(tree)

3.7 インタラクティブグラフ

非推奨

library(plotly)
library(ggplot2)
library(cluster)
library(ggdendro)
data <- dendro_data(tree)
p <- ggplot() + 
      geom_segment(data = data$segments, 
                   aes(x = x, y = y, xend = xend, yend = yend)) + 
      geom_text(data = data$labels, 
                aes(x = x, y = y, label = label), size = 5, vjust = 0) +
      geom_text(data = data$leaf_labels, 
                aes(x = x, y = y, label = label), size = 5, vjust = 1) +
      theme_dendro()

ggplotly(p)

4 Python

非推奨(日本語設定が必要)

import pandas as pd
from sklearn import tree
from sklearn.tree import plot_tree

d = pd.read_csv('https://stats.dip.jp/01_ds/data/titanic_data.csv').dropna(axis = 'rows')
d
##      PassengerId  Survived  Pclass       Name  ... SibSp Parch     Fare  Embarked
## 0              1         0       3     Braund  ...     1     0   7.2500         S
## 1              2         1       1    Cumings  ...     1     0  71.2833         C
## 2              3         1       3  Heikkinen  ...     0     0   7.9250         S
## 3              4         1       1   Futrelle  ...     1     0  53.1000         S
## 4              5         0       3      Allen  ...     0     0   8.0500         S
## ..           ...       ...     ...        ...  ...   ...   ...      ...       ...
## 886          887         0       2   Montvila  ...     0     0  13.0000         S
## 887          888         1       1     Graham  ...     0     0  30.0000         S
## 888          889         0       3   Johnston  ...     1     2  23.4500         S
## 889          890         1       1       Behr  ...     0     0  30.0000         C
## 890          891         0       3     Dooley  ...     0     0   7.7500         Q
## 
## [889 rows x 10 columns]
y = d['Survived']
x = d[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']]

x = pd.get_dummies(x)

model = tree.DecisionTreeClassifier(max_depth = 3, random_state = 0)
model.fit(x, y)
DecisionTreeClassifier(max_depth=3, random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

# 日本語は豆腐になるので,
# export.py でデフォルトのフォントから日本語フォントへの変更が必要。fontname='helvetica'を変更
# import sklearn.tree._export
# print(sklearn.tree._export.__file__)

plot_tree(model, feature_names = x.columns.tolist(), class_names = ['male', 'female'], filled = True)