#广州大学公共管理学院社会学系创新班 机器学习入门
library(rpart)
## Warning: package 'rpart' was built under R version 4.0.3
library(titanic)
## Warning: package 'titanic' was built under R version 4.0.3
library(rpart.plot)
## Warning: package 'rpart.plot' was built under R version 4.0.3
View(titanic_train)
names(titanic_train)
## [1] "PassengerId" "Survived" "Pclass" "Name" "Sex"
## [6] "Age" "SibSp" "Parch" "Ticket" "Fare"
## [11] "Cabin" "Embarked"
#basic rpart
fit <- rpart(Survived ~ Age + Sex + Pclass+Fare, data = titanic_train)
fit #result
## n= 891
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 891 210.727300 0.3838384
## 2) Sex=male 577 88.409010 0.1889081
## 4) Age>=6.5 553 77.359860 0.1681736
## 8) Pclass>=1.5 433 44.226330 0.1154734 *
## 9) Pclass< 1.5 120 27.591670 0.3583333 *
## 5) Age< 6.5 24 5.333333 0.6666667 *
## 3) Sex=female 314 60.105100 0.7420382
## 6) Pclass>=2.5 144 36.000000 0.5000000
## 12) Fare>=23.35 27 2.666667 0.1111111 *
## 13) Fare< 23.35 117 28.307690 0.5897436 *
## 7) Pclass< 2.5 170 8.523529 0.9470588 *
plot(fit)
text(fit, use.n = TRUE)

#fancyplot use rpartplot
rpart.plot(fit,
main = "titanic survived\n(binary response)")

#model2
binary.model <- rpart(survived ~ ., data = ptitanic, cp = .02)
# cp = .02 for small demo tree
View(ptitanic)
names(ptitanic)
## [1] "pclass" "survived" "sex" "age" "sibsp" "parch"
rpart.plot(binary.model,
main = "titanic survived\n(binary response)")

rpart.plot(fit,
main = "titanic survived\n(binary response)")
summary(fit)
## Call:
## rpart(formula = Survived ~ Age + Sex + Pclass + Fare, data = titanic_train)
## n= 891
##
## CP nsplit rel error xerror xstd
## 1 0.29523072 0 1.0000000 1.0025204 0.01606048
## 2 0.07394186 1 0.7047693 0.7081410 0.03325318
## 3 0.02712427 2 0.6308274 0.6351655 0.03170495
## 4 0.02629874 3 0.6037031 0.6381571 0.03273129
## 5 0.02384903 4 0.5774044 0.6350682 0.03290604
## 6 0.01000000 5 0.5535554 0.5777702 0.03231963
##
## Variable importance
## Sex Fare Pclass Age
## 55 20 19 5
##
## Node number 1: 891 observations, complexity param=0.2952307
## mean=0.3838384, MSE=0.2365065
## left son=2 (577 obs) right son=3 (314 obs)
## Primary splits:
## Sex splits as RL, improve=0.29523070, (0 missing)
## Pclass < 2.5 to the right, improve=0.10388270, (0 missing)
## Fare < 10.48125 to the left, improve=0.09002618, (0 missing)
## Age < 6.5 to the right, improve=0.02091370, (177 missing)
## Surrogate splits:
## Fare < 77.6229 to the left, agree=0.679, adj=0.089, (0 split)
##
## Node number 2: 577 observations, complexity param=0.02712427
## mean=0.1889081, MSE=0.1532219
## left son=4 (553 obs) right son=5 (24 obs)
## Primary splits:
## Age < 6.5 to the right, improve=0.06101713, (124 missing)
## Fare < 26.26875 to the left, improve=0.05778096, (0 missing)
## Pclass < 1.5 to the right, improve=0.05666357, (0 missing)
##
## Node number 3: 314 observations, complexity param=0.07394186
## mean=0.7420382, MSE=0.1914175
## left son=6 (144 obs) right son=7 (170 obs)
## Primary splits:
## Pclass < 2.5 to the right, improve=0.25923870, (0 missing)
## Fare < 48.2 to the left, improve=0.08413772, (0 missing)
## Age < 12 to the left, improve=0.01573647, (53 missing)
## Surrogate splits:
## Fare < 25.69795 to the left, agree=0.799, adj=0.563, (0 split)
## Age < 18.5 to the left, agree=0.564, adj=0.049, (0 split)
##
## Node number 4: 553 observations, complexity param=0.02629874
## mean=0.1681736, MSE=0.1398912
## left son=8 (433 obs) right son=9 (120 obs)
## Primary splits:
## Pclass < 1.5 to the right, improve=0.071637420, (0 missing)
## Fare < 26.26875 to the left, improve=0.068071850, (0 missing)
## Age < 24.75 to the left, improve=0.007985322, (124 missing)
## Surrogate splits:
## Fare < 26.26875 to the left, agree=0.911, adj=0.592, (0 split)
##
## Node number 5: 24 observations
## mean=0.6666667, MSE=0.2222222
##
## Node number 6: 144 observations, complexity param=0.02384903
## mean=0.5, MSE=0.25
## left son=12 (27 obs) right son=13 (117 obs)
## Primary splits:
## Fare < 23.35 to the right, improve=0.13960110, (0 missing)
## Age < 38.5 to the right, improve=0.05382171, (42 missing)
##
## Node number 7: 170 observations
## mean=0.9470588, MSE=0.05013841
##
## Node number 8: 433 observations
## mean=0.1154734, MSE=0.1021393
##
## Node number 9: 120 observations
## mean=0.3583333, MSE=0.2299306
##
## Node number 12: 27 observations
## mean=0.1111111, MSE=0.09876543
##
## Node number 13: 117 observations
## mean=0.5897436, MSE=0.2419461
#predict(binary.model)
library(vip)
## Warning: package 'vip' was built under R version 4.0.3
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi

vip(fit, num_features = 4) #变量重要性

library(caret) #机器学习算法集成包
## Warning: package 'caret' was built under R version 4.0.3
## Loading required package: lattice
## Loading required package: ggplot2
tt <- train(
survived ~ .,
data = ptitanic,
method = "rpart",na.action=na.omit
)
vip(tt, num_features = 4)
