#广州大学公共管理学院社会学系创新班 机器学习入门
library(rpart)
## Warning: package 'rpart' was built under R version 4.0.3
library(titanic)
## Warning: package 'titanic' was built under R version 4.0.3
library(rpart.plot)
## Warning: package 'rpart.plot' was built under R version 4.0.3
View(titanic_train)
names(titanic_train)
##  [1] "PassengerId" "Survived"    "Pclass"      "Name"        "Sex"        
##  [6] "Age"         "SibSp"       "Parch"       "Ticket"      "Fare"       
## [11] "Cabin"       "Embarked"
#basic rpart
fit <- rpart(Survived ~ Age + Sex + Pclass+Fare, data = titanic_train)
fit #result
## n= 891 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 891 210.727300 0.3838384  
##    2) Sex=male 577  88.409010 0.1889081  
##      4) Age>=6.5 553  77.359860 0.1681736  
##        8) Pclass>=1.5 433  44.226330 0.1154734 *
##        9) Pclass< 1.5 120  27.591670 0.3583333 *
##      5) Age< 6.5 24   5.333333 0.6666667 *
##    3) Sex=female 314  60.105100 0.7420382  
##      6) Pclass>=2.5 144  36.000000 0.5000000  
##       12) Fare>=23.35 27   2.666667 0.1111111 *
##       13) Fare< 23.35 117  28.307690 0.5897436 *
##      7) Pclass< 2.5 170   8.523529 0.9470588 *
plot(fit)
text(fit, use.n = TRUE)

#fancyplot use rpartplot
rpart.plot(fit,
           main = "titanic survived\n(binary response)")

#model2 
binary.model <- rpart(survived ~ ., data = ptitanic, cp = .02)
# cp = .02 for small demo tree
View(ptitanic)
names(ptitanic)
## [1] "pclass"   "survived" "sex"      "age"      "sibsp"    "parch"
rpart.plot(binary.model,
           main = "titanic survived\n(binary response)")

rpart.plot(fit,
           main = "titanic survived\n(binary response)")
summary(fit)
## Call:
## rpart(formula = Survived ~ Age + Sex + Pclass + Fare, data = titanic_train)
##   n= 891 
## 
##           CP nsplit rel error    xerror       xstd
## 1 0.29523072      0 1.0000000 1.0025204 0.01606048
## 2 0.07394186      1 0.7047693 0.7081410 0.03325318
## 3 0.02712427      2 0.6308274 0.6351655 0.03170495
## 4 0.02629874      3 0.6037031 0.6381571 0.03273129
## 5 0.02384903      4 0.5774044 0.6350682 0.03290604
## 6 0.01000000      5 0.5535554 0.5777702 0.03231963
## 
## Variable importance
##    Sex   Fare Pclass    Age 
##     55     20     19      5 
## 
## Node number 1: 891 observations,    complexity param=0.2952307
##   mean=0.3838384, MSE=0.2365065 
##   left son=2 (577 obs) right son=3 (314 obs)
##   Primary splits:
##       Sex    splits as  RL,           improve=0.29523070, (0 missing)
##       Pclass < 2.5      to the right, improve=0.10388270, (0 missing)
##       Fare   < 10.48125 to the left,  improve=0.09002618, (0 missing)
##       Age    < 6.5      to the right, improve=0.02091370, (177 missing)
##   Surrogate splits:
##       Fare < 77.6229  to the left,  agree=0.679, adj=0.089, (0 split)
## 
## Node number 2: 577 observations,    complexity param=0.02712427
##   mean=0.1889081, MSE=0.1532219 
##   left son=4 (553 obs) right son=5 (24 obs)
##   Primary splits:
##       Age    < 6.5      to the right, improve=0.06101713, (124 missing)
##       Fare   < 26.26875 to the left,  improve=0.05778096, (0 missing)
##       Pclass < 1.5      to the right, improve=0.05666357, (0 missing)
## 
## Node number 3: 314 observations,    complexity param=0.07394186
##   mean=0.7420382, MSE=0.1914175 
##   left son=6 (144 obs) right son=7 (170 obs)
##   Primary splits:
##       Pclass < 2.5      to the right, improve=0.25923870, (0 missing)
##       Fare   < 48.2     to the left,  improve=0.08413772, (0 missing)
##       Age    < 12       to the left,  improve=0.01573647, (53 missing)
##   Surrogate splits:
##       Fare < 25.69795 to the left,  agree=0.799, adj=0.563, (0 split)
##       Age  < 18.5     to the left,  agree=0.564, adj=0.049, (0 split)
## 
## Node number 4: 553 observations,    complexity param=0.02629874
##   mean=0.1681736, MSE=0.1398912 
##   left son=8 (433 obs) right son=9 (120 obs)
##   Primary splits:
##       Pclass < 1.5      to the right, improve=0.071637420, (0 missing)
##       Fare   < 26.26875 to the left,  improve=0.068071850, (0 missing)
##       Age    < 24.75    to the left,  improve=0.007985322, (124 missing)
##   Surrogate splits:
##       Fare < 26.26875 to the left,  agree=0.911, adj=0.592, (0 split)
## 
## Node number 5: 24 observations
##   mean=0.6666667, MSE=0.2222222 
## 
## Node number 6: 144 observations,    complexity param=0.02384903
##   mean=0.5, MSE=0.25 
##   left son=12 (27 obs) right son=13 (117 obs)
##   Primary splits:
##       Fare < 23.35    to the right, improve=0.13960110, (0 missing)
##       Age  < 38.5     to the right, improve=0.05382171, (42 missing)
## 
## Node number 7: 170 observations
##   mean=0.9470588, MSE=0.05013841 
## 
## Node number 8: 433 observations
##   mean=0.1154734, MSE=0.1021393 
## 
## Node number 9: 120 observations
##   mean=0.3583333, MSE=0.2299306 
## 
## Node number 12: 27 observations
##   mean=0.1111111, MSE=0.09876543 
## 
## Node number 13: 117 observations
##   mean=0.5897436, MSE=0.2419461
#predict(binary.model)

library(vip) 
## Warning: package 'vip' was built under R version 4.0.3
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi

vip(fit, num_features = 4) #变量重要性 

library(caret) #机器学习算法集成包
## Warning: package 'caret' was built under R version 4.0.3
## Loading required package: lattice
## Loading required package: ggplot2
tt <- train(
  survived ~ .,
  data = ptitanic,
  method = "rpart",na.action=na.omit
)

vip(tt, num_features = 4)