pm1 <- seq(0, 1, 0.01)
gini_index <- pm1*(1 - pm1)*2
entropy <- -(pm1*log(pm1) + (1 - pm1)*log(1 - pm1))
class.error <- 1 - pmax(pm1, 1 - pm1)
matplot(pm1, cbind(gini_index, entropy, class.error), col = c("orange", "green", "brown"))
set.seed(1)
attach(Carseats)
#a
train <- sample(1:nrow(Carseats), nrow(Carseats) / 2)
car.train <- Carseats[train, ]
car.test <- Carseats[-train, ]
car.tree <- tree(Sales ~., data = car.train)
plot(car.tree)
text(car.tree)
predict.tree <- predict(car.tree, car.test)
mse <- mean((predict.tree - car.test$Sales)^2)
mse
## [1] 4.922039
cross.cars <- cv.tree(car.tree, FUN = prune.tree)
par(mfrow = c(1, 2))
plot(cross.cars$size, cross.cars$dev, type = "b")
plot(cross.cars$k, cross.cars$dev, type = "b")
prun.car <- prune.tree(car.tree, best = 13)
par(mfrow = c(1, 1))
plot(prun.car)
text(prun.car, pretty = 0)
prun.pred <- predict(prun.car, car.test)
prun.mse <- mean((car.test$Sales - prun.pred)^2)
prun.mse
## [1] 4.96547
#d
set.seed(1)
bagged.car <- randomForest(Sales ~., data = car.train, mtry = 10, ntree = 500, importance = T)
bagged.pred <- predict(bagged.car, car.test)
bagged.mse <- mean((car.test$Sales - bagged.pred)^2)
bagged.mse
## [1] 2.605253
importance(bagged.car)
## %IncMSE IncNodePurity
## CompPrice 24.8888481 170.182937
## Income 4.7121131 91.264880
## Advertising 12.7692401 97.164338
## Population -1.8074075 58.244596
## Price 56.3326252 502.903407
## ShelveLoc 48.8886689 380.032715
## Age 17.7275460 157.846774
## Education 0.5962186 44.598731
## Urban 0.1728373 9.822082
## US 4.2172102 18.073863
forest.car <- randomForest(Sales ~., data = car.train, mtry = 5, importance = T)
forest.pred <- predict(forest.car, car.test)
forest.mse <- mean((car.test$Sales - forest.pred)^2)
forest.mse
## [1] 2.710806
importance(forest.car)
## %IncMSE IncNodePurity
## CompPrice 18.3350929 163.77398
## Income 4.7546303 115.94202
## Advertising 10.4995404 98.99790
## Population -0.8865995 78.91872
## Price 45.2942175 451.27503
## ShelveLoc 40.8250417 333.45728
## Age 13.6552105 165.03566
## Education 0.9386481 56.40921
## Urban -0.8796633 11.21732
## US 6.1388918 25.33755
attach(OJ)
set.seed(1)
train <- sample(dim(OJ)[1], 800)
oj.train <- OJ[train, ]
oj.test <- OJ[-train, ]
fit.oj <- tree(Purchase ~., data = oj.train)
summary(fit.oj)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = oj.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
fit.oj
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1073.00 CH ( 0.60625 0.39375 )
## 2) LoyalCH < 0.5036 365 441.60 MM ( 0.29315 0.70685 )
## 4) LoyalCH < 0.280875 177 140.50 MM ( 0.13559 0.86441 )
## 8) LoyalCH < 0.0356415 59 10.14 MM ( 0.01695 0.98305 ) *
## 9) LoyalCH > 0.0356415 118 116.40 MM ( 0.19492 0.80508 ) *
## 5) LoyalCH > 0.280875 188 258.00 MM ( 0.44149 0.55851 )
## 10) PriceDiff < 0.05 79 84.79 MM ( 0.22785 0.77215 )
## 20) SpecialCH < 0.5 64 51.98 MM ( 0.14062 0.85938 ) *
## 21) SpecialCH > 0.5 15 20.19 CH ( 0.60000 0.40000 ) *
## 11) PriceDiff > 0.05 109 147.00 CH ( 0.59633 0.40367 ) *
## 3) LoyalCH > 0.5036 435 337.90 CH ( 0.86897 0.13103 )
## 6) LoyalCH < 0.764572 174 201.00 CH ( 0.73563 0.26437 )
## 12) ListPriceDiff < 0.235 72 99.81 MM ( 0.50000 0.50000 )
## 24) PctDiscMM < 0.196196 55 73.14 CH ( 0.61818 0.38182 ) *
## 25) PctDiscMM > 0.196196 17 12.32 MM ( 0.11765 0.88235 ) *
## 13) ListPriceDiff > 0.235 102 65.43 CH ( 0.90196 0.09804 ) *
## 7) LoyalCH > 0.764572 261 91.20 CH ( 0.95785 0.04215 ) *
plot(fit.oj)
text(fit.oj, pretty = 0)
##### LoyalCH is the most important factor in the tree.
oj.pred <- predict(fit.oj, oj.test, type = "class")
table(oj.test$Purchase, oj.pred)
## oj.pred
## CH MM
## CH 160 8
## MM 38 64
(38+8)/270
## [1] 0.1703704
## why is there moreeee
cross.oj <- cv.tree(fit.oj, FUN = prune.tree)
plot(cross.oj$size, cross.oj$dev, type = "b", xlab = "Tree Size", ylab = "Cross-Validation Error Rate")
prun.oj <- prune.misclass(fit.oj, best = 5)
plot(prun.oj)
text(prun.oj, pretty = 0)
#### j
summary(fit.oj)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = oj.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
summary(prun.oj)
##
## Classification tree:
## snip.tree(tree = fit.oj, nodes = c(4L, 10L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff" "PctDiscMM"
## Number of terminal nodes: 7
## Residual mean deviance: 0.7748 = 614.4 / 793
## Misclassification error rate: 0.1625 = 130 / 800