Completely from online, I understand the reasons for Gini and Entropy regarding measure of node purity (pg 335-336 2nd edition) but performing the math calculations and plotting (I’m lacking). From searching Rpubs and Github produced various solutions. Below is the more simple and straight forward solution that I found online. I’m not taking credit for the below solution, just wanted to have a solution and plot it.
#data = OJ[0:1] #gets the purchase column of the OJ dataset that has two classes CH or MM
p=seq(0,1,0.01)
gini= 2*p*(1-p)
classerror= 1-pmax(p,1-p)
crossentropy= -(p*log(p)+(1-p)*log(1-p))
plot(NA,NA,xlim=c(0,1),ylim=c(0,1),xlab='p',ylab='f')
lines(p,gini,type='l')
lines(p,classerror,col='blue')
lines(p,crossentropy,col='red')
legend(x='top',legend=c('gini','class error','cross entropy'),
col=c('black','blue','red'),lty=1,text.width = 0.22)
Look at the dataset
library(ISLR)
## Warning: package 'ISLR' was built under R version 4.1.2
str(Carseats)
## 'data.frame': 400 obs. of 11 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
set.seed(22)
train=sample(c(TRUE,FALSE), nrow(Carseats),rep=TRUE)
test=(!train)
Carseats.train = Carseats[train,]
Carseats.test = Carseats[!train,]
NOTES: To grow a tree, use rpart(formula, data=, method=,control=) where formula is in the format outcome ~ predictor1+predictor2+predictor3+ect., data= specifies the data frame, method=“class” for a classification tree (“anova” for a regression tree), and control= allows for optional parameters for controlling tree growth. For example, control=rpart.control(minsplit=15, cp=0.01) requires that the minimum number of observations in a node be 15 before attempting a split and that a split must decrease the overall lack of fit by a factor of 0.01 (cost complexity factor) before being attempted.
library(rpart)
## Warning: package 'rpart' was built under R version 4.1.3
#fit the tree, regression
tree.carseats = rpart(Sales~., data=Carseats.train, method = "anova", control=rpart.control(minsplit=15, cp=0.01))
The summary() function lists the variables that are used as internal nodes in the tree and the the number of terminal nodes.
#summary(tree.carseats)
fancyRpartPlot plots a fancy RPart decision tree using the pretty rpart plotter. You can read more about fancyRpartPlot https://www.rdocumentation.org/packages/rattle/versions/5.3.0/topics/fancyRpartPlot.
library(rattle)
## Warning: package 'rattle' was built under R version 4.1.3
## Loading required package: tibble
## Loading required package: bitops
## Rattle: A free graphical interface for data science with R.
## Version 5.5.1 Copyright (c) 2006-2021 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
fancyRpartPlot(tree.carseats)
tree.carseats
## n= 201
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 201 1620.437000 7.456915
## 2) ShelveLoc=Bad,Medium 160 1002.948000 6.755625
## 4) Price>=101.5 116 545.315100 6.030431
## 8) ShelveLoc=Bad 33 135.555900 4.219697
## 16) CompPrice< 144.5 27 73.927470 3.852222 *
## 17) CompPrice>=144.5 6 41.575330 5.873333 *
## 9) ShelveLoc=Medium 83 258.541300 6.750361
## 18) CompPrice< 123.5 33 68.249620 5.901515
## 36) Age>=49.5 23 25.459970 5.325652 *
## 37) Age< 49.5 10 17.619840 7.226000 *
## 19) CompPrice>=123.5 50 150.820500 7.310600
## 38) Price>=115.5 40 102.489100 6.860750
## 76) Education>=15.5 14 17.873690 5.737143 *
## 77) Education< 15.5 26 57.423230 7.465769 *
## 39) Price< 115.5 10 7.858400 9.110000 *
## 5) Price< 101.5 44 235.795800 8.667500
## 10) ShelveLoc=Bad 16 49.146740 6.966875
## 20) Income< 54 6 13.198280 5.388333 *
## 21) Income>=54 10 12.027240 7.914000 *
## 11) ShelveLoc=Medium 28 113.932800 9.639286
## 22) Price>=86.5 17 58.849890 8.709412
## 44) CompPrice< 118.5 8 17.491400 7.380000 *
## 45) CompPrice>=118.5 9 14.652090 9.891111 *
## 23) Price< 86.5 11 17.666450 11.076360 *
## 3) ShelveLoc=Good 41 231.719600 10.193660
## 6) Price>=113 23 75.742720 8.863478
## 12) Price>=144 6 6.724883 6.841667 *
## 13) Price< 144 17 35.835150 9.577059 *
## 7) Price< 113 18 63.281000 11.893330
## 14) CompPrice< 118 7 4.362286 10.278570 *
## 15) CompPrice>=118 11 29.051490 12.920910 *
MSE for the regression tree model on the test dataset
preds_tree <- predict(tree.carseats, newdata = Carseats.test)
sprintf("Tree MSE: %f", mean((preds_tree - Carseats.test$Sales)^2))
## [1] "Tree MSE: 4.912420"
Does pruning the tree improve the test MSE? No, no change in MSE, the optimal CV is 0.01, the same as the full tree.
NOTES: Using the non-caret approach can use some built in functions within rpart to examine the cross-validation error. The rpart package’s plotcp function plots the Complexity Parameter Table for an rpart tree fit on the training dataset. You don’t need to supply any additional validation datasets when using the plotcp function. To validate the model we use the printcp and plotcp functions. CP stands for Complexity Parameter of the tree. This function provides the optimal prunings based on the cp value.
Prune the tree to avoid any overfitting of the data. The convention is to have a small tree and the one with least cross validated error given by printcp() function i.e. ‘xerror’.
printcp(tree.carseats)
##
## Regression tree:
## rpart(formula = Sales ~ ., data = Carseats.train, method = "anova",
## control = rpart.control(minsplit = 15, cp = 0.01))
##
## Variables actually used in tree construction:
## [1] Age CompPrice Education Income Price ShelveLoc
##
## Root node error: 1620.4/201 = 8.0619
##
## n= 201
##
## CP nsplit rel error xerror xstd
## 1 0.238065 0 1.00000 1.00865 0.098105
## 2 0.136899 1 0.76193 0.86463 0.073996
## 3 0.093319 2 0.62504 0.79618 0.069726
## 4 0.057204 3 0.53172 0.65382 0.058288
## 5 0.044875 4 0.47451 0.59442 0.051413
## 6 0.024667 5 0.42964 0.55395 0.049036
## 7 0.023090 7 0.38030 0.55877 0.052002
## 8 0.020478 8 0.35721 0.54757 0.050513
## 9 0.018432 9 0.33673 0.57055 0.052642
## 10 0.016781 10 0.31830 0.56625 0.052416
## 11 0.016481 11 0.30152 0.56498 0.052357
## 12 0.015533 12 0.28504 0.55766 0.052338
## 13 0.014762 13 0.26951 0.56725 0.052983
## 14 0.012375 14 0.25475 0.55265 0.052909
## 15 0.010000 15 0.24237 0.54082 0.052013
Plotcp() provides a graphical representation to the cross validated error summary. The cp values are plotted against the geometric mean to depict the deviation until the minimum value is reached.
plotcp(tree.carseats)
use of this function that returns the optimal cp value associated with the minimum error.
tree.carseats$cptable[which.min(tree.carseats$cptable[,"xerror"]),"CP"]
## [1] 0.01
carseats.prune=prune(tree.carseats,cp=tree.carseats$cptable[which.min(tree.carseats$cptable[,"xerror"]),"CP"])
fancyRpartPlot(carseats.prune, uniform=TRUE, main="Pruned Classification Tree")
Using Caret for cross validation fit the model within caret using traincontrol cv = 10
library(caret)
## Warning: package 'caret' was built under R version 4.1.2
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 4.1.3
## Loading required package: lattice
set.seed(22)
train_control <- trainControl("cv", number = 10)
tree.carseats.caret = train(Sales~.,data=Carseats.train, trControl=train_control, method="rpart")
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, :
## There were missing values in resampled performance measures.
Looking at the model under caret cv = 10
tree.carseats.caret
## CART
##
## 201 samples
## 10 predictor
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 180, 181, 182, 180, 181, 181, ...
## Resampling results across tuning parameters:
##
## cp RMSE Rsquared MAE
## 0.09331921 2.352007 0.3398999 1.962920
## 0.13689940 2.453457 0.2667407 2.006134
## 0.23806518 2.787435 0.1021585 2.277097
##
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was cp = 0.09331921.
tree.carseats.caret$finalModel
## n= 201
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 201 1620.4370 7.456915
## 2) ShelveLocGood< 0.5 160 1002.9480 6.755625
## 4) Price>=101.5 116 545.3151 6.030431 *
## 5) Price< 101.5 44 235.7958 8.667500 *
## 3) ShelveLocGood>=0.5 41 231.7196 10.193660 *
NOTES: Notice that the output indicates that the final value used for the model was cp = 0.09331921. The caret package implements the rpart method with cp as the tuning parameter. caret by default will prune your tree based on a default run it makes on a default parameter grid (even if you don’t supply any tuneGrid and trControl while training your model.
Also notice that the output of the finalModel object indicates that only two of the variables have been used in constructing the tree. In the context of a regression tree, the deviance is simply the sum of squared errors for the tree.
plot the tree. We need to use the rpart.plot function in the rpart.plot library since the output of caret’s train function doesn’t work with fancyRpartPlot().
library(rpart.plot)
## Warning: package 'rpart.plot' was built under R version 4.1.3
rpart.plot(tree.carseats.caret$finalModel)
8.(d) Use the bagging approach in order to analyze this data. What test MSE do you obtain? Bag MSE: 2.966265 Use the importance() function to determine which variables are most important Top 3: Shelfloc, price, compprice
library(randomForest)
## Warning: package 'randomForest' was built under R version 4.1.3
## randomForest 4.7-1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
## The following object is masked from 'package:rattle':
##
## importance
set.seed(22)
bag.carseats = randomForest(Sales~.,data= Carseats.train, mtry = 10, trControl = train_control, importance=TRUE)
#best tuning parameter
bag.carseats
##
## Call:
## randomForest(formula = Sales ~ ., data = Carseats.train, mtry = 10, trControl = train_control, importance = TRUE)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 10
##
## Mean of squared residuals: 2.549706
## % Var explained: 68.37
preds_bag <- predict(bag.carseats, newdata = Carseats.test)
sprintf("Bag MSE: %f", mean((preds_bag - Carseats.test$Sales)^2))
## [1] "Bag MSE: 2.966265"
Importance variables
varImp(bag.carseats)
## Overall
## CompPrice 21.743185
## Income 6.761682
## Advertising 9.854972
## Population 1.635470
## Price 55.380365
## ShelveLoc 67.243200
## Age 15.478874
## Education 5.988575
## Urban -1.914082
## US -2.105608
varImpPlot(bag.carseats)
Describe the effect of m, the number of variables considered at each split, on the error rate obtained. Error rate goes down as the model reaches the optimal number of variables. RF randomly selects predictors at each split random forests “de-correlates” the bagged trees leading to more reduction in variance
Fit random forest model and identify best tuning parameter
rf.carseats = train(Sales~.,data= Carseats.train, method = 'rf', trControl = train_control, importance=TRUE)
#best tuning parameter
rf.carseats$bestTune
## mtry
## 3 11
final rf model
rf.carseats$finalModel
##
## Call:
## randomForest(x = x, y = y, mtry = min(param$mtry, ncol(x)), importance = TRUE)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 11
##
## Mean of squared residuals: 2.61214
## % Var explained: 67.6
MSE for the rf model on the test dataset
preds_rf <- predict(rf.carseats, newdata = Carseats.test)
sprintf("RF MSE: %f", mean((preds_rf - Carseats.test$Sales)^2))
## [1] "RF MSE: 2.922798"
Importance variables
varImp(rf.carseats)
## rf variable importance
##
## Overall
## ShelveLocGood 100.000
## Price 86.126
## ShelveLocMedium 61.342
## CompPrice 41.543
## Age 22.366
## Advertising 17.738
## Income 14.741
## Education 13.871
## Population 6.544
## USYes 2.704
## UrbanYes 0.000
plot(varImp(rf.carseats))
Look at the data
str(OJ)
## 'data.frame': 1070 obs. of 18 variables:
## $ Purchase : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
## $ WeekofPurchase: num 237 239 245 227 228 230 232 234 235 238 ...
## $ StoreID : num 1 1 1 1 7 7 7 7 7 7 ...
## $ PriceCH : num 1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceMM : num 1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
## $ DiscCH : num 0 0 0.17 0 0 0 0 0 0 0 ...
## $ DiscMM : num 0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
## $ SpecialCH : num 0 0 0 0 0 0 1 1 0 0 ...
## $ SpecialMM : num 0 1 0 0 0 1 1 0 0 0 ...
## $ LoyalCH : num 0.5 0.6 0.68 0.4 0.957 ...
## $ SalePriceMM : num 1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
## $ SalePriceCH : num 1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceDiff : num 0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
## $ Store7 : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
## $ PctDiscMM : num 0 0.151 0 0 0 ...
## $ PctDiscCH : num 0 0 0.0914 0 0 ...
## $ ListPriceDiff : num 0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
## $ STORE : num 1 1 1 1 0 0 0 0 0 0 ...
9 (a) Create a training set containing a random sample of 800 observations, and a test set containing the remaining observations.
set.seed(22)
train.oj=sample(1:nrow(OJ), 800)
OJ.train = OJ[train.oj,]
OJ.test = OJ[-train.oj,]
library(tree)
## Warning: package 'tree' was built under R version 4.1.3
## Registered S3 method overwritten by 'tree':
## method from
## print.tree cli
tree.oj = tree(Purchase~.,data=OJ.train)
summary(tree.oj)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "ListPriceDiff" "PriceDiff"
## Number of terminal nodes: 6
## Residual mean deviance: 0.7857 = 623.8 / 794
## Misclassification error rate: 0.1638 = 131 / 800
NOTE: Code from RLab
tree.oj2 = rpart(Purchase~.,data=OJ.train, method = "class")
summary(tree.oj2)
## Call:
## rpart(formula = Purchase ~ ., data = OJ.train, method = "class")
## n= 800
##
## CP nsplit rel error xerror xstd
## 1 0.52698413 0 1.0000000 1.0000000 0.04387030
## 2 0.02063492 1 0.4730159 0.4730159 0.03495651
## 3 0.01269841 3 0.4317460 0.4730159 0.03495651
## 4 0.01000000 4 0.4190476 0.4634921 0.03468245
##
## Variable importance
## LoyalCH ListPriceDiff PriceMM PriceDiff WeekofPurchase
## 74 6 4 4 4
## SalePriceMM DiscMM PctDiscMM PriceCH
## 3 2 2 1
##
## Node number 1: 800 observations, complexity param=0.5269841
## predicted class=CH expected loss=0.39375 P(node) =1
## class counts: 485 315
## probabilities: 0.606 0.394
## left son=2 (506 obs) right son=3 (294 obs)
## Primary splits:
## LoyalCH < 0.48285 to the right, improve=140.35880, (0 missing)
## StoreID < 3.5 to the right, improve= 38.56290, (0 missing)
## Store7 splits as RL, improve= 18.74697, (0 missing)
## STORE < 0.5 to the left, improve= 18.74697, (0 missing)
## PriceDiff < 0.015 to the right, improve= 15.00298, (0 missing)
## Surrogate splits:
## PriceMM < 1.89 to the right, agree=0.641, adj=0.024, (0 split)
## WeekofPurchase < 227.5 to the right, agree=0.639, adj=0.017, (0 split)
## DiscMM < 0.77 to the left, agree=0.637, adj=0.014, (0 split)
## SalePriceMM < 1.385 to the right, agree=0.637, adj=0.014, (0 split)
## PriceDiff < -0.575 to the right, agree=0.637, adj=0.014, (0 split)
##
## Node number 2: 506 observations, complexity param=0.02063492
## predicted class=CH expected loss=0.1679842 P(node) =0.6325
## class counts: 421 85
## probabilities: 0.832 0.168
## left son=4 (300 obs) right son=5 (206 obs)
## Primary splits:
## LoyalCH < 0.705699 to the right, improve=18.262560, (0 missing)
## PriceDiff < 0.015 to the right, improve=10.177560, (0 missing)
## SalePriceMM < 1.84 to the right, improve= 9.143175, (0 missing)
## ListPriceDiff < 0.18 to the right, improve= 7.826746, (0 missing)
## PriceMM < 1.74 to the right, improve= 5.069620, (0 missing)
## Surrogate splits:
## WeekofPurchase < 236.5 to the right, agree=0.632, adj=0.097, (0 split)
## PriceCH < 1.755 to the right, agree=0.625, adj=0.078, (0 split)
## PriceMM < 2.04 to the right, agree=0.625, adj=0.078, (0 split)
## SalePriceMM < 1.64 to the right, agree=0.615, adj=0.053, (0 split)
## PctDiscMM < 0.1961965 to the left, agree=0.607, adj=0.034, (0 split)
##
## Node number 3: 294 observations
## predicted class=MM expected loss=0.2176871 P(node) =0.3675
## class counts: 64 230
## probabilities: 0.218 0.782
##
## Node number 4: 300 observations
## predicted class=CH expected loss=0.05666667 P(node) =0.375
## class counts: 283 17
## probabilities: 0.943 0.057
##
## Node number 5: 206 observations, complexity param=0.02063492
## predicted class=CH expected loss=0.3300971 P(node) =0.2575
## class counts: 138 68
## probabilities: 0.670 0.330
## left son=10 (153 obs) right son=11 (53 obs)
## Primary splits:
## ListPriceDiff < 0.155 to the right, improve=12.214210, (0 missing)
## PriceDiff < 0.015 to the right, improve=10.320220, (0 missing)
## SalePriceMM < 1.84 to the right, improve= 7.652049, (0 missing)
## PriceMM < 2.04 to the right, improve= 4.651385, (0 missing)
## WeekofPurchase < 229.5 to the right, improve= 4.434177, (0 missing)
## Surrogate splits:
## PriceMM < 1.89 to the right, agree=0.845, adj=0.396, (0 split)
## WeekofPurchase < 230.5 to the right, agree=0.816, adj=0.283, (0 split)
## PriceDiff < -0.185 to the right, agree=0.796, adj=0.208, (0 split)
## PriceCH < 1.72 to the right, agree=0.772, adj=0.113, (0 split)
## LoyalCH < 0.6942745 to the left, agree=0.752, adj=0.038, (0 split)
##
## Node number 10: 153 observations, complexity param=0.01269841
## predicted class=CH expected loss=0.2287582 P(node) =0.19125
## class counts: 118 35
## probabilities: 0.771 0.229
## left son=20 (143 obs) right son=21 (10 obs)
## Primary splits:
## PriceDiff < -0.165 to the right, improve=4.751963, (0 missing)
## PriceMM < 2.04 to the right, improve=3.176309, (0 missing)
## StoreID < 2.5 to the right, improve=2.911818, (0 missing)
## PriceCH < 1.775 to the right, improve=2.681909, (0 missing)
## LoyalCH < 0.574494 to the right, improve=2.364922, (0 missing)
## Surrogate splits:
## DiscMM < 0.57 to the left, agree=0.974, adj=0.6, (0 split)
## SalePriceMM < 1.585 to the right, agree=0.974, adj=0.6, (0 split)
## PctDiscMM < 0.264375 to the left, agree=0.974, adj=0.6, (0 split)
## ListPriceDiff < 0.18 to the right, agree=0.954, adj=0.3, (0 split)
##
## Node number 11: 53 observations
## predicted class=MM expected loss=0.3773585 P(node) =0.06625
## class counts: 20 33
## probabilities: 0.377 0.623
##
## Node number 20: 143 observations
## predicted class=CH expected loss=0.1958042 P(node) =0.17875
## class counts: 115 28
## probabilities: 0.804 0.196
##
## Node number 21: 10 observations
## predicted class=MM expected loss=0.3 P(node) =0.0125
## class counts: 3 7
## probabilities: 0.300 0.700
view(OJ)
tree.oj
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1073.000 CH ( 0.60625 0.39375 )
## 2) LoyalCH < 0.48285 294 308.100 MM ( 0.21769 0.78231 )
## 4) LoyalCH < 0.035047 51 9.844 MM ( 0.01961 0.98039 ) *
## 5) LoyalCH > 0.035047 243 278.100 MM ( 0.25926 0.74074 ) *
## 3) LoyalCH > 0.48285 506 458.100 CH ( 0.83202 0.16798 )
## 6) LoyalCH < 0.753545 245 301.800 CH ( 0.69388 0.30612 )
## 12) ListPriceDiff < 0.18 66 89.300 MM ( 0.40909 0.59091 ) *
## 13) ListPriceDiff > 0.18 179 179.700 CH ( 0.79888 0.20112 )
## 26) PriceDiff < -0.165 8 6.028 MM ( 0.12500 0.87500 ) *
## 27) PriceDiff > -0.165 171 155.700 CH ( 0.83041 0.16959 ) *
## 7) LoyalCH > 0.753545 261 84.850 CH ( 0.96169 0.03831 ) *
tree.oj2
## n= 800
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 315 CH (0.60625000 0.39375000)
## 2) LoyalCH>=0.48285 506 85 CH (0.83201581 0.16798419)
## 4) LoyalCH>=0.705699 300 17 CH (0.94333333 0.05666667) *
## 5) LoyalCH< 0.705699 206 68 CH (0.66990291 0.33009709)
## 10) ListPriceDiff>=0.155 153 35 CH (0.77124183 0.22875817)
## 20) PriceDiff>=-0.165 143 28 CH (0.80419580 0.19580420) *
## 21) PriceDiff< -0.165 10 3 MM (0.30000000 0.70000000) *
## 11) ListPriceDiff< 0.155 53 20 MM (0.37735849 0.62264151) *
## 3) LoyalCH< 0.48285 294 64 MM (0.21768707 0.78231293) *
LoyalCH is the most important predictor If loyalty to CH is less than .48, than the prediction is MM If loyalty to CH is >= to .48, than directed to another node If loyalty to CH is >= to .75, prediction is CH If loyalty to CH is less than .75, direct to another node ListPriceDiff If ListPriceDiff is less than 0.18 prediction is MM If ListPriceDiff is >= 0.18 direct to another node PriceDiff If PriceDiff is < -.165 MM is the prediction
plot(tree.oj)
text(tree.oj, pretty = 0)
Plot the tree Note: code from RLab Same as above plot except sides have switch and nice labels added
fancyRpartPlot(tree.oj2)
NOTE: pg 355 2nd Edition
preds_oj <- predict(tree.oj, OJ.test, type = "class")
table(preds_oj, OJ.test$Purchase)
##
## preds_oj CH MM
## CH 132 14
## MM 36 88
oj_correct = (88 + 132) / 270
test_error_oj = 1 - oj_correct
test_error_oj
## [1] 0.1851852
cv.oj = cv.tree(tree.oj)
cv.oj
## $size
## [1] 6 5 4 3 2 1
##
## $dev
## [1] 719.8520 736.5305 736.5349 755.6974 811.2094 1076.5794
##
## $k
## [1] -Inf 17.97797 20.11905 32.82322 71.43514 306.43464
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(cv.oj$size, cv.oj$dev, type = "b")
plotcp(tree.oj2)
printcp(tree.oj2)
##
## Classification tree:
## rpart(formula = Purchase ~ ., data = OJ.train, method = "class")
##
## Variables actually used in tree construction:
## [1] ListPriceDiff LoyalCH PriceDiff
##
## Root node error: 315/800 = 0.39375
##
## n= 800
##
## CP nsplit rel error xerror xstd
## 1 0.526984 0 1.00000 1.00000 0.043870
## 2 0.020635 1 0.47302 0.47302 0.034957
## 3 0.012698 3 0.43175 0.47302 0.034957
## 4 0.010000 4 0.41905 0.46349 0.034682
tree.oj2$cptable[which.min(tree.oj2$cptable[,"xerror"]),"CP"]
## [1] 0.01
oj.pruned = prune(tree.oj2, cp = tree.oj2$cptable[which.min(tree.oj2$cptable[,"xerror"]),"CP"])
fancyRpartPlot(oj.pruned, uniform = TRUE, main="Pruned Classification Tree")
If cross-validation does not lead to selection of a pruned tree, then create a pruned tree with five terminal nodes.
oj.pruned5 = prune.tree(tree.oj, best = 5)
plot(oj.pruned5)
text(oj.pruned5, pretty = 0)
summary(tree.oj)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "ListPriceDiff" "PriceDiff"
## Number of terminal nodes: 6
## Residual mean deviance: 0.7857 = 623.8 / 794
## Misclassification error rate: 0.1638 = 131 / 800
summary(oj.pruned5)
##
## Classification tree:
## snip.tree(tree = tree.oj, nodes = 13L)
## Variables actually used in tree construction:
## [1] "LoyalCH" "ListPriceDiff"
## Number of terminal nodes: 5
## Residual mean deviance: 0.8073 = 641.8 / 795
## Misclassification error rate: 0.1713 = 137 / 800
unpruned: 0.1852 pruned: 0.20 (HIGHER)
preds_oj <- predict(tree.oj, OJ.test, type = "class")
table(preds_oj, OJ.test$Purchase)
##
## preds_oj CH MM
## CH 132 14
## MM 36 88
preds_oj5 <- predict(oj.pruned5, OJ.test, type = "class")
table(preds_oj5, OJ.test$Purchase)
##
## preds_oj5 CH MM
## CH 132 18
## MM 36 84
oj_correct = (88 + 132) / 270
test_error_oj = 1 - oj_correct
test_error_oj
## [1] 0.1851852
oj5_correct = (84 + 132) / 270
test_error_oj5 = 1 - oj5_correct
test_error_oj5
## [1] 0.2