d <- na.omit(read.csv('https://stats.dip.jp/01_ds/data/titanic_data_jp.csv'))
n <- nrow(d)

library(DT)
datatable(d, options = list(pageLength = 5))
COL <- c(rgb(255,   0,   0,  105, max = 255),
         rgb(  0,   0, 255,  105, max = 255),
         rgb(  0, 155,   0,  105, max = 255),
         rgb(100, 100, 100,   55, max = 255))

library(rpart)
library(rpart.plot)
tree <- rpart(生死 ~ 年齢 + 性別 + 客室等級 + 運賃 + 兄弟配偶者数 + 親子数 + 乗船地,
              data = d, method = 'class', cp = 0.04)
rpart.plot(tree, type = 5)

rpart.plot(tree, branch.type = 5)

plotcp(tree)

plotcp(tree)

tree2 <- prune(tree, cp = 0.5)

rpart.plot(tree2, branch.type = 5)

tree3 <- prune(tree, cp = 0.07)

rpart.plot(tree3, branch.type = 5)

#rpart.predict(tree, rules = T,
#              newdata = data.frame(年齢         = 5,
#                                   性別         = '女性', 
#                                   客室等級     = '3等', 
#                                   運賃         = 10,
#                                   兄弟配偶者数 = 5,
#                                   親子数       = 7,
#                                   乗船地       = 'S'))

d.new <- data.frame(年齢         = 5,
                    性別         = '女性', 
                    客室等級     = '3等', 
                    運賃         = 10,
                    兄弟配偶者数 = 5,
                    親子数       = 7,
                    乗船地       = 'S')

rpart.predict(tree, rules = T, newdata = d.new)
##        死亡      生存                     
## 1 0.2452107 0.7547893 because 性別 is 女性
xs <- c(1, 2, 1, 2)
ys <- c(1, 1, 2, 2)

d0 <- data.frame(x = c(xs, xs,   xs + 5, xs + 5),
                 y = c(ys, ys + 5, ys + 3, ys + 9),
                 z = c(rep(1, 4), rep(2, 4), rep(3, 4), rep(4, 4)))

matplot(x = d0$x, y = d0$y, pch = 1, cex = 2, col = COL[1])
text(d0$x, d0$y, d0$z)

tree <- rpart(z ~ x + y, data = d0, minsplit = 2, minbucket = 1)
rpart.plot(tree)

matplot(x = d0$x, y = d0$y, pch = 1, cex = 2, col = COL[1])
segments(4, 0, 4,12, lty = 3, col = COL[2])
segments(4, 8, 8, 8, lty = 3, col = COL[2])
segments(0, 4, 4, 4, lty = 3, col = COL[2])
text(d0$x, d0$y, d0$z)

plotcp(tree)

tree <- rpart(年齢 ~ 性別 + 客室等級 + 運賃 + 兄弟配偶者数 + 親子数 + 乗船地, data = d)

rpart.plot(tree, type = 5)

rpart.plot(tree, branch.type = 5)

import pandas as pd
from sklearn import tree
from sklearn.tree import plot_tree

d = pd.read_csv('https://stats.dip.jp/01_ds/data/titanic_data.csv').dropna(axis = 'rows')
d
##      PassengerId  Survived  Pclass       Name  ... SibSp Parch     Fare  Embarked
## 0              1         0       3     Braund  ...     1     0   7.2500         S
## 1              2         1       1    Cumings  ...     1     0  71.2833         C
## 2              3         1       3  Heikkinen  ...     0     0   7.9250         S
## 3              4         1       1   Futrelle  ...     1     0  53.1000         S
## 4              5         0       3      Allen  ...     0     0   8.0500         S
## ..           ...       ...     ...        ...  ...   ...   ...      ...       ...
## 886          887         0       2   Montvila  ...     0     0  13.0000         S
## 887          888         1       1     Graham  ...     0     0  30.0000         S
## 888          889         0       3   Johnston  ...     1     2  23.4500         S
## 889          890         1       1       Behr  ...     0     0  30.0000         C
## 890          891         0       3     Dooley  ...     0     0   7.7500         Q
## 
## [889 rows x 10 columns]
d.columns
## Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',
##        'Parch', 'Fare', 'Embarked'],
##       dtype='object')
y = d['Survived']
x = d[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']]

x = pd.get_dummies(x)
y
## 0      0
## 1      1
## 2      1
## 3      1
## 4      0
##       ..
## 886    0
## 887    1
## 888    0
## 889    1
## 890    0
## Name: Survived, Length: 889, dtype: int64
x
##      Pclass  SibSp  ...  Embarked_       Q  Embarked_       S
## 0         3      1  ...              False               True
## 1         1      1  ...              False              False
## 2         3      0  ...              False               True
## 3         1      1  ...              False               True
## 4         3      0  ...              False               True
## ..      ...    ...  ...                ...                ...
## 886       2      0  ...              False               True
## 887       1      0  ...              False               True
## 888       3      1  ...              False               True
## 889       1      0  ...              False              False
## 890       3      0  ...               True              False
## 
## [889 rows x 98 columns]
model = tree.DecisionTreeClassifier(max_depth = 3, random_state = 0)
model.fit(x, y)
DecisionTreeClassifier(max_depth=3, random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

plot_tree(model, feature_names = x.columns.tolist(), class_names = ['male', 'female'], filled = True)