Regression and Classification Trees

Dimitrios Zacharatos

This shows the output of PART RPART functions from workingfunctions.
Installation instructions for workingfunctions can be found here These functions are not installed by default they are in /workingfunctions/working_functions/OTHER/ML_TREE_RPART.R

Rpart Classification

result<-data.frame(rtree_classification$cptable)
result$nsplit<-factor(result$nsplit+1)
minimun_size<-as.numeric(as.character(result[which.min(result[,"xerror"]),"nsplit"]))
initial_model<-rtree_classification
model<-rpart::prune(rtree_classification,cp=rtree_classification$cptable[which.min(rtree_classification$cptable[,"xerror"]),"CP"])

importance<-model$variable.importance
importance<-data.frame(names=names(importance),importance=importance)
importance$names<-factor(importance$names,levels=rev(as.character(importance$names)))
plot_importance<-ggplot(importance,aes(x=names,y=importance))+
  geom_bar(stat='identity')+
  labs(title="Importance Plot",y="Relative Influence",x="Predictor")+
  theme_bw(base_size=10)+
  scale_x_discrete(limits=rev(levels(names)))+
  coord_flip()
plot_importance

rtree_classification
## n= 223 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##   1) root 223 74 0 (0.6681614 0.3318386)  
##     2) spontaneous< 0.5 124 24 0 (0.8064516 0.1935484) *
##     3) spontaneous>=0.5 99 49 1 (0.4949495 0.5050505)  
##       6) parity>=3.5 20  5 0 (0.7500000 0.2500000) *
##       7) parity< 3.5 79 34 1 (0.4303797 0.5696203)  
##        14) age< 30.5 46 21 0 (0.5434783 0.4565217)  
##          28) age>=28.5 9  2 0 (0.7777778 0.2222222) *
##          29) age< 28.5 37 18 1 (0.4864865 0.5135135)  
##            58) spontaneous< 1.5 27 12 0 (0.5555556 0.4444444)  
##             116) parity>=1.5 12  3 0 (0.7500000 0.2500000) *
##             117) parity< 1.5 15  6 1 (0.4000000 0.6000000) *
##            59) spontaneous>=1.5 10  3 1 (0.3000000 0.7000000) *
##        15) age>=30.5 33  9 1 (0.2727273 0.7272727) *
rpart::plotcp(model)

rpart::rsq.rpart(model)
## 
## Classification tree:
## rpart::rpart(formula = infert_formula, data = train_test_classification$f$train$f1, 
##     model = TRUE, x = TRUE, y = TRUE)
## 
## Variables actually used in tree construction:
## [1] age         parity      spontaneous
## 
## Root node error: 74/223 = 0.33184
## 
## n= 223 
## 
##         CP nsplit rel error  xerror     xstd
## 1 0.074324      0   1.00000 1.00000 0.095022
## 2 0.054054      2   0.85135 0.91892 0.092904
## 3 0.031532      3   0.79730 0.89189 0.092117
## 4 0.010000      6   0.70270 0.87838 0.091707
## Warning in rpart::rsq.rpart(model): may not be applicable for this method

error<-data.frame(model$cptable)
error$nsplit<-factor(error$nsplit+1)
tree_size<-error[which.min(error[,"xerror"]),"nsplit"]
error<-reshape2::melt(error,id.vars="nsplit")
names(error)<-c("Split","Metric","value")
plot_prune<-ggplot(error,aes(x=Split,y=value,color=Metric))+
  geom_line(aes(group=Metric))+
  geom_point()+
  labs(title=paste("Error Plot","Suggested Size:",tree_size),y="Metric value",x="Size of Tree")+
  theme_bw(base_size=10)
plot_prune

rpart.plot::rpart.plot(model,type=1)

frame<-data.frame(model$frame)
cp<-data.frame(model$cptable)
parameters<-data.frame(parameters=unlist(model$control))
splits<-data.frame(name=row.names(model$splits),model$splits,row.names=NULL)
importance<-data.frame(importance=model$variable.importance)
ordered<-data.frame(ordered=model$ordered)
data<-data.frame(y=model$y,x=model$x,model=model$model)
call<-data.frame(call=call_to_string(model))
result<-list(frame=frame,cp=cp,parameters=parameters,splits=splits,importance=importance,ordered=ordered,call=call)

print(result)
## $frame
##             var   n  wt dev yval  complexity ncompete nsurrogate     yval2.V1     yval2.V2     yval2.V3     yval2.V4     yval2.V5 yval2.nodeprob
## 1   spontaneous 223 223  74    1 0.074324324        4          4   1.00000000 149.00000000  74.00000000   0.66816143   0.33183857     1.00000000
## 2         124 124  24    1 0.000000000        0          0   1.00000000 100.00000000  24.00000000   0.80645161   0.19354839     0.55605381
## 3        parity  99  99  49    2 0.074324324        4          2   2.00000000  49.00000000  50.00000000   0.49494949   0.50505051     0.44394619
## 6          20  20   5    1 0.000000000        0          0   1.00000000  15.00000000   5.00000000   0.75000000   0.25000000     0.08968610
## 7           age  79  79  34    2 0.054054054        4          0   2.00000000  34.00000000  45.00000000   0.43037975   0.56962025     0.35426009
## 14          age  46  46  21    1 0.031531532        4          0   1.00000000  25.00000000  21.00000000   0.54347826   0.45652174     0.20627803
## 28          9   9   2    1 0.010000000        0          0   1.00000000   7.00000000   2.00000000   0.77777778   0.22222222     0.04035874
## 29  spontaneous  37  37  18    2 0.031531532        4          0   2.00000000  18.00000000  19.00000000   0.48648649   0.51351351     0.16591928
## 58       parity  27  27  12    1 0.031531532        3          3   1.00000000  15.00000000  12.00000000   0.55555556   0.44444444     0.12107623
## 116        12  12   3    1 0.010000000        0          0   1.00000000   9.00000000   3.00000000   0.75000000   0.25000000     0.05381166
## 117        15  15   6    2 0.010000000        0          0   2.00000000   6.00000000   9.00000000   0.40000000   0.60000000     0.06726457
## 59         10  10   3    2 0.010000000        0          0   2.00000000   3.00000000   7.00000000   0.30000000   0.70000000     0.04484305
## 15         33  33   9    2 0.006756757        0          0   2.00000000   9.00000000  24.00000000   0.27272727   0.72727273     0.14798206
## 
## $cp
##           CP nsplit rel.error    xerror       xstd
## 1 0.07432432      0 1.0000000 1.0000000 0.09502215
## 2 0.05405405      2 0.8513514 0.9189189 0.09290437
## 3 0.03153153      3 0.7972973 0.8918919 0.09211655
## 4 0.01000000      6 0.7027027 0.8783784 0.09170670
## 
## $parameters
##                parameters
## minsplit            20.00
## minbucket            7.00
## cp                   0.01
## maxcompete           4.00
## maxsurrogate         5.00
## usesurrogate         2.00
## surrogatestyle       0.00
## maxdepth            30.00
## xval                10.00
## 
## $splits
##           name count ncat      improve index        adj
## 1  spontaneous   223   -1 1.068327e+01   0.5 0.00000000
## 2          age   223    1 7.907546e-02  35.5 0.00000000
## 3      induced   223   -1 3.572929e-02   1.5 0.00000000
## 4       parity   223   -1 2.339838e-02   4.5 0.00000000
## 5    education   223    3 2.122571e-02   1.0 0.00000000
## 6       parity     0   -1 6.367713e-01   2.5 0.18181818
## 7      induced     0    1 6.233184e-01   0.5 0.15151515
## 8          age     0    1 5.695067e-01  24.5 0.03030303
## 9    education     0    3 5.605381e-01   2.0 0.01010101
## 10      parity    99    1 3.260772e+00   3.5 0.00000000
## 11 spontaneous    99   -1 2.585859e+00   1.5 0.00000000
## 12         age    99   -1 1.487542e+00  30.5 0.00000000
## 13     induced    99    1 2.727273e-01   0.5 0.00000000
## 14   education    99    3 2.430976e-01   3.0 0.00000000
## 15     induced     0    1 8.282828e-01   1.5 0.15000000
## 16   education     0    3 8.181818e-01   4.0 0.10000000
## 17         age    79   -1 2.817181e+00  30.5 0.00000000
## 18 spontaneous    79   -1 2.651214e+00   1.5 0.00000000
## 19      parity    79    1 5.822955e-01   1.5 0.00000000
## 20     induced    79    1 4.876055e-02   0.5 0.00000000
## 21   education    79    3 1.240302e-02   5.0 0.00000000
## 22         age    46    1 1.228489e+00  28.5 0.00000000
## 23      parity    46    1 9.356204e-01   1.5 0.00000000
## 24 spontaneous    46   -1 9.146650e-01   1.5 0.00000000
## 25   education    46    3 5.786225e-01   6.0 0.00000000
## 26     induced    46    1 1.129305e-04   0.5 0.00000000
## 27 spontaneous    37   -1 9.531532e-01   1.5 0.00000000
## 28         age    37   -1 7.007722e-01  26.5 0.00000000
## 29      parity    37    1 3.773956e-01   1.5 0.00000000
## 30   education    37    3 6.486486e-03   7.0 0.00000000
## 31     induced    37    1 5.005005e-03   0.5 0.00000000
## 32      parity    27    1 1.633333e+00   1.5 0.00000000
## 33         age    27   -1 1.096491e-01  24.5 0.00000000
## 34     induced    27    1 1.096491e-01   0.5 0.00000000
## 35   education    27    3 9.803922e-02   8.0 0.00000000
## 36     induced     0    1 8.518519e-01   0.5 0.66666667
## 37         age     0    1 7.037037e-01  23.5 0.33333333
## 38   education     0    3 6.296296e-01   9.0 0.16666667
## 
## $importance
##             importance
## spontaneous 11.6364186
## parity       6.8365175
## age          4.9138503
## induced      3.1966813
## education    0.7062112
## 
## $ordered
##             ordered
## age           FALSE
## parity        FALSE
## education     FALSE
## spontaneous   FALSE
## induced       FALSE
## 
## $call
##                                                                                                       call
## 1 rpart::rpart(formula=infert_formula,data=train_test_classification$f$train$f1,,model=TRUE,x=TRUE,y=TRUE)

Rpart Regression

result<-data.frame(rtree_regression$cptable)
result$nsplit<-factor(result$nsplit+1)
minimun_size<-as.numeric(as.character(result[which.min(result[,"xerror"]),"nsplit"]))
initial_model<-rtree_regression
model<-rpart::prune(rtree_regression,cp=rtree_regression$cptable[which.min(rtree_regression$cptable[,"xerror"]),"CP"])

importance<-model$variable.importance
importance<-data.frame(names=names(importance),importance=importance)
importance$names<-factor(importance$names,levels=rev(as.character(importance$names)))
plot_importance<-ggplot(importance,aes(x=names,y=importance))+
  geom_bar(stat='identity')+
  labs(title="Importance Plot",y="Relative Influence",x="Predictor")+
  theme_bw(base_size=10)+
  scale_x_discrete(limits=rev(levels(names)))+
  coord_flip()
plot_importance

rtree_regression
## n= 455 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 455 36882.4700 22.27934  
##    2) rm< 6.941 392 15593.1200 19.86454  
##      4) lstat>=14.4 160  3093.8000 14.85125  
##        8) nox>=0.607 95  1158.4370 12.65263  
##         16) lstat>=19.645 53   469.1419 10.56415 *
##         17) lstat< 19.645 42   166.4040 15.28810 *
##        9) nox< 0.607 65   804.9686 18.06462 *
##      5) lstat< 14.4 232  5704.7180 23.32198  
##       10) rm< 6.543 179  3125.3570 21.96313  
##         20) crim< 7.24712 172  1449.2880 21.64651 *
##         21) crim>=7.24712 7  1235.1570 29.74286 *
##       11) rm>=6.543 53  1132.5530 27.91132 *
##    3) rm>=6.941 63  4780.4090 37.30476  
##      6) rm< 7.4545 39  1376.0590 32.40513  
##       12) lstat>=5.495 18   603.9711 28.97778 *
##       13) lstat< 5.495 21   379.4114 35.34286 *
##      7) rm>=7.4545 24   946.6933 45.26667 *
rpart::plotcp(model)

rpart::rsq.rpart(model)
## 
## Regression tree:
## rpart::rpart(formula = boston_formula, data = train_test_regression$f$train$f1, 
##     model = TRUE, x = TRUE, y = TRUE)
## 
## Variables actually used in tree construction:
## [1] crim  lstat nox   rm   
## 
## Root node error: 36882/455 = 81.06
## 
## n= 455 
## 
##         CP nsplit rel error  xerror     xstd
## 1 0.447609      0   1.00000 1.00304 0.089852
## 2 0.184223      1   0.55239 0.63906 0.060331
## 3 0.066635      2   0.36817 0.39400 0.045195
## 4 0.039228      3   0.30153 0.35010 0.043677
## 5 0.030649      4   0.26231 0.31264 0.044314
## 6 0.014177      5   0.23166 0.26498 0.042395
## 7 0.011955      6   0.21748 0.25069 0.039729
## 8 0.010647      7   0.20552 0.24530 0.039714
## 9 0.010000      8   0.19488 0.24468 0.039712

error<-data.frame(model$cptable)
error$nsplit<-factor(error$nsplit+1)
tree_size<-error[which.min(error[,"xerror"]),"nsplit"]
error<-reshape2::melt(error,id.vars="nsplit")
names(error)<-c("Split","Metric","value")
plot_prune<-ggplot(error,aes(x=Split,y=value,color=Metric))+
  geom_line(aes(group=Metric))+
  geom_point()+
  labs(title=paste("Error Plot","Suggested Size:",tree_size),y="Metric value",x="Size of Tree")+
  theme_bw(base_size=10)
plot_prune

rpart.plot::rpart.plot(model,type=1)

frame<-data.frame(model$frame)
cp<-data.frame(model$cptable)
parameters<-data.frame(parameters=unlist(model$control))
splits<-data.frame(name=row.names(model$splits),model$splits,row.names=NULL)
importance<-data.frame(importance=model$variable.importance)
ordered<-data.frame(ordered=model$ordered)
data<-data.frame(y=model$y,x=model$x,model=model$model)
call<-data.frame(call=call_to_string(model))
result<-list(frame=frame,cp=cp,parameters=parameters,splits=splits,importance=importance,ordered=ordered,call=call)

print(result)
## $frame
##       var   n  wt        dev     yval  complexity ncompete nsurrogate
## 1      rm 455 455 36882.4658 22.27934 0.447609447        4          5
## 2   lstat 392 392 15593.1171 19.86454 0.184223027        4          5
## 4     nox 160 160  3093.7998 14.85125 0.030648555        4          5
## 8   lstat  95  95  1158.4368 12.65263 0.014177222        4          5
## 16   53  53   469.1419 10.56415 0.004942787        0          0
## 17   42  42   166.4040 15.28810 0.010000000        0          0
## 9    65  65   804.9686 18.06462 0.006818168        0          0
## 5      rm 232 232  5704.7179 23.32198 0.039227530        4          5
## 10   crim 179 179  3125.3566 21.96313 0.011954504        4          2
## 20  172 172  1449.2879 21.64651 0.007874345        0          0
## 21    7   7  1235.1571 29.74286 0.010000000        0          0
## 11   53  53  1132.5532 27.91132 0.008671637        0          0
## 3      rm  63  63  4780.4086 37.30476 0.066634814        4          5
## 6   lstat  39  39  1376.0590 32.40513 0.010646697        4          5
## 12   18  18   603.9711 28.97778 0.010000000        0          0
## 13   21  21   379.4114 35.34286 0.002712095        0          0
## 7    24  24   946.6933 45.26667 0.005777875        0          0
## 
## $cp
##           CP nsplit rel.error    xerror       xstd
## 1 0.44760945      0 1.0000000 1.0030382 0.08985226
## 2 0.18422303      1 0.5523906 0.6390649 0.06033078
## 3 0.06663481      2 0.3681675 0.3940014 0.04519524
## 4 0.03922753      3 0.3015327 0.3501001 0.04367740
## 5 0.03064856      4 0.2623052 0.3126368 0.04431441
## 6 0.01417722      5 0.2316566 0.2649762 0.04239499
## 7 0.01195450      6 0.2174794 0.2506874 0.03972945
## 8 0.01064670      7 0.2055249 0.2453018 0.03971366
## 9 0.01000000      8 0.1948782 0.2446765 0.03971197
## 
## $parameters
##                parameters
## minsplit            20.00
## minbucket            7.00
## cp                   0.01
## maxcompete           4.00
## maxsurrogate         5.00
## usesurrogate         2.00
## surrogatestyle       0.00
## maxdepth            30.00
## xval                10.00
## 
## $splits
##       name count ncat    improve      index        adj
## 1       rm   455   -1 0.44760945   6.941000 0.00000000
## 2    lstat   455    1 0.43832216   9.950000 0.00000000
## 3    indus   455    1 0.27457781   6.660000 0.00000000
## 4  ptratio   455    1 0.24752787  19.650000 0.00000000
## 5      nox   455    1 0.22601220   0.669500 0.00000000
## 6    lstat     0    1 0.89890110   4.830000 0.26984127
## 7  ptratio     0    1 0.88131868  14.150000 0.14285714
## 8       zn     0   -1 0.87252747  87.500000 0.07936508
## 9    indus     0    1 0.87252747   1.605000 0.07936508
## 10    crim     0    1 0.86373626   0.013355 0.01587302
## 11   lstat   392    1 0.43574350  14.400000 0.00000000
## 12     nox   392    1 0.28729760   0.669500 0.00000000
## 13    crim   392    1 0.25216185   5.848030 0.00000000
## 14 ptratio   392    1 0.23720257  19.900000 0.00000000
## 15     age   392    1 0.22884396  75.750000 0.00000000
## 16     age     0    1 0.82142857  84.600000 0.56250000
## 17     dis     0   -1 0.77806122   2.239350 0.45625000
## 18     nox     0    1 0.77551020   0.576500 0.45000000
## 19   indus     0    1 0.77295918  16.570000 0.44375000
## 20     tax     0    1 0.76275510 434.500000 0.41875000
## 21     nox   160    1 0.36537410   0.607000 0.00000000
## 22    crim   160    1 0.35899922   5.781900 0.00000000
## 23     tax   160    1 0.31905818 567.500000 0.00000000
## 24     dis   160   -1 0.29272223   1.990400 0.00000000
## 25 ptratio   160    1 0.29088455  19.900000 0.00000000
## 26   indus     0    1 0.87500000  16.010000 0.69230769
## 27     tax     0    1 0.86875000 397.000000 0.67692308
## 28     dis     0   -1 0.83750000   2.384050 0.60000000
## 29    crim     0    1 0.82500000   1.400920 0.56923077
## 30     rad     0    1 0.75000000  16.000000 0.38461538
## 31   lstat    95    1 0.45137628  19.645000 0.00000000
## 32    crim    95    1 0.42470492  10.452400 0.00000000
## 33     tax    95    1 0.21801780 551.500000 0.00000000
## 34     dis    95   -1 0.19770744   1.951250 0.00000000
## 35     rad    95    1 0.14739896  14.500000 0.00000000
## 36     dis     0   -1 0.78947368   1.842250 0.52380952
## 37    crim     0    1 0.74736842   8.951885 0.42857143
## 38      rm     0   -1 0.69473684   5.627500 0.30952381
## 39     nox     0   -1 0.67368421   0.706500 0.26190476
## 40     age     0    1 0.63157895  98.850000 0.16666667
## 41      rm   232   -1 0.25361605   6.543000 0.00000000
## 42   lstat   232    1 0.20737698   7.685000 0.00000000
## 43     dis   232    1 0.13514617   1.600900 0.00000000
## 44    chas   232   -1 0.07354552   0.500000 0.00000000
## 45   indus   232    1 0.06337127   4.100000 0.00000000
## 46   lstat     0    1 0.84913793   5.055000 0.33962264
## 47      zn     0   -1 0.79741379  31.500000 0.11320755
## 48    crim     0    1 0.78879310   0.017895 0.07547170
## 49   indus     0    1 0.78017241   4.010000 0.03773585
## 50     dis     0   -1 0.77586207  10.648000 0.01886792
## 51    crim   179   -1 0.14107561   7.247120 0.00000000
## 52   lstat   179    1 0.11894361   9.660000 0.00000000
## 53     dis   179    1 0.11625341   1.615600 0.00000000
## 54      rm   179   -1 0.07798460   6.142000 0.00000000
## 55 ptratio   179    1 0.05518417  20.550000 0.00000000
## 56     dis     0    1 0.98882682   1.551100 0.71428571
## 57     age     0   -1 0.96648045  99.450000 0.14285714
## 58      rm    63   -1 0.51411009   7.454500 0.00000000
## 59   lstat    63    1 0.29110784   5.185000 0.00000000
## 60 ptratio    63    1 0.20314402  19.700000 0.00000000
## 61   black    63    1 0.09352138 392.805000 0.00000000
## 62    crim    63    1 0.08962926   1.676395 0.00000000
## 63   lstat     0    1 0.74603175   3.990000 0.33333333
## 64    crim     0   -1 0.69841270   0.112760 0.20833333
## 65 ptratio     0    1 0.69841270  14.750000 0.20833333
## 66   black     0    1 0.68253968 390.065000 0.16666667
## 67   indus     0   -1 0.65079365  18.840000 0.08333333
## 68   lstat    39    1 0.28536309   5.495000 0.00000000
## 69     nox    39    1 0.14197199   0.488500 0.00000000
## 70     tax    39    1 0.11804967 267.000000 0.00000000
## 71     rad    39    1 0.11441342   7.500000 0.00000000
## 72 ptratio    39    1 0.10870099  18.900000 0.00000000
## 73     nox     0    1 0.74358974   0.470500 0.44444444
## 74     age     0    1 0.74358974  58.950000 0.44444444
## 75     dis     0   -1 0.74358974   3.546600 0.44444444
## 76    crim     0    1 0.69230769   0.102345 0.33333333
## 77      zn     0   -1 0.64102564  33.500000 0.22222222
## 
## $importance
##         importance
## rm      20575.2516
## lstat   13475.5477
## indus    5367.3187
## dis      4568.9258
## nox      4499.4345
## age      4146.6209
## tax      3610.4285
## ptratio  2870.4317
## crim     2322.6065
## zn       1561.2844
## rad       434.7670
## black     409.6094
## 
## $ordered
##         ordered
## crim      FALSE
## zn        FALSE
## indus     FALSE
## chas      FALSE
## nox       FALSE
## rm        FALSE
## age       FALSE
## dis       FALSE
## rad       FALSE
## tax       FALSE
## ptratio   FALSE
## black     FALSE
## lstat     FALSE
## 
## $call
##                                                                                                   call
## 1 rpart::rpart(formula=boston_formula,data=train_test_regression$f$train$f1,,model=TRUE,x=TRUE,y=TRUE)