library(ISLR)
library(tree)
We use the Hitters data to illustrate the tree model. In the following codes, the ‘dim()’ gives the number of rows and number of columns in the data, the function ‘names()’ outputs the column names (the names of the predictors) of the data.
Hitters[1:5,] #first five row
dim(Hitters)
## [1] 322 20
names(Hitters)
## [1] "AtBat" "Hits" "HmRun" "Runs" "RBI"
## [6] "Walks" "Years" "CAtBat" "CHits" "CHmRun"
## [11] "CRuns" "CRBI" "CWalks" "League" "Division"
## [16] "PutOuts" "Assists" "Errors" "Salary" "NewLeague"
We will first use the log of ‘Salary’ as the response and the ‘Years’ and ‘Hits’ as the two predictors. The syntax of the ‘tree()’ function is quite similar to that of the ‘lm()’ function. The ‘summary()’ function provides all the summary information of the tree outputs.
mytree=tree(log(Salary)~ Years+Hits, data=Hitters) #log because salary is a huge number
summary(mytree) #SAVED NAME MY TREE
##
## Regression tree:
## tree(formula = log(Salary) ~ Years + Hits, data = Hitters)
## Number of terminal nodes: 8
## Residual mean deviance: 0.2708 = 69.06 / 255
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -2.2400 -0.2980 -0.0365 0.0000 0.3233 2.1520
The above summary information includes number of terminal nodes, which is 8 in our case, and the Residual mean deviance, which is the residual sum of squares (RSS) in the regression tree.
One of the most attractive properties of trees is that they can be graphically displayed. We use the ‘plot()’ function to display the tree structure, the ‘text()’ function to display the node labels, and the ‘title()’ function to add a title of your plot.
plot(mytree)
text(mytree)
title("This is a tree!")
At a given internal node, the label of the form Xj < c (Years< 4.5) indicates the left-hand branch issued from that split, and the right-hand branch corresponds to Xj > c (Years > 4.5). The above tree has 8 terminal nodes. Based on above tree plot, ‘Years’ is the most important factor in determining ‘Salary’, and players with less experience earn lower salaries. To predict the ‘log(Salary)’ of a new payer, we only need to check which region this new player belongs to. For instance, a new player with ‘Years’ < 3.5 and ‘Hits’ < 40.5 will have a predicted ‘log(Salary)’ 5.511.
If we wish to prune the tree, we could do so as follows, using the ‘prune.tree()’ function. The first argument is the tree fit, the second argument ‘best=3’ says we would like to keep only 3 terminal nodes. We then plot the pruned tree. As you can see, this pruned tree is a subtree of the previous full-size tree. This tree is the one shown on page 2 of our lecture note.
prune.mytree = prune.tree(mytree, best=3)
plot(prune.mytree)
text(prune.mytree)
title("This is a pruned tree with 3 terminal nodes!")
set.seed(1)
train = sample(1:nrow(Hitters), nrow(Hitters)/2)
tree.Hitters=tree(log(Salary)~., data=Hitters, subset=train)
summary(tree.Hitters)
##
## Regression tree:
## tree(formula = log(Salary) ~ ., data = Hitters, subset = train)
## Variables actually used in tree construction:
## [1] "CHits" "Hits" "Walks" "RBI" "PutOuts" "Years" "Assists"
## [8] "CAtBat"
## Number of terminal nodes: 11
## Residual mean deviance: 0.185 = 22.02 / 119
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -1.891000 -0.198100 0.005742 0.000000 0.224700 2.217000
plot(tree.Hitters)
text(tree.Hitters)
title("This is a tree with all features!")
The output of ‘summary()’ now includes the variables actually used in tree construction, as well as tree size and RSS as we discussed before. My result is slightly different from the one shown in the textbook, which is because I do not know which seed they used in their tree model fit. Using a different seed will lead to a different result. This refers to a drawback of the single tree model that a single tree is not stable. We will discuss this point again later on.
Next, we use the cross-validation function ‘cv.tree()’ to determine the optimal level of tree complexity, i.e., the best tree size; cost complexity pruning is used in order to select a sequence of trees for consideration. The ‘cv.tree()’ function reports the number of terminal nodes of each tree considered (size) as well as the corresponding error rate (dev) and the value of the cost-complexity parameter used (k, which corresponds to \(\alpha\) in the lecture note.
set.seed(2)
cv.tree.Hitters=cv.tree(tree.Hitters) #SELECT BEST SIZE OF TREE --- 8
cv.tree.Hitters # alpha is k
## $size
## [1] 11 10 9 8 6 5 4 3 2 1
##
## $dev
## [1] 48.79574 49.43083 49.59175 47.49789 47.94292 47.86271 49.23235
## [8] 52.78509 54.33486 103.58266
##
## $k
## [1] -Inf 1.036220 1.242826 1.301227 1.477342 1.945330 3.045800
## [8] 6.016269 7.682035 54.859249
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
Note that, despite the name, ‘dev’ corresponds to the cross-validation error rate in this instance. We then plot the error rate as a function of the size.
plot(cv.tree.Hitters$size,cv.tree.Hitters$dev,type="b")
#plot(cv.tree.Hitters$size,cv.tree.Hitters$dev,type="l")
#plot(cv.tree.Hitters$size,cv.tree.Hitters$dev,type="o")
# x axis, y axis, type - line & dot)
We now apply the ‘prune.tree()’ function in order to prune the tree to obtain the best sized tree chosen by CV.
prune.cv.mytree = prune.tree(tree.Hitters, best=8)
plot(prune.cv.mytree)
text(prune.cv.mytree)
title("Pruned tree with tree size tuned by CV!")
Finally, we can make prediction using this pruned tree. The function ‘predict()’ is similar to the one we used in regression models. Remind that we use the ‘log(Salary)’ as the response. Therefore, the y.test is computed by taking log. The ‘plot()’ uses predicted value in x-axis, and the true response in y-axis. The ‘abline(0,1)’ adds a diagonal line to plot. The case when all points are on the diagonal line indicates the perfect prediction. We can compute the mean square error (MSE) of the prediction. Since some salary in the data are missing, we will use ‘na.rm = TRUE’ to remove those mossing values when we compute the MSE.
yhat=predict(prune.cv.mytree,newdata=Hitters[-train,]) #INDEX THAT IS NOT TRAIN
#first
y.test=log(Hitters[-train,"Salary"])
plot(yhat,y.test) #PLOT UR PREDICTED AND THE TRUE ONE
#XAXIS --- #Y AXIS
abline(0,1) #
mean((yhat-y.test)^2, na.rm = TRUE) # CALCULATED ERROR
## [1] 0.2668873
# YHAT AND YTEST ARE VECTORS -- SO YOU GET GET A COLUMN Y-HAT AND Y-TEST & THEN SQUARE IT
NEXT EXAMPLE - 4 diff trees for diff sample size REMEBER - USING DIFF SIZE
However, a single tree is not stable. We next plot the fitted tree model using 4 different training sizes, 100, 150, 200, and all data.
par(mfrow=c(2,2)) ## 4 figures in one plot
mytree = tree(Salary~ ., data=Hitters[1:100,])
plot(mytree);text(mytree);title("Tree: first 100 data")
mytree = tree(Salary~ ., data=Hitters[1:150,])
plot(mytree);text(mytree);title("Tree: first 150 data")
mytree = tree(Salary~ ., data=Hitters[1:200,])
plot(mytree);text(mytree);title("Tree: first 200 data")
mytree = tree(Salary~ ., data=Hitters)
plot(mytree);text(mytree);title("Tree: all data")
As you can see from the output, these 4 tree outputs provide very different results. It is not ideal if a tree output provides totally different results if we make a small change to the data. This is a major drawback of the single tree model. In the next, we will discuss other tree methods to solve this issue.
We will use the ‘randomForest’ package for both bagging and random forests methods. Its syntax is essentially the same as ‘tree()’. The package is not able to handle missing values, so we remove the NAs in the data first. In the ‘randomForest()’ function, an important argument is ‘mtry’, which is the number of features randomly sampled in random forests. Default values are mtry=\(\sqrt{p}\) for classification and mtry=\(p/3\) for regression. The Bagging is just a special case of random forest with mtry = \(p\).
library(randomForest); ##install.packages("randomForest")
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
Hitters=na.omit(Hitters) ## remove missing data
p = ncol(Hitters)-1 #number of predictor (20 columns - 1 as a response, 19 as p)
dim(Hitters) #you removed 100 or so rows
## [1] 263 20
#DIFFERENCE IS IN TERMS OF SIZE
# IF M TRY = P -> BAGGING
# IF M TRY < P -> RF
#BAGGING: Uses all p predictors
set.seed(1)
mybagging = randomForest(log(Salary) ~ ., mtry=p, data=Hitters)
mybagging
##
## Call:
## randomForest(formula = log(Salary) ~ ., data = Hitters, mtry = p)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 19
##
## Mean of squared residuals: 0.1911709
## % Var explained: 75.73
#RANDOM FOREST: randomly select m predictors to spilt the tree
#& for each
# ---> p/3 = 19/3 = 6
myrf = randomForest(log(Salary) ~ ., data=Hitters)
myrf
##
## Call:
## randomForest(formula = log(Salary) ~ ., data = Hitters)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 6
##
## Mean of squared residuals: 0.1837426
## % Var explained: 76.67
The output of bagging and random forests, gives you the number of variables tried at each split, the Mean of squared residuals, as well as the percentage of variance explained by the model.
Bagging/RF can also evaluate each predictor’s performance. Using the ‘importance()’ function, we can view the importance of each variable.
#Evaluate 'importance()' function, we can view the importance of each variable.
#make sure to set importance=TRUE
#First column:
#Exclude: IncNodePurity
myrf=randomForest(Salary ~.,data=Hitters,importance=TRUE)
importance(myrf)
## %IncMSE IncNodePurity
## AtBat 9.9816613 2531073.8
## Hits 7.2597161 3102464.3
## HmRun 4.9337753 1257497.1
## Runs 7.3579674 2807166.8
## RBI 6.0627356 3724398.3
## Walks 5.0695969 3991651.5
## Years 7.5586982 1327721.5
## CAtBat 12.9679474 4762352.5
## CHits 13.0278963 6573898.3
## CHmRun 10.1691602 2844038.8
## CRuns 12.0430468 5566461.6
## CRBI 11.6116657 6046523.4
## CWalks 8.1917595 3116466.1
## League -0.2225893 102073.6
## Division 1.5194866 164176.1
## PutOuts 2.7587697 2394619.1
## Assists 0.6373521 667391.8
## Errors 3.3421767 724359.8
## NewLeague 1.8512527 115823.6
Two different measures of variable importance are reported. The former is based upon the mean decrease of accuracy in predictions on the out of bag samples when a given variable is excluded from the model. The latter is a measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees. In the case of regression trees, the node impurity is measured by the training RSS, and for classification trees by the deviance.
Plots of these importance measures can be produced using the varImpPlot() function.
myrf=randomForest(Salary ~.,data=Hitters,importance=TRUE)
varImpPlot(myrf)
The results indicate that across all of the trees considered in the random forest, the ‘CHits’ is the most important predictor based on both criterions.
Next, we use another Boston Housing data set to compare all these three methods: single tree, bagging, random forests. Again, we use half data for training and half for testing. We set the seed for reproducibility.
library(tree);
library(ISLR);
library(MASS);
set.seed(1);
train = sample(1:nrow(Boston), nrow(Boston)/2);
tree.boston=tree(medv~.,Boston,subset=train); #PREDICT MEDIAN PRICE
summary(tree.boston);
##
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "lstat" "rm" "dis"
## Number of terminal nodes: 8
## Residual mean deviance: 12.65 = 3099 / 245
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -14.10000 -2.04200 -0.05357 0.00000 1.96000 12.60000
Variables actually used in tree construction: “lstat” “rm” “dis Number of terminal nodes: 8
plot(tree.boston);
text(tree.boston)
lstat: % of individuals with lower socioeconomic status. The tree indicates that lower lstat implies more expensive houses. The tree predicts a median house price of $46,380 for larger homes with high socioeconomic status residents (rm > 7.437 and lstat< 9.715).
Next we use function ‘cv.tree()’ to perform cross-validation to determine the optimal level of tree complexity. It outputs number of terminal nodes of each tree (size), as well as the corresponding error rate (dev). The CV chooses the most complex tree, the one with size 8, i.e., the whole tree.
cv.boston=cv.tree(tree.boston)
cv.boston
## $size
## [1] 8 7 6 5 4 3 2 1
##
## $dev
## [1] 5226.322 5228.360 6462.626 6692.615 6397.438 7529.846 11958.691
## [8] 21118.139
##
## $k
## [1] -Inf 255.6581 451.9272 768.5087 818.8885 1559.1264 4276.5803
## [8] 9665.3582
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(cv.boston$size,cv.boston$dev,type='b')
Next we make prediction using this tree model. The ‘predict()’ outputs the estimated median housing price. We then compute the prediction MSE.
yhat=predict(tree.boston,newdata=Boston[-train,])
boston.test=Boston[-train,"medv"]
plot(yhat,boston.test)
abline(0,1)
mean((yhat-boston.test)^2) #MSE
## [1] 25.04559
#MSE Based on single tree 25.04559
Next we apply bagging to the same training data and then compute the prediction MSE on the same testing data. Since the ‘randomForest()’ function has randomness, we need to set the seed for reproducibility.
library(randomForest)
## Bagging
set.seed(1)
bag.boston=randomForest(medv~.,data=Boston,subset=train,
mtry=13,importance=TRUE) #NUMBER OF PREDICT IN NEW DATA SET
yhat.bag = predict(bag.boston,newdata=Boston[-train,])
mean((yhat.bag-boston.test)^2)
## [1] 13.50808
###bagging MSE 13.50808, half of the og one
Comparing this plot with the one from the single tree, the estimated median housing price of bagging is closer to the true price. The MSE of bagging is almost half that obtained using a single tree!
Next we apply random forests to the same data tom compute the prediction MSE. Remind that Random forest uses \(p/3\) (by default) variables for regression trees.
set.seed(1)
rf.boston=randomForest(medv~.,data=Boston,subset=train,
mtry=5,importance=TRUE) #4 or 5
yhat.rf = predict(rf.boston,newdata=Boston[-train,])
mean((yhat.rf-boston.test)^2)
## [1] 10.94181
#11.66467
The random forest gives the best prediction accuracy in the Boston data.
Here we use the ‘gbm’ package to fit boosted regression trees to the Boston data set. We run ‘gbm()’ function with the option distribution=“gaussian” since this is a regression problem; if it were a binary classification problem, we would use distribution=“bernoulli”. The argument n.trees=10000 indicates that we want 10000 trees. The ‘summary()’ function produces a relative influence plot and also outputs the relative influence statistics.
library(gbm); ##install.packages("gbm")
## Loaded gbm 2.1.4
set.seed (1)
boost.boston=gbm(medv~.,data=Boston[train,],distribution= "gaussian",n.trees=10000)
summary(boost.boston)
We see that ‘lstat’ and ‘rm’ are by far the most important variables.
We now use the boosted model to predict ‘medv’ on the test set:
yhat.boost=predict(boost.boston,newdata=Boston[-train,],n.trees=10000)
mean((yhat.boost -boston.test)^2)
## [1] 15.74614
PLAY AROUND WITH IT IN THE LAB We can also modify the parameters in ‘gbm()’ function to see if we can obtain a better prediction accuracy. In the following ‘gbm()’ function, the option ‘interaction.depth=4’ limits the depth of each tree. The default depth of each tree is ‘interaction.depth = 1’. The ‘shrinkage=0.2’ specifies the shrinkage parameter \(\lambda\), whose default value is 0.001.
boost.boston=gbm(medv~.,data=Boston[train,],distribution="gaussian",n.trees=5000,
interaction.depth=4,shrinkage=0.2)
yhat.boost=predict(boost.boston,newdata=Boston[-train,],n.trees=5000)
mean((yhat.boost -boston.test)^2)
## [1] 10.89474
The above gradient boosted model gives the best prediction accuracy so far.