if (!require(rpart)) install.packages("rpart")
if (!require(rpart.plot)) install.packages("rpart.plot")
if (!require(DT)) install.packages("DT")
if (!require(kableExtra)) install.packages("kableExtra")
タイタニック号の乗客データ 【kaggle】Titanic: cleaned data(一部竹田改変)
説明変数 | 内容 |
---|---|
乗客番号 | |
生死 | {生存,死亡} |
客室等級 | {1, 2, 3}等 |
乗客名 | |
性別 | {男性,女性} |
年齢 | |
兄弟配偶者数 | 同乗している兄弟・配偶者の数 |
親子数 | 同乗している親・子供の数 |
運賃 | |
乗船地 | {C:Cherbourg,Q:Queenstown,S:Southampton} |
d <- na.omit(read.csv('https://stats.dip.jp/01_ds/data/titanic_data_jp.csv'))
n <- nrow(d)
datatable(d, options = list(pageLength = 5))
# カラーパレット
COL <- c(rgb(255, 0, 0, 105, max = 255), # 赤
rgb( 0, 0, 255, 105, max = 255), # 青
rgb( 0, 155, 0, 105, max = 255), # 緑
rgb(100, 100, 100, 55, max = 255)) # 灰
tree <- rpart(生死 ~ 年齢 + 性別 + 客室等級 + 運賃 + 兄弟配偶者数 + 親子数 + 乗船地,
data = d, method = 'class', cp = 0.02)
rpart.plot(tree, type = 5)
各ノードでは,上から順に予測されたクラス{生 or 死},生存確率, ノードに含まれるデータサイズが全体に対する割合[%]が表示されている。
1,2等客室の女性や幼い子どもが優先的に救命ボートに乗せられた。
r <- rpart.rules(tree, cover = T)
kable(r)
生死 | cover | |||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
10 | 0.11 | when | 性別 | is | 男性 | & | 年齢 | < | 6.5 | & | 兄弟配偶者数 | >= | 3 | 1% | ||||||||
12 | 0.13 | when | 性別 | is | 女性 | & | 客室等級 | is | 3等 | & | 運賃 | >= | 21 | 3% | ||||||||
4 | 0.18 | when | 性別 | is | 男性 | & | 年齢 | >= | 6.5 | 60% | ||||||||||||
13 | 0.56 | when | 性別 | is | 女性 | & | 客室等級 | is | 3等 | & | 運賃 | < | 21 | 11% | ||||||||
7 | 0.94 | when | 性別 | is | 女性 | & | 客室等級 | is | 1等 or 2等 | 22% | ||||||||||||
11 | 1.00 | when | 性別 | is | 男性 | & | 年齢 | < | 6.5 | & | 兄弟配偶者数 | < | 3 | 2% |
様々な種類の樹形図を描くことができるので, 詳細はrpart.plot を参照のこと。
「branch.type = 5」にすると%を幹の太さ(灰色)で表現してくれる。
rpart.plot(tree, branch.type = 5)
CP(複雑度パラメータ)を与えて枝を剪定(せんてい)することができる。 CPの値を小さくすると分岐が多くなる。 CP図(plotcp)はCPが大きい方から順に並べてあり,CPごとの相対誤差を示す。 枝数が少ない方がルールとしては分かりやすい。 また,不必要に枝数が多いと過学習になる。 最適なCPは,CP図内の横点線下側に最も近い分岐数に対応するものを選ぶとよい。
plotcp(tree)
plotcp(tree)
tree2 <- prune(tree, cp = 0.1)
rpart.plot(tree2, branch.type = 5)
tree3 <- prune(tree, cp = 0.02)
rpart.plot(tree3, branch.type = 5)
データを与えると死亡率(または生存率)が出力される。
d.new <- data.frame(年齢 = 5,
性別 = '女性',
客室等級 = '3等',
運賃 = 10,
兄弟配偶者数 = 5,
親子数 = 7,
乗船地 = 'S')
rpart.predict(tree, newdata = d.new)
## 死亡 生存
## 1 0.443038 0.556962
xs <- c(1, 2, 1, 2)
ys <- c(1, 1, 2, 2)
d0 <- data.frame(x = c(xs, xs, xs+5, xs+5),
y = c(ys, ys+5, ys+3, ys+9),
z = c(rep(1, 4), rep(2, 4), rep(3, 4), rep(4, 4)))
kable(d0)
x | y | z |
---|---|---|
1 | 1 | 1 |
2 | 1 | 1 |
1 | 2 | 1 |
2 | 2 | 1 |
1 | 6 | 2 |
2 | 6 | 2 |
1 | 7 | 2 |
2 | 7 | 2 |
6 | 4 | 3 |
7 | 4 | 3 |
6 | 5 | 3 |
7 | 5 | 3 |
6 | 10 | 4 |
7 | 10 | 4 |
6 | 11 | 4 |
7 | 11 | 4 |
matplot(x = d0$x, y = d0$y, pch = 1, cex = 2, col = COL[1])
回帰木は,オプション「method = “anova”」にする。
tree <- rpart(z ~ x + y, data = d0, method = "anova", minsplit = 2, minbucket = 1, model = T)
rpart.plot(tree)
z=1,2,3,4の平均は2.5(分岐前),x<4だとz=1, 2の平均1.5, x>= 4だとz=3, 4の平均3.5になる。 ノードに含まれるデータサイズが全体に対する割合[%]で表示されている。
r <- rpart.rules(tree, cover = T)
kable(r)
z | cover | |||||||||
---|---|---|---|---|---|---|---|---|---|---|
4 | 1 | when | x | < | 4 | & | y | < | 4 | 25% |
5 | 2 | when | x | < | 4 | & | y | >= | 4 | 25% |
6 | 3 | when | x | >= | 4 | & | y | < | 8 | 25% |
7 | 4 | when | x | >= | 4 | & | y | >= | 8 | 25% |
d.new <- data.frame(x = c(2.5, 4.5, 3.5),
y = c(8.0, 5.4, 3.0))
zhat <- rpart.predict(tree, newdata = d.new)
zhat
## 1 2 3
## 2 3 1
matplot(x = d0$x, y = d0$y, pch = 1, cex = 2, col = COL[1])
text(tree$model$x, tree$model$y, tree$model$z) # データ分類番号表示
text(d.new$x, d.new$y, zhat, col = "blue") # 予測分類番号表示
# 分岐ルールを青点線で示す。
segments(4, 0, 4,12, lty = 3, col = COL[2])
segments(4, 8, 8, 8, lty = 3, col = COL[2])
segments(0, 4, 4, 4, lty = 3, col = COL[2])
plotcp(tree)
非推奨
library(plotly)
library(ggplot2)
library(cluster)
library(ggdendro)
data <- dendro_data(tree)
p <- ggplot() +
geom_segment(data = data$segments,
aes(x = x, y = y, xend = xend, yend = yend)) +
geom_text(data = data$labels,
aes(x = x, y = y, label = label), size = 5, vjust = 0) +
geom_text(data = data$leaf_labels,
aes(x = x, y = y, label = label), size = 5, vjust = 1) +
theme_dendro()
ggplotly(p)
非推奨(日本語設定が必要)
import pandas as pd
from sklearn import tree
from sklearn.tree import plot_tree
d = pd.read_csv('https://stats.dip.jp/01_ds/data/titanic_data.csv').dropna(axis = 'rows')
d
## PassengerId Survived Pclass Name ... SibSp Parch Fare Embarked
## 0 1 0 3 Braund ... 1 0 7.2500 S
## 1 2 1 1 Cumings ... 1 0 71.2833 C
## 2 3 1 3 Heikkinen ... 0 0 7.9250 S
## 3 4 1 1 Futrelle ... 1 0 53.1000 S
## 4 5 0 3 Allen ... 0 0 8.0500 S
## .. ... ... ... ... ... ... ... ... ...
## 886 887 0 2 Montvila ... 0 0 13.0000 S
## 887 888 1 1 Graham ... 0 0 30.0000 S
## 888 889 0 3 Johnston ... 1 2 23.4500 S
## 889 890 1 1 Behr ... 0 0 30.0000 C
## 890 891 0 3 Dooley ... 0 0 7.7500 Q
##
## [889 rows x 10 columns]
y = d['Survived']
x = d[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']]
x = pd.get_dummies(x)
model = tree.DecisionTreeClassifier(max_depth = 3, random_state = 0)
model.fit(x, y)
DecisionTreeClassifier(max_depth=3, random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DecisionTreeClassifier(max_depth=3, random_state=0)
# 日本語は豆腐になるので,
# export.py でデフォルトのフォントから日本語フォントへの変更が必要。fontname='helvetica'を変更
# import sklearn.tree._export
# print(sklearn.tree._export.__file__)
plot_tree(model, feature_names = x.columns.tolist(), class_names = ['male', 'female'], filled = True)