In this material is used the Hitters dataset and some regression trees models are constructed, accuracy is performed and pruning process as well. Examples are also explained . (Please understand some errors on writing or comments results due sometimes to the random simulation of data)
A brief information regarding the regression trees.
REGIONS -> In any decision tree, observations are often grouped into regions R1,R2,…Rn.
TERMINAL NODES -> These regions are commonly referred to as terminal nodes or leaves of the tree.
Decision trees are generally plotted upside down, in the sense that the terminal nodes are at the bottom of the tree.
INTERNAL NODES -> The points along the decision tree where the predictor space is split are referred to as internal nodes.
All observations go through this tree, are assessed at a particular node, and: -> proceed to the left if the answer is “yes” or, -> proceed to the right if the answer is “no”.
These are the libraries used in material.
library(rpart)
library(rpart.plot)
library(ISLR)
library(dplyr)
library(tidyverse)
library(explore)# explore relations
library(explore)# explore relations
hitters <- Hitters
head(hitters,10)
str(hitters)
'data.frame': 322 obs. of 20 variables:
$ AtBat : int 293 315 479 496 321 594 185 298 323 401 ...
$ Hits : int 66 81 130 141 87 169 37 73 81 92 ...
$ HmRun : int 1 7 18 20 10 4 1 0 6 17 ...
$ Runs : int 30 24 66 65 39 74 23 24 26 49 ...
$ RBI : int 29 38 72 78 42 51 8 24 32 66 ...
$ Walks : int 14 39 76 37 30 35 21 7 8 65 ...
$ Years : int 1 14 3 11 2 11 2 3 2 13 ...
$ CAtBat : int 293 3449 1624 5628 396 4408 214 509 341 5206 ...
$ CHits : int 66 835 457 1575 101 1133 42 108 86 1332 ...
$ CHmRun : int 1 69 63 225 12 19 1 0 6 253 ...
$ CRuns : int 30 321 224 828 48 501 30 41 32 784 ...
$ CRBI : int 29 414 266 838 46 336 9 37 34 890 ...
$ CWalks : int 14 375 263 354 33 194 24 12 8 866 ...
$ League : Factor w/ 2 levels "A","N": 1 2 1 2 2 1 2 1 2 1 ...
$ Division : Factor w/ 2 levels "E","W": 1 2 2 1 1 2 1 2 2 1 ...
$ PutOuts : int 446 632 880 200 805 282 76 121 143 0 ...
$ Assists : int 33 43 82 11 40 421 127 283 290 0 ...
$ Errors : int 20 10 14 3 4 25 7 9 19 0 ...
$ Salary : num NA 475 480 500 91.5 750 70 100 75 1100 ...
$ NewLeague: Factor w/ 2 levels "A","N": 1 2 1 2 2 1 1 1 2 1 ...
hitters %>% explore(Salary)# explore Salary variable
hitters %>% explore(Years)# explore Years variable
hitters %>% explore(League)# explore league variable
# Plot scatterplot with a dependent variable Salary (target) and explanatory variables those selected from the dataset
hitters %>% select(Salary,Hits, Runs, Walks, Years)%>% explore_all(target = Salary)
# To be run latter!!!!
# hitters %>% explore()# open a new window where you can explore regression trees
# hitters %>% explore_all()# explore all variables, not significant in this dataset because of the large number of variables
hitters %>% explore(Salary)
hitters %>% explore(Years)
hitters %>% explore(Hits)
Now let’s grow a decision tree by first obtaining a summary statistics from variable salary. Below we are doing some correlations to understand the relationship of our dependent variable Salary with other numerical variables. This will help the tree to start split based on theh first most correlated variable with Salary. Let’s see :
summary(hitters$Salary)
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
67.5 190.0 425.0 535.9 750.0 2460.0 59
hitters.1=na.omit(hitters)
# use some correlation libraries to obtain a better view of the correlations
library(corrplot)
library("PerformanceAnalytics")
chart.Correlation(hitters.1[,c(1,2,5,6,7:9,13,19)],histogram=TRUE, pch=19)
summary(hitters$CHits)
Min. 1st Qu. Median Mean 3rd Qu. Max.
4.0 209.0 508.0 717.6 1059.2 4256.0
summary(hitters$CAtBat)
Min. 1st Qu. Median Mean 3rd Qu. Max.
19.0 816.8 1928.0 2648.7 3924.2 14053.0
summary(hitters$Walks)
Min. 1st Qu. Median Mean 3rd Qu. Max.
0.00 22.00 35.00 38.74 53.00 105.00
# the split starts with those highly correlated variables with the target variable.
hitters%>% explain_tree(target=Salary)
rpart() works similar as the linear model lm()
# library(rpart)
mod.1<-rpart(Salary~.,data=hitters,method="anova")# using anova because we have numerical variables
mod.1
n=263 (59 observations deleted due to missingness)
node), split, n, deviance, yval
* denotes terminal node
1) root 263 53319110.0 535.9259
2) CHits< 450 117 5931094.0 227.8547
4) Walks>=10 110 1754378.0 207.4470
8) CRBI< 114.5 72 284426.4 141.6343 *
9) CRBI>=114.5 38 567215.0 332.1447 *
5) Walks< 10 7 3410996.0 548.5476 *
3) CHits>=450 146 27385210.0 782.8048
6) Walks< 61 104 9469906.0 649.6232
12) AtBat< 395.5 53 2859476.0 510.0157 *
13) AtBat>=395.5 51 4503956.0 794.7054
26) PutOuts< 709 44 2358329.0 746.3631 *
27) PutOuts>=709 7 1396458.0 1098.5710 *
7) Walks>=61 42 11502830.0 1112.5880
14) RBI< 73.5 22 3148182.0 885.2651
28) PutOuts>=239.5 15 656292.3 758.8889 *
29) PutOuts< 239.5 7 1738973.0 1156.0710 *
15) RBI>=73.5 20 5967231.0 1362.6430
30) CRuns< 788 9 581309.7 1114.4440 *
31) CRuns>=788 11 4377879.0 1565.7150 *
We observe we have started with 263 observations ( 59 of them deleted because we have missing values in salary). We start with the root which is our variable “Salary”, then we split the data first on “CHits” variable root and the start the sub-roots and sub-sub-roots and so on.
The * correspond to the leaves (the end points down of the tree). Branches that lead to terminal nodes are indicated using asterisks (*).
Let us observe the tree visualizing it.
rpart.plot(mod.1, type = 3, digits=3, fallen.leaves = T)
So, how we interpret this tree? Starting from an individual(player), if CHits <450 then , we consider Walks>=10, then CRBI<115 AND at the end the 27.4% of the individuals with the above characteristics have a predicted Salary of 142 (thousand $).
Observe the last row of the output, the sum of the percentage is 100%.
Can you find the predicted Salary for a player with CHits=480 & Walks=64 & RBI=76 & CRuns =720? (Follow the branches of the tree. Ans=1114)
The predictions are the average values of the Salary of those individuals (players) which fall in that group (which is a subset of the total observations).
We can check the importance each variable plays into the dependent variable using the following code. The highest value of importance the greater the importance of that variable.
mod.1$variable.importance
CHits CAtBat CRuns CRBI CWalks Years Walks Runs
22137035 21027447 20807517 20128145 18963586 14675213 8210782 5117522
RBI HmRun AtBat CHmRun Hits PutOuts
4513820 3697913 3270236 3148716 2145516 1502086
What if we already know which variables to take into consideration? In this situation we will go through these steps:
reg.tree <- rpart(Salary ~ Years + Hits, data = hitters)
rpart.plot(reg.tree, type=3, digits=3, fallen.leaves = T) # you may try type=4
What is the predicted Salary of a player with characteristics: a) Years=6 and Hits= 150? b) Years=3 and Hits= 50?
We can interpret the above regression tree as follows: because the split starts with “Years” (i.e. experience)it means that this is the most important factor in determining a player’s salary .
“Hits” also plays a factor in determining a player’s salary (after all, we did include it in our equation).
We can check the importance each variable plays into the dependent variable using the following code.
reg.tree$variable.importance
Years Hits
15696202 14259635
It is evident that Years plays a relatively more important role than Hits in determining a player’s salary (at least based on our simple example).
##Building a Regression Tree
When we work on regression trees it is important to have a train and test dataset.
library(MASS)
Attaching package: 㤼㸱MASS㤼㸲
The following object is masked _by_ 㤼㸱.GlobalEnv㤼㸲:
area
The following object is masked from 㤼㸱package:dplyr㤼㸲:
select
library(rpart)
set.seed(512)
train <- sample(1:nrow(hitters), nrow(hitters)/2)
mod.2 <- rpart(Salary ~ Hits + Runs + RBI + Walks + Years, subset = train, data = hitters)
summary(mod.2)
Call:
rpart(formula = Salary ~ Hits + Runs + RBI + Walks + Years, data = hitters,
subset = train)
n=134 (27 observations deleted due to missingness)
CP nsplit rel error xerror xstd
1 0.28936005 0 1.0000000 1.0148198 0.1861855
2 0.15263221 1 0.7106399 0.8452857 0.1359201
3 0.08426616 2 0.5580077 0.8508802 0.1465145
4 0.07136647 3 0.4737416 0.7974691 0.1445832
5 0.04274346 4 0.4023751 0.8546806 0.1566366
6 0.03257815 5 0.3596316 0.7886644 0.1554360
7 0.01614981 6 0.3270535 0.6770736 0.1318538
8 0.01000000 7 0.3109037 0.6743197 0.1318870
Variable importance
Hits Runs Years RBI Walks
26 24 17 16 16
Node number 1: 134 observations, complexity param=0.2893601
mean=571.3543, MSE=236102.3
left son=2 (72 obs) right son=3 (62 obs)
Primary splits:
Hits < 117.5 to the left, improve=0.2893601, (0 missing)
RBI < 51.5 to the left, improve=0.2566073, (0 missing)
Years < 4.5 to the left, improve=0.2362867, (0 missing)
Walks < 67 to the left, improve=0.2077635, (0 missing)
Runs < 55.5 to the left, improve=0.2066041, (0 missing)
Surrogate splits:
Runs < 55.5 to the left, agree=0.903, adj=0.790, (0 split)
RBI < 43.5 to the left, agree=0.806, adj=0.581, (0 split)
Walks < 39.5 to the left, agree=0.761, adj=0.484, (0 split)
Years < 4.5 to the left, agree=0.582, adj=0.097, (0 split)
Node number 2: 72 observations, complexity param=0.04274346
mean=328.8056, MSE=52471.68
left son=4 (37 obs) right son=5 (35 obs)
Primary splits:
Years < 6.5 to the left, improve=0.35794580, (0 missing)
RBI < 44 to the left, improve=0.23447110, (0 missing)
Walks < 21.5 to the left, improve=0.12034070, (0 missing)
Hits < 76.5 to the left, improve=0.09738826, (0 missing)
Runs < 36 to the left, improve=0.09266866, (0 missing)
Surrogate splits:
Runs < 22.5 to the right, agree=0.611, adj=0.200, (0 split)
Hits < 111 to the right, agree=0.569, adj=0.114, (0 split)
RBI < 16 to the left, agree=0.542, adj=0.057, (0 split)
Walks < 11.5 to the left, agree=0.542, adj=0.057, (0 split)
Node number 3: 62 observations, complexity param=0.1526322
mean=853.0238, MSE=301694.5
left son=6 (9 obs) right son=7 (53 obs)
Primary splits:
Years < 3.5 to the left, improve=0.2581619, (0 missing)
Walks < 69.5 to the left, improve=0.2346518, (0 missing)
Hits < 170.5 to the left, improve=0.1447259, (0 missing)
Runs < 96.5 to the left, improve=0.1421208, (0 missing)
RBI < 51 to the left, improve=0.1055367, (0 missing)
Node number 4: 37 observations
mean=195.5135, MSE=12916.11
Node number 5: 35 observations, complexity param=0.03257815
mean=469.7143, MSE=55650.27
left son=10 (23 obs) right son=11 (12 obs)
Primary splits:
RBI < 43.5 to the left, improve=0.52917120, (0 missing)
Hits < 76.5 to the left, improve=0.29975240, (0 missing)
Walks < 21 to the left, improve=0.27715470, (0 missing)
Runs < 36 to the left, improve=0.24992870, (0 missing)
Years < 12.5 to the left, improve=0.04763565, (0 missing)
Surrogate splits:
Runs < 43.5 to the left, agree=0.857, adj=0.583, (0 split)
Hits < 76.5 to the left, agree=0.829, adj=0.500, (0 split)
Walks < 34.5 to the left, agree=0.829, adj=0.500, (0 split)
Years < 7.5 to the right, agree=0.714, adj=0.167, (0 split)
Node number 6: 9 observations
mean=175.7778, MSE=717.0617
Node number 7: 53 observations, complexity param=0.08426616
mean=968.0278, MSE=261691.9
left son=14 (45 obs) right son=15 (8 obs)
Primary splits:
Runs < 96.5 to the left, improve=0.19221710, (0 missing)
Walks < 69.5 to the left, improve=0.19113900, (0 missing)
RBI < 80.5 to the left, improve=0.16038410, (0 missing)
Hits < 181 to the left, improve=0.11340640, (0 missing)
Years < 7.5 to the left, improve=0.03062666, (0 missing)
Surrogate splits:
Hits < 185 to the left, agree=0.943, adj=0.625, (0 split)
RBI < 97 to the left, agree=0.906, adj=0.375, (0 split)
Node number 10: 23 observations
mean=345.7609, MSE=22286.65
Node number 11: 12 observations
mean=707.2917, MSE=33705.69
Node number 14: 45 observations, complexity param=0.07136647
mean=873.463, MSE=199386
left son=28 (34 obs) right son=29 (11 obs)
Primary splits:
Walks < 70 to the left, improve=0.25164720, (0 missing)
RBI < 80.5 to the left, improve=0.10844460, (0 missing)
Years < 6.5 to the left, improve=0.09938133, (0 missing)
Hits < 151.5 to the right, improve=0.03744486, (0 missing)
Runs < 76.5 to the right, improve=0.02262225, (0 missing)
Node number 15: 8 observations
mean=1499.955, MSE=278914.3
Node number 28: 34 observations, complexity param=0.01614981
mean=746.0539, MSE=88308.26
left son=56 (12 obs) right son=57 (22 obs)
Primary splits:
Years < 6.5 to the left, improve=0.17017360, (0 missing)
Walks < 46.5 to the right, improve=0.06033537, (0 missing)
RBI < 80.5 to the left, improve=0.04926630, (0 missing)
Runs < 81 to the right, improve=0.04196110, (0 missing)
Hits < 143.5 to the right, improve=0.02594435, (0 missing)
Surrogate splits:
Hits < 174 to the right, agree=0.765, adj=0.333, (0 split)
Runs < 78.5 to the right, agree=0.706, adj=0.167, (0 split)
Walks < 64 to the right, agree=0.706, adj=0.167, (0 split)
Node number 29: 11 observations
mean=1267.273, MSE=337456.2
Node number 56: 12 observations
mean=580.0694, MSE=42897.95
Node number 57: 22 observations
mean=836.5909, MSE=89852.84
library(rpart.plot)
rpart.plot(mod.2, type=3, digits = 3)
The most important factors in determining salary (according to the features used in our model) are
mod.2$variable.importance
Hits Runs Years RBI Walks
11661145 10858006 7749903 7423345 7365341
Now lets suppose we have a testing dataset and we want to predict using the model above (mod.2)
# testing dataset
test<-hitters[1:10,c(2,4:7)]
test
pred.2<-predict(mod.2,test)
pred.2 # these are the predicted salary for the first 10 individuals using the values of other variables used in the model (mod.2)
-Andy Allanson -Alan Ashby -Alvin Davis -Andre Dawson
195.5135 345.7609 175.7778 836.5909
-Andres Galarraga -Alfredo Griffin -Al Newman -Argenis Salazar
195.5135 836.5909 195.5135 195.5135
-Andres Thomas -Andre Thornton
195.5135 707.2917
real.sal<-hitters$Salary[1:10]# these are the real Salary for those individuals
real.sal
[1] NA 475.0 480.0 500.0 91.5 750.0 70.0 100.0 75.0 1100.0
How to measure teh accuracy of the algorithm? One way to do this is by calculating some errors such as: MAE (Mean Absolute Error) or MSE (Mean Squared Error)
MAE.mod.2<- mean(abs(pred.2-real.sal),na.rm=T)# remove missing values
MAE.mod.2
[1] 188.3228
RMSE.mod.2<-sqrt(mean((pred.2-real.sal)^2,na.rm=T))
RMSE.mod.2
[1] 219.718
Try to fit other models and calculate errors.
The library(tree) is used to construct classification and regression trees.
# create a vector of random variable from uniform distribution (0,1)
set.seed(512)
v<-runif(nrow(hitters),0,1)
# order the dataset , to have a shuffle dataset
hitters.rand<-hitters[order(v),]
The syntax used is : rpart(target~predictors). We are going to construct a classification tree based on our target variable wich is “League” and our explanatory variables which are “Hits”, “Years”, “Runs”, “Walks”.
mod.3<-rpart(League~Hits+Years+Runs+Walks, data=hitters.rand[1:250,],method = "class")
mod.3
n= 250
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 250 112 A (0.55200000 0.44800000)
2) Hits>=100.5 122 44 A (0.63934426 0.36065574)
4) Hits< 103.5 8 0 A (1.00000000 0.00000000) *
5) Hits>=103.5 114 44 A (0.61403509 0.38596491)
10) Runs>=73.5 55 16 A (0.70909091 0.29090909)
20) Runs< 88.5 27 4 A (0.85185185 0.14814815) *
21) Runs>=88.5 28 12 A (0.57142857 0.42857143)
42) Runs>=97.5 11 2 A (0.81818182 0.18181818) *
43) Runs< 97.5 17 7 N (0.41176471 0.58823529) *
11) Runs< 73.5 59 28 A (0.52542373 0.47457627)
22) Hits< 131.5 37 14 A (0.62162162 0.37837838)
44) Walks>=47.5 11 1 A (0.90909091 0.09090909) *
45) Walks< 47.5 26 13 A (0.50000000 0.50000000)
90) Hits>=111 19 8 A (0.57894737 0.42105263) *
91) Hits< 111 7 2 N (0.28571429 0.71428571) *
23) Hits>=131.5 22 8 N (0.36363636 0.63636364)
46) Hits>=149.5 7 2 A (0.71428571 0.28571429) *
47) Hits< 149.5 15 3 N (0.20000000 0.80000000) *
3) Hits< 100.5 128 60 N (0.46875000 0.53125000)
6) Walks< 13.5 20 5 A (0.75000000 0.25000000) *
7) Walks>=13.5 108 45 N (0.41666667 0.58333333)
14) Years>=4.5 71 34 N (0.47887324 0.52112676)
28) Runs>=48.5 7 1 A (0.85714286 0.14285714) *
29) Runs< 48.5 64 28 N (0.43750000 0.56250000)
58) Walks< 23.5 29 13 A (0.55172414 0.44827586)
116) Hits< 70.5 21 8 A (0.61904762 0.38095238) *
117) Hits>=70.5 8 3 N (0.37500000 0.62500000) *
59) Walks>=23.5 35 12 N (0.34285714 0.65714286) *
15) Years< 4.5 37 11 N (0.29729730 0.70270270)
30) Years< 3.5 26 11 N (0.42307692 0.57692308)
60) Runs< 33 15 6 A (0.60000000 0.40000000) *
61) Runs>=33 11 2 N (0.18181818 0.81818182) *
31) Years>=3.5 11 0 N (0.00000000 1.00000000) *
# we are using only 250 observations to let the other as a testing dataset to evaluate our model
# method="class" notify that our target variable is a chategorical variable
Let us visualize the tree
rpart.plot(mod.3, type=3, digits = 3)
library(tree)
mod.4<-tree(League~Hits+Years+Runs+Walks, data=hitters.rand[1:250,],method = "class")
mod.4
node), split, n, deviance, yval, (yprob)
* denotes terminal node
1) root 250 343.900 A ( 0.55200 0.44800 )
2) Hits < 100.5 128 176.900 N ( 0.46875 0.53125 )
4) Walks < 13.5 20 22.490 A ( 0.75000 0.25000 ) *
5) Walks > 13.5 108 146.700 N ( 0.41667 0.58333 ) *
3) Hits > 100.5 122 159.500 A ( 0.63934 0.36066 )
6) Hits < 103.5 8 0.000 A ( 1.00000 0.00000 ) *
7) Hits > 103.5 114 152.100 A ( 0.61404 0.38596 )
14) Runs < 73.5 59 81.640 A ( 0.52542 0.47458 )
28) Hits < 131.5 37 49.080 A ( 0.62162 0.37838 )
56) Walks < 47.5 26 36.040 N ( 0.50000 0.50000 ) *
57) Walks > 47.5 11 6.702 A ( 0.90909 0.09091 ) *
29) Hits > 131.5 22 28.840 N ( 0.36364 0.63636 )
58) Hits < 149.5 15 15.010 N ( 0.20000 0.80000 )
116) Walks < 48.5 8 10.590 N ( 0.37500 0.62500 ) *
117) Walks > 48.5 7 0.000 N ( 0.00000 1.00000 ) *
59) Hits > 149.5 7 8.376 A ( 0.71429 0.28571 ) *
15) Runs > 73.5 55 66.330 A ( 0.70909 0.29091 )
30) Runs < 88.5 27 22.650 A ( 0.85185 0.14815 )
60) Hits < 129 7 9.561 A ( 0.57143 0.42857 ) *
61) Hits > 129 20 7.941 A ( 0.95000 0.05000 ) *
31) Runs > 88.5 28 38.240 A ( 0.57143 0.42857 )
62) Runs < 97.5 17 23.030 N ( 0.41176 0.58824 )
124) Years < 4.5 5 0.000 N ( 0.00000 1.00000 ) *
125) Years > 4.5 12 16.300 A ( 0.58333 0.41667 ) *
63) Runs > 97.5 11 10.430 A ( 0.81818 0.18182 ) *
summary(mod.4)
Classification tree:
tree(formula = League ~ Hits + Years + Runs + Walks, data = hitters.rand[1:250,
], method = "class")
Number of terminal nodes: 13
Residual mean deviance: 1.161 = 275.1 / 237
Misclassification error rate: 0.32 = 80 / 250
here we see that the training error is 0.32% We use the plot() function to display the tree structure, and the text() function to display the node labels. The argument pretty=0 instructs R to include the category names for any qualitative predictors, rather than simply displaying a letter for each category.
plot(mod.4)
text(mod.4,pretty =0)
The most important indicator of “League” appears to be “Hits” , since the first branch differentiates Hits<100.5 and >=100.5.
In order to properly evaluate the performance of a classification tree on these data, we must estimate the test error rather than simply computing the training error. We split the observations into a training set and a test set, build the tree using the training set, and evaluate its performance on the test data.
The predict() function can be used for this purpose. In the case of a classification tree, the argument type=“class” instructs R to return the actual class prediction.
Next, we consider whether pruning the tree might lead to improved results.
The function cv.tree() performs cross-validation in order to cv.tree() determine the optimal level of tree complexity; cost complexity pruning is used in order to select a sequence of trees for consideration.
We use the argument FUN=prune.misclass in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default for the cv.tree() function, which is deviance.
The cv.tree() function reports the number of terminal nodes of each tree considered (size) as well as the corresponding error rate and the value of the cost-complexity parameter used (k).
set.seed(34)
cv.hitters =cv.tree(mod.4,FUN=prune.misclass)
names(cv.hitters )
[1] "size" "dev" "k" "method"
cv.hitters
$size
[1] 13 10 7 3 1
$dev
[1] 109 109 107 109 112
$k
[1] -Inf 0.000000 1.666667 2.250000 9.000000
$method
[1] "misclass"
attr(,"class")
[1] "prune" "tree.sequence"
Note that, despite the name, dev corresponds to the cross-validation error rate in this instance. The tree with 7 terminal nodes results in the lowest cross-validation error rate, with 107 cross-validation errors. We plot the error rate as a function of both size and k.
par(mfrow=c(1,2))
plot(cv.hitters$size ,cv.hitters$dev ,type="b")
plot(cv.hitters$k ,cv.hitters$dev ,type="b")
We now apply the prune.misclass() function in order to prune the tree.
obtain the seven-node tree. prune.misclass()
prune.hitters =prune.misclass (mod.4 ,best=7)
plot(prune.hitters)
text(prune.hitters ,pretty =0)
Try to calculate the errors for this model prune.hitters , see above for the calculations.
If we increase the value of argument “best”, we obtain a larger pruned tree with lower classification accuracy (you may check it):
prune.2 =prune.misclass (mod.4,best=13)
plot(prune.2)
text(prune.2 ,pretty =0)
create Train and Test data
library(tree)
set.seed(1013)
head(hitters,10)
train.h<-sample(dim(hitters)[1],250)
hitters.train<-hitters[train.h,]
hitters.test<-hitters[-train.h,]
hitt.tree<-tree(League~.,data=hitters.train)
summary(hitt.tree)
Classification tree:
tree(formula = League ~ ., data = hitters.train)
Variables actually used in tree construction:
[1] "NewLeague" "Hits" "Walks" "Assists" "CRBI" "PutOuts"
[7] "CHmRun"
Number of terminal nodes: 11
Residual mean deviance: 0.204 = 40.19 / 197
Misclassification error rate: 0.05288 = 11 / 208
We see that only 7 variables were used in the tree construction. There are 11 terminal nodes.Missclasification error rate is 5% which is good.
hitt.tree
node), split, n, deviance, yval, (yprob)
* denotes terminal node
1) root 208 287.900 A ( 0.52404 0.47596 )
2) NewLeague: A 111 57.490 A ( 0.92793 0.07207 )
4) Hits < 84.5 41 37.480 A ( 0.82927 0.17073 )
8) Walks < 20.5 19 0.000 A ( 1.00000 0.00000 ) *
9) Walks > 20.5 22 27.520 A ( 0.68182 0.31818 )
18) Assists < 4.5 6 0.000 A ( 1.00000 0.00000 ) *
19) Assists > 4.5 16 21.930 A ( 0.56250 0.43750 )
38) CRBI < 140 6 5.407 A ( 0.83333 0.16667 ) *
39) CRBI > 140 10 13.460 N ( 0.40000 0.60000 ) *
5) Hits > 84.5 70 10.480 A ( 0.98571 0.01429 )
10) PutOuts < 89.5 9 6.279 A ( 0.88889 0.11111 ) *
11) PutOuts > 89.5 61 0.000 A ( 1.00000 0.00000 ) *
3) NewLeague: N 97 45.020 N ( 0.06186 0.93814 )
6) Assists < 50.5 45 35.340 N ( 0.13333 0.86667 )
12) CHmRun < 3.5 6 8.318 A ( 0.50000 0.50000 ) *
13) CHmRun > 3.5 39 21.150 N ( 0.07692 0.92308 )
26) CRBI < 550.5 29 0.000 N ( 0.00000 1.00000 ) *
27) CRBI > 550.5 10 12.220 N ( 0.30000 0.70000 )
54) CRBI < 830 5 6.730 A ( 0.60000 0.40000 ) *
55) CRBI > 830 5 0.000 N ( 0.00000 1.00000 ) *
7) Assists > 50.5 52 0.000 N ( 0.00000 1.00000 ) *
plot(hitt.tree)
text(hitt.tree, pretty = 0)
CRBI is the most important variable of teh tree. Let now predict teh response on the test data, and produce a confuzion matrix comparing the test labels to the predicted test labels.
hitt.pred = predict(hitt.tree, hitters.test, type = "class")
table(hitters.test$League, hitt.pred)
hitt.pred
A N
A 37 5
N 3 27
# calculate
(37+27)/72
[1] 0.8888889
We obtain a classification rate of 87.5%, in other words, a test error rate of (1-0.889=11.1%)
Now we will apply the cv.tree() function to thetraining set in order to determine the optimal tree size.
cv.hitt= cv.tree(hitt.tree, FUN = prune.tree)
cv.hitt
$size
[1] 11 10 9 8 7 5 4 2 1
$dev
[1] 213.1705 213.1705 189.5949 186.9586 180.9159 176.6814 118.7734 118.6956
[9] 290.5079
$k
[1] -Inf 3.063087 4.203659 5.487169 5.591565 7.402825 9.675684
[8] 9.743106 185.361418
$method
[1] "deviance"
attr(,"class")
[1] "prune" "tree.sequence"
the optimal tree size is 4 or 2 terminal (size) nodes with a cross-validation error of 161.0751
Cross-validation plot of the error rate as a function of both tree size and k:
par(mfrow=c(1,2))
plot(cv.hitt$size, cv.hitt$dev, type = "b", xlab = "Tree Size", ylab = "Deviance")
plot(cv.hitt$size, cv.hitt$k, type = "b", xlab = "Cost-Complexity Parameter", ylab = "Deviance")
Wesee that a tree size of 2-4 have the same cross-validation error (see graph 2) and they do not differ to much from those of greater size.
hitt.pruned.4 = prune.tree(hitt.tree, best = 4)
summary(hitt.pruned)
Classification tree:
snip.tree(tree = hitt.tree, nodes = c(5L, 9L, 3L))
Variables actually used in tree construction:
[1] "NewLeague" "Hits" "Walks"
Number of terminal nodes: 4
Residual mean deviance: 0.407 = 83.02 / 204
Misclassification error rate: 0.06731 = 14 / 208
hitt.pruned.2 = prune.tree(hitt.tree, best = 2)
summary(hitt.pruned.2)
Classification tree:
snip.tree(tree = hitt.tree, nodes = 3:2)
Variables actually used in tree construction:
[1] "NewLeague"
Number of terminal nodes: 2
Residual mean deviance: 0.4976 = 102.5 / 206
Misclassification error rate: 0.06731 = 14 / 208
plot(hitt.pruned.4)
text(hitt.pruned.4, pretty = 0)
plot(hitt.pruned.2)
text(hitt.pruned.2, pretty = 0)
Now how do the test errors for the unpruned and pruned trees compare:
pred.unpruned = predict(hitt.tree, hitters.test, type = "class")
misclass.unpruned = sum(hitters.test$League!= pred.unpruned)
misclass.unpruned/length(pred.unpruned)
[1] 0.1388889
pred.pruned = predict(hitt.pruned, hitters.test, type = "class")
misclass.pruned = sum(hitters.test$League!=pred.pruned)
misclass.pruned/length(pred.pruned)
[1] 0.05555556
We see that the test error rates between the unpruned and pruned trees are exactly respectively 9.7% and 5.5%. We lower the error rate after pruning the tree. Which suggest that the pruned tree can perform better for prediction.
Reference:
https://rstudio-pubs-static.s3.amazonaws.com/330056_afe6136844804ec3a00586ccb73bc554.html
https://rdrr.io/cran/ISLR/man/Hitters.html
https://www.youtube.com/watch?v=MoBw5PiW56k
https://cran.r-project.org/web/packages/explore/vignettes/explore_mtcars.html
https://uc-r.github.io/regression_trees