d <- na.omit(read.csv('https://stats.dip.jp/01_ds/data/titanic_data_jp.csv'))
n <- nrow(d)
library(DT)
datatable(d, options = list(pageLength = 5))
COL <- c(rgb(255, 0, 0, 105, max = 255),
rgb( 0, 0, 255, 105, max = 255),
rgb( 0, 155, 0, 105, max = 255),
rgb(100, 100, 100, 55, max = 255))
library(rpart)
library(rpart.plot)
tree <- rpart(生死 ~ 年齢 + 性別 + 客室等級 + 運賃 + 兄弟配偶者数 + 親子数 + 乗船地,
data = d, method = 'class', cp = 0.04)
rpart.plot(tree, type = 5)

rpart.plot(tree, branch.type = 5)

plotcp(tree)
plotcp(tree)

tree2 <- prune(tree, cp = 0.5)
rpart.plot(tree2, branch.type = 5)

tree3 <- prune(tree, cp = 0.07)
rpart.plot(tree3, branch.type = 5)

#rpart.predict(tree, rules = T,
# newdata = data.frame(年齢 = 5,
# 性別 = '女性',
# 客室等級 = '3等',
# 運賃 = 10,
# 兄弟配偶者数 = 5,
# 親子数 = 7,
# 乗船地 = 'S'))
d.new <- data.frame(年齢 = 5,
性別 = '女性',
客室等級 = '3等',
運賃 = 10,
兄弟配偶者数 = 5,
親子数 = 7,
乗船地 = 'S')
rpart.predict(tree, rules = T, newdata = d.new)
## 死亡 生存
## 1 0.2452107 0.7547893 because 性別 is 女性
xs <- c(1, 2, 1, 2)
ys <- c(1, 1, 2, 2)
d0 <- data.frame(x = c(xs, xs, xs + 5, xs + 5),
y = c(ys, ys + 5, ys + 3, ys + 9),
z = c(rep(1, 4), rep(2, 4), rep(3, 4), rep(4, 4)))
matplot(x = d0$x, y = d0$y, pch = 1, cex = 2, col = COL[1])
text(d0$x, d0$y, d0$z)

tree <- rpart(z ~ x + y, data = d0, minsplit = 2, minbucket = 1)
rpart.plot(tree)

matplot(x = d0$x, y = d0$y, pch = 1, cex = 2, col = COL[1])
segments(4, 0, 4,12, lty = 3, col = COL[2])
segments(4, 8, 8, 8, lty = 3, col = COL[2])
segments(0, 4, 4, 4, lty = 3, col = COL[2])
text(d0$x, d0$y, d0$z)

plotcp(tree)

tree <- rpart(年齢 ~ 性別 + 客室等級 + 運賃 + 兄弟配偶者数 + 親子数 + 乗船地, data = d)
rpart.plot(tree, type = 5)

rpart.plot(tree, branch.type = 5)

import pandas as pd
from sklearn import tree
from sklearn.tree import plot_tree
d = pd.read_csv('https://stats.dip.jp/01_ds/data/titanic_data.csv').dropna(axis = 'rows')
d
## PassengerId Survived Pclass Name ... SibSp Parch Fare Embarked
## 0 1 0 3 Braund ... 1 0 7.2500 S
## 1 2 1 1 Cumings ... 1 0 71.2833 C
## 2 3 1 3 Heikkinen ... 0 0 7.9250 S
## 3 4 1 1 Futrelle ... 1 0 53.1000 S
## 4 5 0 3 Allen ... 0 0 8.0500 S
## .. ... ... ... ... ... ... ... ... ...
## 886 887 0 2 Montvila ... 0 0 13.0000 S
## 887 888 1 1 Graham ... 0 0 30.0000 S
## 888 889 0 3 Johnston ... 1 2 23.4500 S
## 889 890 1 1 Behr ... 0 0 30.0000 C
## 890 891 0 3 Dooley ... 0 0 7.7500 Q
##
## [889 rows x 10 columns]
d.columns
## Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',
## 'Parch', 'Fare', 'Embarked'],
## dtype='object')
y = d['Survived']
x = d[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']]
x = pd.get_dummies(x)
y
## 0 0
## 1 1
## 2 1
## 3 1
## 4 0
## ..
## 886 0
## 887 1
## 888 0
## 889 1
## 890 0
## Name: Survived, Length: 889, dtype: int64
x
## Pclass SibSp ... Embarked_ Q Embarked_ S
## 0 3 1 ... False True
## 1 1 1 ... False False
## 2 3 0 ... False True
## 3 1 1 ... False True
## 4 3 0 ... False True
## .. ... ... ... ... ...
## 886 2 0 ... False True
## 887 1 0 ... False True
## 888 3 1 ... False True
## 889 1 0 ... False False
## 890 3 0 ... True False
##
## [889 rows x 98 columns]
model = tree.DecisionTreeClassifier(max_depth = 3, random_state = 0)
model.fit(x, y)
DecisionTreeClassifier(max_depth=3, random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
plot_tree(model, feature_names = x.columns.tolist(), class_names = ['male', 'female'], filled = True)
