Decision Tree:

if we can come up with a set of splitting rules to segment or stratify the predictor space into simple region so that we can classify the observation for a class of outcome variable and we summarize these splitting rules in a form of a tree…….this approach of a statistical learning method is called “DECISION TREE”.

  1. can be applied for both regression and classification problems.
  2. good for non-linear data (relationship between outcome and predictors is non-linear)
  3. simple to implement and understand
  4. high interpretability
  5. not as accurate for prediction as other flexible methods.

but combining a lage number of trees can often result in dramatic increase in prediction accuracy, at the expense of some loss in interpretability.

Bagging, Random forests and Boosting are tree-based methods on this concept.

SPLIT CRITERIA:

for Regression Tree:

  1. minimum RSS (Residual Sum of Squares).

for Classification Tree:

  1. Misclassification error (0 = perfect purity; 0.5 = no purity)
  2. Gini Index (0 = perfect purity; 0.5 = no purity)
  3. Information Gain (Cross-entropy) (0 = perfect purity; 1 = no purity)

Information gain uses log2, if loge then called the Deviance.

normally Gini-Index or information-gain is used to build trees as well as prune trees.

If prediction accuracy of the model is the goal then misclassification-error is used to prune the tree.

ADVANTAGES:

  1. high interpretability - easy to expalin
  2. high visualization power
  3. close to resemblance to human-thinking process.
  4. can handle qualitative predictors without creating dummy variables.

DISADVANTAGES:

  1. lower prediction accuracy
  2. non-robust with small changes in data
suppressMessages(library(tree))
suppressMessages(library(ISLR))

carseats = Carseats
names(carseats) = tolower(names(carseats))

# sales is a continuous numeric variable.
carseats$high = ifelse(carseats$sales <= 8, "No","Yes")
carseats$high = as.factor(carseats$high)

# creating training and test sets: 50-50 split
set.seed(2)
index = sample(1:nrow(carseats),200)
# index = sample(1:nrow(carseats),0.50*nrow(carseats)) - alternate method

train = carseats[index,]
test = carseats[-index,]

high.test = carseats$high[-index]

tree.carseats = tree(high~.-sales,data=train) # model

tree.carseats
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##   1) root 200 269.200 No ( 0.60000 0.40000 )  
##     2) shelveloc: Bad,Medium 153 185.400 No ( 0.70588 0.29412 )  
##       4) price < 142 130 167.700 No ( 0.65385 0.34615 )  
##         8) shelveloc: Bad 39  29.870 No ( 0.87179 0.12821 )  
##          16) income < 100 34  15.210 No ( 0.94118 0.05882 )  
##            32) age < 33.5 6   7.638 No ( 0.66667 0.33333 ) *
##            33) age > 33.5 28   0.000 No ( 1.00000 0.00000 ) *
##          17) income > 100 5   6.730 Yes ( 0.40000 0.60000 ) *
##         9) shelveloc: Medium 91 124.800 No ( 0.56044 0.43956 )  
##          18) price < 86.5 9   0.000 Yes ( 0.00000 1.00000 ) *
##          19) price > 86.5 82 108.700 No ( 0.62195 0.37805 )  
##            38) advertising < 6.5 52  56.180 No ( 0.76923 0.23077 )  
##              76) advertising < 1.5 36  45.830 No ( 0.66667 0.33333 )  
##               152) compprice < 115.5 10   0.000 No ( 1.00000 0.00000 ) *
##               153) compprice > 115.5 26  35.890 No ( 0.53846 0.46154 )  
##                 306) age < 33.5 5   0.000 Yes ( 0.00000 1.00000 ) *
##                 307) age > 33.5 21  26.730 No ( 0.66667 0.33333 )  
##                   614) price < 108.5 10  13.460 Yes ( 0.40000 0.60000 ) *
##                   615) price > 108.5 11   6.702 No ( 0.90909 0.09091 ) *
##              77) advertising > 1.5 16   0.000 No ( 1.00000 0.00000 ) *
##            39) advertising > 6.5 30  39.430 Yes ( 0.36667 0.63333 )  
##              78) age < 37.5 5   0.000 Yes ( 0.00000 1.00000 ) *
##              79) age > 37.5 25  34.300 Yes ( 0.44000 0.56000 )  
##               158) compprice < 118.5 8   8.997 No ( 0.75000 0.25000 ) *
##               159) compprice > 118.5 17  20.600 Yes ( 0.29412 0.70588 )  
##                 318) advertising < 12.5 10  13.860 Yes ( 0.50000 0.50000 ) *
##                 319) advertising > 12.5 7   0.000 Yes ( 0.00000 1.00000 ) *
##       5) price > 142 23   0.000 No ( 1.00000 0.00000 ) *
##     3) shelveloc: Good 47  53.400 Yes ( 0.25532 0.74468 )  
##       6) price < 142.5 38  29.590 Yes ( 0.13158 0.86842 )  
##        12) population < 278 17   0.000 Yes ( 0.00000 1.00000 ) *
##        13) population > 278 21  23.050 Yes ( 0.23810 0.76190 )  
##          26) advertising < 10.5 13  17.320 Yes ( 0.38462 0.61538 )  
##            52) price < 99.5 5   0.000 Yes ( 0.00000 1.00000 ) *
##            53) price > 99.5 8  10.590 No ( 0.62500 0.37500 ) *
##          27) advertising > 10.5 8   0.000 Yes ( 0.00000 1.00000 ) *
##       7) price > 142.5 9   9.535 No ( 0.77778 0.22222 ) *
summary(tree.carseats)
## 
## Classification tree:
## tree(formula = high ~ . - sales, data = train)
## Variables actually used in tree construction:
## [1] "shelveloc"   "price"       "income"      "age"         "advertising"
## [6] "compprice"   "population" 
## Number of terminal nodes:  19 
## Residual mean deviance:  0.4282 = 77.51 / 181 
## Misclassification error rate: 0.105 = 21 / 200
pred = predict(tree.carseats, test, type="class")# type of class is because of classification tree.

# plotting the tree
plot(tree.carseats)
text(tree.carseats, pretty=0)

# assessing the model
tree.confusion = table(pred, high.test)
tree.confusion
##      high.test
## pred  No Yes
##   No  86  27
##   Yes 30  57
tree.accuracy.percentage = round(mean(pred == high.test)*100,2)
tree.accuracy.percentage
## [1] 71.5
tree.error.percentage = round(mean(pred != high.test)*100,2)
tree.error.percentage
## [1] 28.5
tree.error.percentage+tree.accuracy.percentage
## [1] 100
set.seed(3)

# cross-validating the model for pruning levels
cv.carseats = cv.tree(tree.carseats, FUN=prune.misclass) # FUN part is only for classification trees

names(cv.carseats)
## [1] "size"   "dev"    "k"      "method"
cv.carseats # model with pruning levels, you look at this model to determine which pruning level is best for your model. this is not final pruned model.
## $size
## [1] 19 17 14 13  9  7  3  2  1
## 
## $dev
## [1] 55 55 53 52 50 56 69 65 80
## 
## $k
## [1]       -Inf  0.0000000  0.6666667  1.0000000  1.7500000  2.0000000
## [7]  4.2500000  5.0000000 23.0000000
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"
par(mfrow=c(1,2))
plot(cv.carseats$size, cv.carseats$dev, type = "b")
plot(cv.carseats$k, cv.carseats$dev, type = "b")

par(default_par)

# this is to achieve final desired pruned model with level you like from above model.
prune.carseats = prune.misclass(tree.carseats,best=9)

# plot the pruned model (best=9)
plot(prune.carseats)
text(prune.carseats, pretty=0)

# prediction using pruned model (best=9)
pruned.pred = predict(prune.carseats, test, type="class") # type of class is because of classification tree.

# assessing the model
tree.confusion.pruned = table(pruned.pred, high.test)
tree.confusion.pruned 
##            high.test
## pruned.pred No Yes
##         No  94  24
##         Yes 22  60
tree.accuracy.percentage = round(mean(pruned.pred == high.test)*100,2)
tree.accuracy.percentage
## [1] 77
tree.error.percentage = round(mean(pruned.pred != high.test)*100,2)
tree.error.percentage
## [1] 23
tree.error.percentage+tree.accuracy.percentage
## [1] 100
suppressMessages(library(MASS))

boston = Boston
names(boston) = tolower(names(boston))

set.seed(1)

index = sample(1:nrow(boston),nrow(boston)/2)
train = boston[index,]
test = boston[-index,]
medv.test = boston$medv[-index]


# fit the tree model
tree.boston = tree(medv ~ ., data=boston)

tree.boston
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 506 42720.0 22.53  
##    2) rm < 6.941 430 17320.0 19.93  
##      4) lstat < 14.4 255  6632.0 23.35  
##        8) dis < 1.38485 5   390.7 45.58 *
##        9) dis > 1.38485 250  3721.0 22.91  
##         18) rm < 6.543 195  1636.0 21.63 *
##         19) rm > 6.543 55   643.2 27.43 *
##      5) lstat > 14.4 175  3373.0 14.96  
##       10) crim < 6.99237 101  1151.0 17.14 *
##       11) crim > 6.99237 74  1086.0 11.98 *
##    3) rm > 6.941 76  6059.0 37.24  
##      6) rm < 7.437 46  1900.0 32.11  
##       12) lstat < 11.455 41   844.2 33.50 *
##       13) lstat > 11.455 5   329.8 20.74 *
##      7) rm > 7.437 30  1099.0 45.10  
##       14) ptratio < 17.9 25   340.7 46.82 *
##       15) ptratio > 17.9 5   312.7 36.48 *
summary(tree.boston)
## 
## Regression tree:
## tree(formula = medv ~ ., data = boston)
## Variables actually used in tree construction:
## [1] "rm"      "lstat"   "dis"     "crim"    "ptratio"
## Number of terminal nodes:  9 
## Residual mean deviance:  13.55 = 6734 / 497 
## Distribution of residuals:
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## -17.68000  -2.23000   0.07026   0.00000   2.22100  16.50000
plot(tree.boston)
text(tree.boston, pretty=0)

# cross-validating the model for pruning
cv.boston = cv.tree(tree.boston)
cv.boston
## $size
## [1] 9 8 7 6 5 4 3 2 1
## 
## $dev
## [1] 10039.23 10278.44 11620.15 12645.14 13660.15 15176.62 16243.87 27260.55
## [9] 42830.65
## 
## $k
## [1]       -Inf   445.4817   725.6002  1136.8088  1441.9267  2520.3263
## [7]  3060.9575  7311.8524 19339.5550
## 
## $method
## [1] "deviance"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"
# check the cv model
par(mfrow=c(1,2))
plot(cv.boston$size, cv.boston$dev, type="b")
plot(cv.boston$k, cv.boston$dev, type="b")

par(default_par)

# prune the tree at best=5
prune.boston = prune.tree(tree.boston, best=5)
plot(prune.boston)
text(prune.boston, pretty=0)

# predict using prned model
pruned.pred = predict(prune.boston, test)

tree.test.MSE = mean((pruned.pred-medv.test)^2)
tree.test.MSE
## [1] 19.90316
model_SE = sqrt(tree.test.MSE)
model_SE
## [1] 4.461296
# model SE of 4.461 means this model leads to predictions that are within around $4461 of true median home value for the suburb.