Decision tree

  1. The R package for tree method is tree. We will also use the ISLR package to access the real data.
library(ISLR)
library(tree)

We use the Hitters data to illustrate the tree model. In the following codes, the ‘dim()’ gives the number of rows and number of columns in the data, the function ‘names()’ outputs the column names (the names of the predictors) of the data.

  1. Predict Salary Dimensions: 322 Columns: 20 - only some will be used as predictors Y: Salary
Hitters[1:5,] #first five row 
dim(Hitters)
## [1] 322  20
names(Hitters)
##  [1] "AtBat"     "Hits"      "HmRun"     "Runs"      "RBI"      
##  [6] "Walks"     "Years"     "CAtBat"    "CHits"     "CHmRun"   
## [11] "CRuns"     "CRBI"      "CWalks"    "League"    "Division" 
## [16] "PutOuts"   "Assists"   "Errors"    "Salary"    "NewLeague"
  1. Summary

We will first use the log of ‘Salary’ as the response and the ‘Years’ and ‘Hits’ as the two predictors. The syntax of the ‘tree()’ function is quite similar to that of the ‘lm()’ function. The ‘summary()’ function provides all the summary information of the tree outputs.

mytree=tree(log(Salary)~ Years+Hits, data=Hitters) #log because salary is a huge number 
summary(mytree) #SAVED NAME MY TREE
## 
## Regression tree:
## tree(formula = log(Salary) ~ Years + Hits, data = Hitters)
## Number of terminal nodes:  8 
## Residual mean deviance:  0.2708 = 69.06 / 255 
## Distribution of residuals:
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
## -2.2400 -0.2980 -0.0365  0.0000  0.3233  2.1520

The above summary information includes number of terminal nodes, which is 8 in our case, and the Residual mean deviance, which is the residual sum of squares (RSS) in the regression tree.

  1. PLOT with 8 nodes

One of the most attractive properties of trees is that they can be graphically displayed. We use the ‘plot()’ function to display the tree structure, the ‘text()’ function to display the node labels, and the ‘title()’ function to add a title of your plot.

plot(mytree)
text(mytree)
title("This is a tree!")

At a given internal node, the label of the form Xj < c (Years< 4.5) indicates the left-hand branch issued from that split, and the right-hand branch corresponds to Xj > c (Years > 4.5). The above tree has 8 terminal nodes. Based on above tree plot, ‘Years’ is the most important factor in determining ‘Salary’, and players with less experience earn lower salaries. To predict the ‘log(Salary)’ of a new payer, we only need to check which region this new player belongs to. For instance, a new player with ‘Years’ < 3.5 and ‘Hits’ < 40.5 will have a predicted ‘log(Salary)’ 5.511.

  1. PRUNE TREE - best -> number of terminal nodes

If we wish to prune the tree, we could do so as follows, using the ‘prune.tree()’ function. The first argument is the tree fit, the second argument ‘best=3’ says we would like to keep only 3 terminal nodes. We then plot the pruned tree. As you can see, this pruned tree is a subtree of the previous full-size tree. This tree is the one shown on page 2 of our lecture note.

prune.mytree = prune.tree(mytree, best=3)
plot(prune.mytree)
text(prune.mytree)
title("This is a pruned tree with 3 terminal nodes!")

  1. Next, we consider the Hitters data to fit a tree model with all nine features to predict the ‘log(Salary)’. We randomly split the data into half for training set and half for test set. Build a tree model on the training set, and make prediction on the test set.
set.seed(1)
train = sample(1:nrow(Hitters), nrow(Hitters)/2)
tree.Hitters=tree(log(Salary)~., data=Hitters, subset=train)
summary(tree.Hitters)
## 
## Regression tree:
## tree(formula = log(Salary) ~ ., data = Hitters, subset = train)
## Variables actually used in tree construction:
## [1] "CHits"   "Hits"    "Walks"   "RBI"     "PutOuts" "Years"   "Assists"
## [8] "CAtBat" 
## Number of terminal nodes:  11 
## Residual mean deviance:  0.185 = 22.02 / 119 
## Distribution of residuals:
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## -1.891000 -0.198100  0.005742  0.000000  0.224700  2.217000
plot(tree.Hitters)
text(tree.Hitters)
title("This is a tree with all features!")

The output of ‘summary()’ now includes the variables actually used in tree construction, as well as tree size and RSS as we discussed before. My result is slightly different from the one shown in the textbook, which is because I do not know which seed they used in their tree model fit. Using a different seed will lead to a different result. This refers to a drawback of the single tree model that a single tree is not stable. We will discuss this point again later on.

  1. SELECT size - dev - cv error (choose one with small cross val error)

Next, we use the cross-validation function ‘cv.tree()’ to determine the optimal level of tree complexity, i.e., the best tree size; cost complexity pruning is used in order to select a sequence of trees for consideration. The ‘cv.tree()’ function reports the number of terminal nodes of each tree considered (size) as well as the corresponding error rate (dev) and the value of the cost-complexity parameter used (k, which corresponds to \(\alpha\) in the lecture note.

set.seed(2)
cv.tree.Hitters=cv.tree(tree.Hitters) #SELECT BEST SIZE OF TREE --- 8  
cv.tree.Hitters # alpha is k 
## $size
##  [1] 11 10  9  8  6  5  4  3  2  1
## 
## $dev
##  [1]  48.79574  49.43083  49.59175  47.49789  47.94292  47.86271  49.23235
##  [8]  52.78509  54.33486 103.58266
## 
## $k
##  [1]      -Inf  1.036220  1.242826  1.301227  1.477342  1.945330  3.045800
##  [8]  6.016269  7.682035 54.859249
## 
## $method
## [1] "deviance"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

Note that, despite the name, ‘dev’ corresponds to the cross-validation error rate in this instance. We then plot the error rate as a function of the size.

plot(cv.tree.Hitters$size,cv.tree.Hitters$dev,type="b") 

#plot(cv.tree.Hitters$size,cv.tree.Hitters$dev,type="l") 
#plot(cv.tree.Hitters$size,cv.tree.Hitters$dev,type="o")
# x axis, y axis, type - line & dot) 
  1. USING 8 BASED ON CROSS VALIDATION STARTING WITH HITS

We now apply the ‘prune.tree()’ function in order to prune the tree to obtain the best sized tree chosen by CV.

prune.cv.mytree = prune.tree(tree.Hitters, best=8)
plot(prune.cv.mytree)
text(prune.cv.mytree)
title("Pruned tree with tree size tuned by CV!")

Finally, we can make prediction using this pruned tree. The function ‘predict()’ is similar to the one we used in regression models. Remind that we use the ‘log(Salary)’ as the response. Therefore, the y.test is computed by taking log. The ‘plot()’ uses predicted value in x-axis, and the true response in y-axis. The ‘abline(0,1)’ adds a diagonal line to plot. The case when all points are on the diagonal line indicates the perfect prediction. We can compute the mean square error (MSE) of the prediction. Since some salary in the data are missing, we will use ‘na.rm = TRUE’ to remove those mossing values when we compute the MSE.

yhat=predict(prune.cv.mytree,newdata=Hitters[-train,]) #INDEX THAT IS NOT TRAIN 

#first 
y.test=log(Hitters[-train,"Salary"])
plot(yhat,y.test) #PLOT UR PREDICTED AND THE TRUE ONE 

  #XAXIS --- #Y AXIS 
abline(0,1) #

mean((yhat-y.test)^2, na.rm = TRUE) # CALCULATED ERROR 
## [1] 0.2668873
# YHAT AND YTEST ARE VECTORS -- SO YOU GET GET A COLUMN Y-HAT AND Y-TEST & THEN SQUARE IT 

NEXT EXAMPLE - 4 diff trees for diff sample size REMEBER - USING DIFF SIZE

However, a single tree is not stable. We next plot the fitted tree model using 4 different training sizes, 100, 150, 200, and all data.

par(mfrow=c(2,2)) ## 4 figures in one plot
mytree = tree(Salary~ ., data=Hitters[1:100,])
plot(mytree);text(mytree);title("Tree: first 100 data")
mytree = tree(Salary~ ., data=Hitters[1:150,])
plot(mytree);text(mytree);title("Tree: first 150 data")
mytree = tree(Salary~ ., data=Hitters[1:200,])
plot(mytree);text(mytree);title("Tree: first 200 data")
mytree = tree(Salary~ ., data=Hitters)
plot(mytree);text(mytree);title("Tree: all data")

As you can see from the output, these 4 tree outputs provide very different results. It is not ideal if a tree output provides totally different results if we make a small change to the data. This is a major drawback of the single tree model. In the next, we will discuss other tree methods to solve this issue.

Bagging and random forests

We will use the ‘randomForest’ package for both bagging and random forests methods. Its syntax is essentially the same as ‘tree()’. The package is not able to handle missing values, so we remove the NAs in the data first. In the ‘randomForest()’ function, an important argument is ‘mtry’, which is the number of features randomly sampled in random forests. Default values are mtry=\(\sqrt{p}\) for classification and mtry=\(p/3\) for regression. The Bagging is just a special case of random forest with mtry = \(p\).

library(randomForest);  ##install.packages("randomForest")
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
Hitters=na.omit(Hitters) ## remove missing data 
p = ncol(Hitters)-1 #number of predictor (20 columns - 1 as a response, 19 as p)


dim(Hitters) #you removed 100 or so rows 
## [1] 263  20
#DIFFERENCE IS IN TERMS OF SIZE 
# IF M TRY = P -> BAGGING 
# IF M TRY < P -> RF 
#BAGGING: Uses all p predictors 
set.seed(1)
mybagging = randomForest(log(Salary) ~ ., mtry=p, data=Hitters)
mybagging
## 
## Call:
##  randomForest(formula = log(Salary) ~ ., data = Hitters, mtry = p) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 19
## 
##           Mean of squared residuals: 0.1911709
##                     % Var explained: 75.73
#RANDOM FOREST: randomly select m predictors to spilt the tree 
#& for each 

# ---> p/3 = 19/3 = 6 
myrf = randomForest(log(Salary) ~ ., data=Hitters) 
myrf
## 
## Call:
##  randomForest(formula = log(Salary) ~ ., data = Hitters) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 6
## 
##           Mean of squared residuals: 0.1837426
##                     % Var explained: 76.67

The output of bagging and random forests, gives you the number of variables tried at each split, the Mean of squared residuals, as well as the percentage of variance explained by the model.

Bagging/RF can also evaluate each predictor’s performance. Using the ‘importance()’ function, we can view the importance of each variable.

#Evaluate 'importance()' function, we can view the importance of each variable. 
#make sure to set importance=TRUE

#First column: 
#Exclude: IncNodePurity
myrf=randomForest(Salary ~.,data=Hitters,importance=TRUE)
importance(myrf)
##              %IncMSE IncNodePurity
## AtBat      9.9816613     2531073.8
## Hits       7.2597161     3102464.3
## HmRun      4.9337753     1257497.1
## Runs       7.3579674     2807166.8
## RBI        6.0627356     3724398.3
## Walks      5.0695969     3991651.5
## Years      7.5586982     1327721.5
## CAtBat    12.9679474     4762352.5
## CHits     13.0278963     6573898.3
## CHmRun    10.1691602     2844038.8
## CRuns     12.0430468     5566461.6
## CRBI      11.6116657     6046523.4
## CWalks     8.1917595     3116466.1
## League    -0.2225893      102073.6
## Division   1.5194866      164176.1
## PutOuts    2.7587697     2394619.1
## Assists    0.6373521      667391.8
## Errors     3.3421767      724359.8
## NewLeague  1.8512527      115823.6

Two different measures of variable importance are reported. The former is based upon the mean decrease of accuracy in predictions on the out of bag samples when a given variable is excluded from the model. The latter is a measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees. In the case of regression trees, the node impurity is measured by the training RSS, and for classification trees by the deviance.

Plots of these importance measures can be produced using the varImpPlot() function.

myrf=randomForest(Salary ~.,data=Hitters,importance=TRUE)
varImpPlot(myrf)

The results indicate that across all of the trees considered in the random forest, the ‘CHits’ is the most important predictor based on both criterions.

Next, we use another Boston Housing data set to compare all these three methods: single tree, bagging, random forests. Again, we use half data for training and half for testing. We set the seed for reproducibility.

library(tree);
library(ISLR);
library(MASS);
set.seed(1); 
train = sample(1:nrow(Boston), nrow(Boston)/2);
tree.boston=tree(medv~.,Boston,subset=train); #PREDICT MEDIAN PRICE 
summary(tree.boston);
## 
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "lstat" "rm"    "dis"  
## Number of terminal nodes:  8 
## Residual mean deviance:  12.65 = 3099 / 245 
## Distribution of residuals:
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## -14.10000  -2.04200  -0.05357   0.00000   1.96000  12.60000

Variables actually used in tree construction: “lstat” “rm” “dis Number of terminal nodes: 8

plot(tree.boston);
text(tree.boston)

lstat: % of individuals with lower socioeconomic status. The tree indicates that lower lstat implies more expensive houses. The tree predicts a median house price of $46,380 for larger homes with high socioeconomic status residents (rm > 7.437 and lstat< 9.715).

Next we use function ‘cv.tree()’ to perform cross-validation to determine the optimal level of tree complexity. It outputs number of terminal nodes of each tree (size), as well as the corresponding error rate (dev). The CV chooses the most complex tree, the one with size 8, i.e., the whole tree.

cv.boston=cv.tree(tree.boston)
cv.boston
## $size
## [1] 8 7 6 5 4 3 2 1
## 
## $dev
## [1]  5226.322  5228.360  6462.626  6692.615  6397.438  7529.846 11958.691
## [8] 21118.139
## 
## $k
## [1]      -Inf  255.6581  451.9272  768.5087  818.8885 1559.1264 4276.5803
## [8] 9665.3582
## 
## $method
## [1] "deviance"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"
plot(cv.boston$size,cv.boston$dev,type='b')

Next we make prediction using this tree model. The ‘predict()’ outputs the estimated median housing price. We then compute the prediction MSE.

yhat=predict(tree.boston,newdata=Boston[-train,])
boston.test=Boston[-train,"medv"]
plot(yhat,boston.test)
abline(0,1) 

mean((yhat-boston.test)^2) #MSE
## [1] 25.04559
#MSE Based on single tree 25.04559

BAGGING

Next we apply bagging to the same training data and then compute the prediction MSE on the same testing data. Since the ‘randomForest()’ function has randomness, we need to set the seed for reproducibility.

library(randomForest)
## Bagging
set.seed(1) 
bag.boston=randomForest(medv~.,data=Boston,subset=train,
    mtry=13,importance=TRUE) #NUMBER OF PREDICT IN NEW DATA SET 
yhat.bag = predict(bag.boston,newdata=Boston[-train,])
mean((yhat.bag-boston.test)^2)
## [1] 13.50808
###bagging MSE 13.50808, half of the og one 

Comparing this plot with the one from the single tree, the estimated median housing price of bagging is closer to the true price. The MSE of bagging is almost half that obtained using a single tree!

Next we apply random forests to the same data tom compute the prediction MSE. Remind that Random forest uses \(p/3\) (by default) variables for regression trees.

set.seed(1)
rf.boston=randomForest(medv~.,data=Boston,subset=train,
    mtry=5,importance=TRUE) #4 or 5 



yhat.rf = predict(rf.boston,newdata=Boston[-train,])
mean((yhat.rf-boston.test)^2)
## [1] 10.94181
#11.66467

The random forest gives the best prediction accuracy in the Boston data.

Boosting

Here we use the ‘gbm’ package to fit boosted regression trees to the Boston data set. We run ‘gbm()’ function with the option distribution=“gaussian” since this is a regression problem; if it were a binary classification problem, we would use distribution=“bernoulli”. The argument n.trees=10000 indicates that we want 10000 trees. The ‘summary()’ function produces a relative influence plot and also outputs the relative influence statistics.

library(gbm);  ##install.packages("gbm")
## Loaded gbm 2.1.4
set.seed (1)
boost.boston=gbm(medv~.,data=Boston[train,],distribution= "gaussian",n.trees=10000)
summary(boost.boston)

We see that ‘lstat’ and ‘rm’ are by far the most important variables.

We now use the boosted model to predict ‘medv’ on the test set:

yhat.boost=predict(boost.boston,newdata=Boston[-train,],n.trees=10000)
mean((yhat.boost -boston.test)^2) 
## [1] 15.74614

PLAY AROUND WITH IT IN THE LAB We can also modify the parameters in ‘gbm()’ function to see if we can obtain a better prediction accuracy. In the following ‘gbm()’ function, the option ‘interaction.depth=4’ limits the depth of each tree. The default depth of each tree is ‘interaction.depth = 1’. The ‘shrinkage=0.2’ specifies the shrinkage parameter \(\lambda\), whose default value is 0.001.

boost.boston=gbm(medv~.,data=Boston[train,],distribution="gaussian",n.trees=5000,
                 interaction.depth=4,shrinkage=0.2)
yhat.boost=predict(boost.boston,newdata=Boston[-train,],n.trees=5000)
mean((yhat.boost -boston.test)^2) 
## [1] 10.89474

The above gradient boosted model gives the best prediction accuracy so far.