Based on: https://www.datacamp.com/community/tutorials/decision-trees-R See the text for interpretation of results.
#library(aod)
library(ggplot2)
library(ISLR)
#data(package="ISLR")
carseats <- Carseats
#install.packages("tree")
library(tree)
High = ifelse(carseats$Sales<=8, "No", "Yes")
carseats = data.frame(carseats, High)
tree.carseats = tree(High~.-Sales, data=carseats)
summary(tree.carseats)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
plot(tree.carseats)
text(tree.carseats, pretty = 0)
dim(carseats)
## [1] 400 12
set.seed(101)
train=sample(1:nrow(carseats), 250)
test = carseats[-train,]
tree.carseats = tree(High~.-Sales, carseats, subset=train)
plot(tree.carseats)
text(tree.carseats, pretty=0)
tree.pred = predict(tree.carseats, carseats[-train,], type="class")
with(carseats[-train,], table(tree.pred, High))
## High
## tree.pred No Yes
## No 72 16
## Yes 19 43
pred <- carseats[-train,]
#install.packages("caret")
library(caret)
## Loading required package: lattice
#install.packages("e1071")
library(e1071)
confusionMatrix(table(tree.pred, pred$High))
## Confusion Matrix and Statistics
##
##
## tree.pred No Yes
## No 72 16
## Yes 19 43
##
## Accuracy : 0.7667
## 95% CI : (0.6907, 0.8318)
## No Information Rate : 0.6067
## P-Value [Acc > NIR] : 2.476e-05
##
## Kappa : 0.5154
##
## Mcnemar's Test P-Value : 0.7353
##
## Sensitivity : 0.7912
## Specificity : 0.7288
## Pos Pred Value : 0.8182
## Neg Pred Value : 0.6935
## Prevalence : 0.6067
## Detection Rate : 0.4800
## Detection Prevalence : 0.5867
## Balanced Accuracy : 0.7600
##
## 'Positive' Class : No
##
Decision tree shows 77% accuracy (improvement from 62% correct, p < 0.001).
str(carseats)
## 'data.frame': 400 obs. of 12 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
## $ High : Factor w/ 2 levels "No","Yes": 2 2 2 1 1 2 1 2 1 1 ...
carseats$High = ifelse(carseats$Sales<=8, "No", "Yes")
length(carseats$High)
## [1] 400
class(carseats$High)
## [1] "character"
carseats$High[carseats$High == "Yes"] <- "1"
carseats$High[carseats$High == "No"] <- "0"
table(carseats$High)
##
## 0 1
## 236 164
class(carseats$High)
## [1] "character"
carseats$High <- as.numeric(carseats$High)
table(carseats$High)
##
## 0 1
## 236 164
set.seed(101)
train = sample(1:nrow(carseats), 250)
log.carseats = glm(High ~ ShelveLoc + Price + Age + CompPrice + Income + Advertising + Population,
family = "binomial", data = carseats[train,], na.action = na.omit)
summary(log.carseats)
##
## Call:
## glm(formula = High ~ ShelveLoc + Price + Age + CompPrice + Income +
## Advertising + Population, family = "binomial", data = carseats[train,
## ], na.action = na.omit)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.48378 -0.29120 -0.04611 0.19053 2.51107
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -5.3335196 2.6481922 -2.014 0.04401 *
## ShelveLocGood 9.1386202 1.3993786 6.530 6.56e-11 ***
## ShelveLocMedium 4.0923411 0.8466083 4.834 1.34e-06 ***
## Price -0.1774524 0.0260864 -6.802 1.03e-11 ***
## Age -0.0723547 0.0175675 -4.119 3.81e-05 ***
## CompPrice 0.1669547 0.0289113 5.775 7.71e-09 ***
## Income 0.0334110 0.0099544 3.356 0.00079 ***
## Advertising 0.2377630 0.0472728 5.030 4.92e-07 ***
## Population -0.0001372 0.0017979 -0.076 0.93917
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 340.15 on 249 degrees of freedom
## Residual deviance: 112.51 on 241 degrees of freedom
## AIC: 130.51
##
## Number of Fisher Scoring iterations: 7
library(tidyverse)
library(magrittr)
coefsExp <- coef(log.carseats) %>% exp() %>% round(2)
coefsExp
## (Intercept) ShelveLocGood ShelveLocMedium Price
## 0.00 9307.91 59.88 0.84
## Age CompPrice Income Advertising
## 0.93 1.18 1.03 1.27
## Population
## 1.00
library(margins)
m <- margins(log.carseats)
## Warning in warn_for_weights(model): 'weights' used in model estimation are
## currently ignored!
summary(m)
## factor AME SE z p lower upper
## Advertising 0.0163 0.0024 6.8543 0.0000 0.0117 0.0210
## Age -0.0050 0.0010 -5.0219 0.0000 -0.0069 -0.0030
## CompPrice 0.0115 0.0012 9.4171 0.0000 0.0091 0.0139
## Income 0.0023 0.0006 3.7643 0.0002 0.0011 0.0035
## Population -0.0000 0.0001 -0.0763 0.9392 -0.0003 0.0002
## Price -0.0122 0.0007 -18.1992 0.0000 -0.0135 -0.0109
## ShelveLocGood 0.7202 0.0451 15.9661 0.0000 0.6318 0.8087
## ShelveLocMedium 0.2792 0.0379 7.3677 0.0000 0.2049 0.3535
log.pred = format(round(predict(log.carseats, newdata = carseats[-train,], type = "response")))
options(scipen = 30)
confusionMatrix(table(log.pred, carseats[-train,]$High))
## Confusion Matrix and Statistics
##
##
## log.pred 0 1
## 0 84 9
## 1 7 50
##
## Accuracy : 0.8933
## 95% CI : (0.8326, 0.9378)
## No Information Rate : 0.6067
## P-Value [Acc > NIR] : 0.000000000000004513
##
## Kappa : 0.7752
##
## Mcnemar's Test P-Value : 0.8026
##
## Sensitivity : 0.9231
## Specificity : 0.8475
## Pos Pred Value : 0.9032
## Neg Pred Value : 0.8772
## Prevalence : 0.6067
## Detection Rate : 0.5600
## Detection Prevalence : 0.6200
## Balanced Accuracy : 0.8853
##
## 'Positive' Class : 0
##
#length(carseats[-train,]$High)
#length(log.pred)
Logistic regression shows 89% accuracy (improvement from 62% correct, p < 0.001). Log reg wins here. There is no overall answer. Use both methods wisely and compare their performance.
How about methods which are less interpretable but more efficient in prediction?
#install.packages("randomForest")
library(randomForest)
rand.carseats = randomForest(Sales~. -High, data = carseats, subset = train)
rand.carseats
##
## Call:
## randomForest(formula = Sales ~ . - High, data = carseats, subset = train)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 3
##
## Mean of squared residuals: 3.124066
## % Var explained: 61.1
oob.err = double(10)
test.err = double(10)
for(mtry in 1:10){
fit = randomForest(Sales~. -High, data = carseats, subset = train, mtry=mtry, ntree = 500)
oob.err[mtry] = fit$mse[500]
pred = predict(fit, carseats[-train,])
test.err[mtry] = with(carseats[-train,], mean( (Sales-pred)^2 ))
}
matplot(1:mtry, cbind(test.err, oob.err), pch = 23, col = c("red", "blue"), type = "b", ylab="Mean Squared Error")
legend("topright", legend = c("OOB", "Test"), pch = 23, col = c("red", "blue"))
rand3.carseats = randomForest(Sales~. -High, data = carseats, subset = train, mtry = 3)
rand3.carseats
##
## Call:
## randomForest(formula = Sales ~ . - High, data = carseats, mtry = 3, subset = train)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 3
##
## Mean of squared residuals: 3.16896
## % Var explained: 60.54
#install.packages("gbm")
library(gbm)
## Loading required package: survival
##
## Attaching package: 'survival'
## The following object is masked from 'package:caret':
##
## cluster
## Loading required package: splines
## Loading required package: parallel
## Loaded gbm 2.1.3
boost.cars = gbm(Sales~.-High, data = carseats[train,], distribution = "gaussian", n.trees = 10000, shrinkage = 0.01, interaction.depth = 4)
summary(boost.cars, las = 2)
## var rel.inf
## Price Price 27.8620347
## ShelveLoc ShelveLoc 23.7994541
## CompPrice CompPrice 12.9284823
## Age Age 9.8484924
## Income Income 8.1932597
## Advertising Advertising 7.0233848
## Population Population 6.8421064
## Education Education 2.4476001
## US US 0.5524655
## Urban Urban 0.5027199
plot(boost.cars,i="Price")
plot(boost.cars,i="ShelveLoc")
n.trees = seq(from = 100, to = 10000, by = 100)
predmat = predict(boost.cars, newdata = carseats[-train,], n.trees = n.trees)
dim(predmat)
## [1] 150 100
boost.err = with(carseats[-train,], apply( (predmat - Sales)^2, 2, mean) )
plot(n.trees, boost.err, pch = 23, ylab = "Mean Squared Error", xlab = "# Trees", main = "Boosting Test Error")
abline(h = min(test.err), col = "red")
Literature-based summary:
Repeat the comparison on the following sample:
mydata <- read.csv("https://stats.idre.ucla.edu/stat/data/binary.csv")
head(mydata)
## admit gre gpa rank
## 1 0 380 3.61 3
## 2 1 660 3.67 3
## 3 1 800 4.00 1
## 4 1 640 3.19 4
## 5 0 520 2.93 4
## 6 1 760 3.00 2
dim(mydata)
## [1] 400 4