In this project, I am going to implement Dcision tree models on Carseats data set.
Data:
Orange Juice(OJ) data frame with 1070 observations on the following 18 variables.
Purchase:A factor with levels CH and MM indicating whether the customer purchased Citrus Hill or Minute Maid Orange Juice
WeekofPurchase: Week of purchase
StoreID: Store ID
PriceCH: Price charged for CH
PriceMM: Price charged for MM
DiscCH: Discount offered for CH
DiscMM: Discount offered for MM
SpecialCH: Indicator of special on CH
SpecialMM: Indicator of special on MM
LoyalCH: Customer brand loyalty for CH
SalePriceMM: Sale price for MM
SalePriceCH: Sale price for CH
PriceDiff: Sale price of MM less sale price of CH
Store7: A factor with levels No and Yes indicating whether the sale is at Store 7
PctDiscMM: Percentage discount for MM
PctDiscCH: Percentage discount for CH
ListPriceDiff: List price of MM less list price of CH
STORE: Which of 5 possible stores the sale occured at
Objective:
Fit a decision tree and Find Test Error Rate
Report important variables
Plot the tree and interpret the results
Show Tree Object and Interpret
Perform Cross Validation and Prune Tree
Compare Results: Unpruned Vs. Pruned Tree
Predict the response on the test data set and create Confusion Matrix
CrossValidation with Rpart and Interpretation
Note that I will be using rpart library. Rpart is a powerful machine learning library in R that is used for building classification and regression trees.
Loading Libraries
#Loading necessary libraries
packages <- c('caret', 'randomForest', 'formattable', 'pls','ISLR', 'tree', 'rpart', 'rpart.plot')
sapply(packages, require, character.only=T)
## Loading required package: caret
## Loading required package: lattice
## Loading required package: ggplot2
## Loading required package: randomForest
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
## Loading required package: formattable
## Loading required package: pls
##
## Attaching package: 'pls'
## The following object is masked from 'package:caret':
##
## R2
## The following object is masked from 'package:stats':
##
## loadings
## Loading required package: ISLR
## Loading required package: tree
## Loading required package: rpart
## Loading required package: rpart.plot
## caret randomForest formattable pls ISLR tree
## TRUE TRUE TRUE TRUE TRUE TRUE
## rpart rpart.plot
## TRUE TRUE
data(OJ)
Data Exploration
head(OJ)
## Purchase WeekofPurchase StoreID PriceCH PriceMM DiscCH DiscMM SpecialCH
## 1 CH 237 1 1.75 1.99 0.00 0.0 0
## 2 CH 239 1 1.75 1.99 0.00 0.3 0
## 3 CH 245 1 1.86 2.09 0.17 0.0 0
## 4 MM 227 1 1.69 1.69 0.00 0.0 0
## 5 CH 228 7 1.69 1.69 0.00 0.0 0
## 6 CH 230 7 1.69 1.99 0.00 0.0 0
## SpecialMM LoyalCH SalePriceMM SalePriceCH PriceDiff Store7 PctDiscMM
## 1 0 0.500000 1.99 1.75 0.24 No 0.000000
## 2 1 0.600000 1.69 1.75 -0.06 No 0.150754
## 3 0 0.680000 2.09 1.69 0.40 No 0.000000
## 4 0 0.400000 1.69 1.69 0.00 No 0.000000
## 5 0 0.956535 1.69 1.69 0.00 Yes 0.000000
## 6 1 0.965228 1.99 1.69 0.30 Yes 0.000000
## PctDiscCH ListPriceDiff STORE
## 1 0.000000 0.24 1
## 2 0.000000 0.24 1
## 3 0.091398 0.23 1
## 4 0.000000 0.00 1
## 5 0.000000 0.00 0
## 6 0.000000 0.30 0
str(OJ)
## 'data.frame': 1070 obs. of 18 variables:
## $ Purchase : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
## $ WeekofPurchase: num 237 239 245 227 228 230 232 234 235 238 ...
## $ StoreID : num 1 1 1 1 7 7 7 7 7 7 ...
## $ PriceCH : num 1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceMM : num 1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
## $ DiscCH : num 0 0 0.17 0 0 0 0 0 0 0 ...
## $ DiscMM : num 0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
## $ SpecialCH : num 0 0 0 0 0 0 1 1 0 0 ...
## $ SpecialMM : num 0 1 0 0 0 1 1 0 0 0 ...
## $ LoyalCH : num 0.5 0.6 0.68 0.4 0.957 ...
## $ SalePriceMM : num 1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
## $ SalePriceCH : num 1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceDiff : num 0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
## $ Store7 : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
## $ PctDiscMM : num 0 0.151 0 0 0 ...
## $ PctDiscCH : num 0 0 0.0914 0 0 ...
## $ ListPriceDiff : num 0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
## $ STORE : num 1 1 1 1 0 0 0 0 0 0 ...
Data Partition
set.seed(111)
ind <- sample(2,nrow(OJ), replace=TRUE, prob = c(0.8,0.2))
train <- OJ[ind==1,]
test <- OJ[ind==2,]
Fit Model
I am goin to fit 2 models, one with tree function and onq with rpart library and compare the result.
#Decision Tree
model_fit1 <-tree(Purchase~., data = train)
model_fit2 <- rpart(Purchase ~ ., data = train, method = 'class',
control = rpart.control(cp = 0))
Test Error Rate
#Prediction 1
tree.pred1 <- predict(model_fit1, test, type = "class")
result_test1 <-confusionMatrix(data=tree.pred1,reference=test$Purchase)
print(result_test1)
## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 120 19
## MM 13 47
##
## Accuracy : 0.8392
## 95% CI : (0.7806, 0.8873)
## No Information Rate : 0.6683
## P-Value [Acc > NIR] : 4.365e-08
##
## Kappa : 0.6288
##
## Mcnemar's Test P-Value : 0.3768
##
## Sensitivity : 0.9023
## Specificity : 0.7121
## Pos Pred Value : 0.8633
## Neg Pred Value : 0.7833
## Prevalence : 0.6683
## Detection Rate : 0.6030
## Detection Prevalence : 0.6985
## Balanced Accuracy : 0.8072
##
## 'Positive' Class : CH
##
#Prediction 2 (Prediction with RPart)
tree.pred2 <- predict(model_fit2, test, type = "class")
result_test2 <-confusionMatrix(data=tree.pred2,reference=test$Purchase)
print(result_test2)
## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 115 15
## MM 18 51
##
## Accuracy : 0.8342
## 95% CI : (0.7751, 0.883)
## No Information Rate : 0.6683
## P-Value [Acc > NIR] : 1.121e-07
##
## Kappa : 0.6302
##
## Mcnemar's Test P-Value : 0.7277
##
## Sensitivity : 0.8647
## Specificity : 0.7727
## Pos Pred Value : 0.8846
## Neg Pred Value : 0.7391
## Prevalence : 0.6683
## Detection Rate : 0.5779
## Detection Prevalence : 0.6533
## Balanced Accuracy : 0.8187
##
## 'Positive' Class : CH
##
names(result_test2)
## [1] "positive" "table" "overall" "byClass" "mode" "dots"
#Accuracy Rate
acc <- result_test2$overall[1]
acc
## Accuracy
## 0.8341709
#Test Error Rate
1-0.8341709
## [1] 0.1658291
#Based on Model1
summary(model_fit1)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff"
## Number of terminal nodes: 7
## Residual mean deviance: 0.7759 = 670.4 / 864
## Misclassification error rate: 0.1642 = 143 / 871
The summary indicates that only 3 variables which are “LoyalCH”, “PriceDiff”, “ListPriceDiff” have been used in constructing the tree.
The fitted tree has 7 terminal nodes and a training error rate of 0.165.
#Based on Model2
Value<-model_fit2$variable.importance
ndf<-as.data.frame(Value)
ggplot(data=ndf, aes(x=row.names(ndf), y=Value, fill=row.names(ndf))) +
ggtitle("Variable Importance")+
geom_bar(stat="identity",width=0.6, fill="steelblue")+
theme_minimal() +
coord_flip()
Model 2 indicates that the most important variable is by far “LoyalCH” following by “Pricediff” and “SalePriceMM”.
rpart.plot(model_fit2)
It is clear that the most important variable is Loyal CH for purchasing decision since top three nods are LoyalCH.
Plot tell us how a decision is made. For example, if we look at the top node, if a person’s LoyalCH>0.48 then this person will purchase CH, if LoyalCH<0.48 then this person will purchase MM. But then if ListPriceDiff<0.23, then he may go with MM but then STORE variable take into consideration. Tree diagram is very useful to understand the how decisions are effective in terms of the product selection.
model_fit2
## n= 871
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 871 351 CH (0.59701493 0.40298507)
## 2) LoyalCH>=0.48285 542 95 CH (0.82472325 0.17527675)
## 4) LoyalCH>=0.7645725 281 14 CH (0.95017794 0.04982206) *
## 5) LoyalCH< 0.7645725 261 81 CH (0.68965517 0.31034483)
## 10) PriceDiff>=-0.165 220 50 CH (0.77272727 0.22727273)
## 20) ListPriceDiff>=0.135 181 29 CH (0.83977901 0.16022099)
## 40) SalePriceMM>=2.125 91 7 CH (0.92307692 0.07692308) *
## 41) SalePriceMM< 2.125 90 22 CH (0.75555556 0.24444444)
## 82) StoreID>=2.5 50 8 CH (0.84000000 0.16000000) *
## 83) StoreID< 2.5 40 14 CH (0.65000000 0.35000000)
## 166) WeekofPurchase>=236.5 32 9 CH (0.71875000 0.28125000) *
## 167) WeekofPurchase< 236.5 8 3 MM (0.37500000 0.62500000) *
## 21) ListPriceDiff< 0.135 39 18 MM (0.46153846 0.53846154)
## 42) LoyalCH>=0.51 30 14 CH (0.53333333 0.46666667)
## 84) WeekofPurchase< 228.5 7 1 CH (0.85714286 0.14285714) *
## 85) WeekofPurchase>=228.5 23 10 MM (0.43478261 0.56521739)
## 170) StoreID>=3.5 12 5 CH (0.58333333 0.41666667) *
## 171) StoreID< 3.5 11 3 MM (0.27272727 0.72727273) *
## 43) LoyalCH< 0.51 9 2 MM (0.22222222 0.77777778) *
## 11) PriceDiff< -0.165 41 10 MM (0.24390244 0.75609756) *
## 3) LoyalCH< 0.48285 329 73 MM (0.22188450 0.77811550)
## 6) LoyalCH>=0.2761415 146 54 MM (0.36986301 0.63013699)
## 12) PriceDiff>=0.065 82 40 MM (0.48780488 0.51219512)
## 24) LoyalCH< 0.3084325 8 0 CH (1.00000000 0.00000000) *
## 25) LoyalCH>=0.3084325 74 32 MM (0.43243243 0.56756757)
## 50) PriceDiff>=0.31 35 16 CH (0.54285714 0.45714286)
## 100) STORE< 1.5 23 8 CH (0.65217391 0.34782609) *
## 101) STORE>=1.5 12 4 MM (0.33333333 0.66666667) *
## 51) PriceDiff< 0.31 39 13 MM (0.33333333 0.66666667)
## 102) LoyalCH>=0.390304 25 11 MM (0.44000000 0.56000000)
## 204) PriceCH>=1.94 7 2 CH (0.71428571 0.28571429) *
## 205) PriceCH< 1.94 18 6 MM (0.33333333 0.66666667) *
## 103) LoyalCH< 0.390304 14 2 MM (0.14285714 0.85714286) *
## 13) PriceDiff< 0.065 64 14 MM (0.21875000 0.78125000)
## 26) SpecialCH>=0.5 7 3 CH (0.57142857 0.42857143) *
## 27) SpecialCH< 0.5 57 10 MM (0.17543860 0.82456140) *
## 7) LoyalCH< 0.2761415 183 19 MM (0.10382514 0.89617486)
## 14) LoyalCH>=0.136344 71 14 MM (0.19718310 0.80281690)
## 28) PriceDiff>=0.31 9 4 CH (0.55555556 0.44444444) *
## 29) PriceDiff< 0.31 62 9 MM (0.14516129 0.85483871)
## 58) LoyalCH< 0.1990905 23 6 MM (0.26086957 0.73913043)
## 116) STORE< 1.5 9 4 CH (0.55555556 0.44444444) *
## 117) STORE>=1.5 14 1 MM (0.07142857 0.92857143) *
## 59) LoyalCH>=0.1990905 39 3 MM (0.07692308 0.92307692) *
## 15) LoyalCH< 0.136344 112 5 MM (0.04464286 0.95535714) *
In the tree object, we can see the total number of nodes and detail information about each node.
For example; the Node number 4 has LoyalCH > 0.7645725 and 281 branches with deviance 14. CH (0.95017794 0.04982206): meaning that %95 of its observation takes value of CH(Citrus Hill) and and 4.98% of its observation takes value of MM(Minute Maid).
#Cross Validation Method
cv.OJ <- cv.tree(model_fit1, FUN=prune.misclass)
names(cv.OJ)
## [1] "size" "dev" "k" "method"
cv.OJ$size
## [1] 7 6 2 1
#the optimal size for tree is 7 6 2 1
#Plot CV
plot(cv.OJ$size, cv.OJ$dev,xlab = "Size of Tree", ylab = "Deviance", type = "b")
tree.min <- which.min(cv.OJ$dev)
df <- data.frame(cv.OJ$size, cv.OJ$dev)
min_value<-df[which.min(cv.OJ$dev),1]
points(min_value, cv.OJ$dev[tree.min], col = "green", cex = 2, pch = 20)
We get the minimum deviance value when size of tree is 6. Deviance corresponds to the cross-validation error.
#Pruning Method
prune.OJ <- prune.misclass(model_fit1, best = 6)
plot(prune.OJ)
text(prune.OJ, pretty = 0)
#Yes, Pruning the tree really help us to see the graph.
#Summary of Unpruned Tree
summary(model_fit1)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff"
## Number of terminal nodes: 7
## Residual mean deviance: 0.7759 = 670.4 / 864
## Misclassification error rate: 0.1642 = 143 / 871
#Summary of Pruned Tree
summary(prune.OJ)
##
## Classification tree:
## snip.tree(tree = model_fit1, nodes = 13L)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 6
## Residual mean deviance: 0.7935 = 686.3 / 865
## Misclassification error rate: 0.1642 = 143 / 871
Misclassification error rate is identical, so prunned tree did not reduce the missclassification error rate.
# Unpruned Tree
tree.pred <- predict(model_fit1, test, type = "class")
unpr_tb <- table(tree.pred, test$Purchase)
unpr_tb
##
## tree.pred CH MM
## CH 120 19
## MM 13 47
1-(sum(diag(unpr_tb))/ sum(unpr_tb))
## [1] 0.160804
# Pruned Tree
prune.pred <- predict(prune.OJ, test, type = "class")
pr_tb <-table(prune.pred, test$Purchase)
pr_tb
##
## prune.pred CH MM
## CH 120 19
## MM 13 47
1-(sum(diag(pr_tb))/ sum(pr_tb))
## [1] 0.160804
Pruned tree did not reduce the Test Error rate, it is because we don’t actually prune the tree. It had 7 nodes and we dropped it to 6 nodes. This reduction did not make any difference.
# define training control
train_control <- trainControl(method = "cv", number = 10)
# train the model on training set
CV_treemodel <- train(Purchase ~ .,
data = OJ,
trControl = train_control,
method = 'rpart')
CV_treemodel
## CART
##
## 1070 samples
## 17 predictor
## 2 classes: 'CH', 'MM'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 962, 963, 963, 964, 963, 963, ...
## Resampling results across tuning parameters:
##
## cp Accuracy Kappa
## 0.009592326 0.8065966 0.5927743
## 0.017985612 0.8140732 0.6125810
## 0.510791367 0.7017119 0.2824111
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.01798561.
We get the highest accuracy rate when cp= 0.02. The complexity parameter (cp) in rpart is the minimum improvement in the model needed at each node.
pruned_model <- rpart(Purchase ~ ., data = train, method = 'class',
control = rpart.control(cp = 0.02))
pruned_model
## n= 871
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 871 351 CH (0.59701493 0.40298507)
## 2) LoyalCH>=0.48285 542 95 CH (0.82472325 0.17527675)
## 4) LoyalCH>=0.7645725 281 14 CH (0.95017794 0.04982206) *
## 5) LoyalCH< 0.7645725 261 81 CH (0.68965517 0.31034483)
## 10) PriceDiff>=-0.165 220 50 CH (0.77272727 0.22727273) *
## 11) PriceDiff< -0.165 41 10 MM (0.24390244 0.75609756) *
## 3) LoyalCH< 0.48285 329 73 MM (0.22188450 0.77811550) *
Plot Prunned Tree
rpart.plot(pruned_model)
Model Comparison
postResample(predict(model_fit2,
train,
type = 'class'), train$Purchase)
## Accuracy Kappa
## 0.8691160 0.7259615
postResample(predict(pruned_model,
train,
type = 'class'), train$Purchase)
## Accuracy Kappa
## 0.8312285 0.6523105
Prunned model is more interpretable with less accuracy, whereas the Unprunned model is more complex with higher accuracy.