Lesson 8.1 - Decision Trees

Robbie Beane

Load Packages

library(ggplot2)
library(gridExtra)
library(caret)
library(rpart)
library(rpart.plot)

Example 1: Wisconsin Breast Cancer Dataset

wbc <- read.table('data/breast_cancer.csv', header=TRUE, sep=',')
wbc$id <- NULL
summary(wbc)
##  diagnosis  radius_mean      texture_mean   perimeter_mean  
##  B:357     Min.   : 6.981   Min.   : 9.71   Min.   : 43.79  
##  M:212     1st Qu.:11.700   1st Qu.:16.17   1st Qu.: 75.17  
##            Median :13.370   Median :18.84   Median : 86.24  
##            Mean   :14.127   Mean   :19.29   Mean   : 91.97  
##            3rd Qu.:15.780   3rd Qu.:21.80   3rd Qu.:104.10  
##            Max.   :28.110   Max.   :39.28   Max.   :188.50  
##    area_mean      smoothness_mean   compactness_mean  concavity_mean   
##  Min.   : 143.5   Min.   :0.05263   Min.   :0.01938   Min.   :0.00000  
##  1st Qu.: 420.3   1st Qu.:0.08637   1st Qu.:0.06492   1st Qu.:0.02956  
##  Median : 551.1   Median :0.09587   Median :0.09263   Median :0.06154  
##  Mean   : 654.9   Mean   :0.09636   Mean   :0.10434   Mean   :0.08880  
##  3rd Qu.: 782.7   3rd Qu.:0.10530   3rd Qu.:0.13040   3rd Qu.:0.13070  
##  Max.   :2501.0   Max.   :0.16340   Max.   :0.34540   Max.   :0.42680  
##  concave.points_mean symmetry_mean    fractal_dimension_mean
##  Min.   :0.00000     Min.   :0.1060   Min.   :0.04996       
##  1st Qu.:0.02031     1st Qu.:0.1619   1st Qu.:0.05770       
##  Median :0.03350     Median :0.1792   Median :0.06154       
##  Mean   :0.04892     Mean   :0.1812   Mean   :0.06280       
##  3rd Qu.:0.07400     3rd Qu.:0.1957   3rd Qu.:0.06612       
##  Max.   :0.20120     Max.   :0.3040   Max.   :0.09744       
##    radius_se        texture_se      perimeter_se       area_se       
##  Min.   :0.1115   Min.   :0.3602   Min.   : 0.757   Min.   :  6.802  
##  1st Qu.:0.2324   1st Qu.:0.8339   1st Qu.: 1.606   1st Qu.: 17.850  
##  Median :0.3242   Median :1.1080   Median : 2.287   Median : 24.530  
##  Mean   :0.4052   Mean   :1.2169   Mean   : 2.866   Mean   : 40.337  
##  3rd Qu.:0.4789   3rd Qu.:1.4740   3rd Qu.: 3.357   3rd Qu.: 45.190  
##  Max.   :2.8730   Max.   :4.8850   Max.   :21.980   Max.   :542.200  
##  smoothness_se      compactness_se      concavity_se    
##  Min.   :0.001713   Min.   :0.002252   Min.   :0.00000  
##  1st Qu.:0.005169   1st Qu.:0.013080   1st Qu.:0.01509  
##  Median :0.006380   Median :0.020450   Median :0.02589  
##  Mean   :0.007041   Mean   :0.025478   Mean   :0.03189  
##  3rd Qu.:0.008146   3rd Qu.:0.032450   3rd Qu.:0.04205  
##  Max.   :0.031130   Max.   :0.135400   Max.   :0.39600  
##  concave.points_se   symmetry_se       fractal_dimension_se
##  Min.   :0.000000   Min.   :0.007882   Min.   :0.0008948   
##  1st Qu.:0.007638   1st Qu.:0.015160   1st Qu.:0.0022480   
##  Median :0.010930   Median :0.018730   Median :0.0031870   
##  Mean   :0.011796   Mean   :0.020542   Mean   :0.0037949   
##  3rd Qu.:0.014710   3rd Qu.:0.023480   3rd Qu.:0.0045580   
##  Max.   :0.052790   Max.   :0.078950   Max.   :0.0298400   
##   radius_worst   texture_worst   perimeter_worst    area_worst    
##  Min.   : 7.93   Min.   :12.02   Min.   : 50.41   Min.   : 185.2  
##  1st Qu.:13.01   1st Qu.:21.08   1st Qu.: 84.11   1st Qu.: 515.3  
##  Median :14.97   Median :25.41   Median : 97.66   Median : 686.5  
##  Mean   :16.27   Mean   :25.68   Mean   :107.26   Mean   : 880.6  
##  3rd Qu.:18.79   3rd Qu.:29.72   3rd Qu.:125.40   3rd Qu.:1084.0  
##  Max.   :36.04   Max.   :49.54   Max.   :251.20   Max.   :4254.0  
##  smoothness_worst  compactness_worst concavity_worst  concave.points_worst
##  Min.   :0.07117   Min.   :0.02729   Min.   :0.0000   Min.   :0.00000     
##  1st Qu.:0.11660   1st Qu.:0.14720   1st Qu.:0.1145   1st Qu.:0.06493     
##  Median :0.13130   Median :0.21190   Median :0.2267   Median :0.09993     
##  Mean   :0.13237   Mean   :0.25427   Mean   :0.2722   Mean   :0.11461     
##  3rd Qu.:0.14600   3rd Qu.:0.33910   3rd Qu.:0.3829   3rd Qu.:0.16140     
##  Max.   :0.22260   Max.   :1.05800   Max.   :1.2520   Max.   :0.29100     
##  symmetry_worst   fractal_dimension_worst
##  Min.   :0.1565   Min.   :0.05504        
##  1st Qu.:0.2504   1st Qu.:0.07146        
##  Median :0.2822   Median :0.08004        
##  Mean   :0.2901   Mean   :0.08395        
##  3rd Qu.:0.3179   3rd Qu.:0.09208        
##  Max.   :0.6638   Max.   :0.20750
set.seed(1)
train.index <- createDataPartition(wbc$diagnosis, p = .8, list=FALSE)
train <- wbc[ train.index,]
test  <- wbc[-train.index,]

summary(train$diagnosis)
##   B   M 
## 286 170
summary(test$diagnosis)
##  B  M 
## 71 42

Creating a Decision Tree

tree_mod <- rpart(diagnosis ~ ., train, method="class", 
                  control = rpart.control(minsplit = 4, 
                                          minbucket =  2, 
                                          cp = 0, 
                                          maxdepth = 6))

print(tree_mod)
## n= 456 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 456 170 B (0.62719298 0.37280702)  
##    2) concave.points_worst< 0.1417 299  24 B (0.91973244 0.08026756)  
##      4) radius_worst< 17.72 281   8 B (0.97153025 0.02846975)  
##        8) area_se< 38.605 266   3 B (0.98872180 0.01127820)  
##         16) concave.points_worst< 0.13235 253   0 B (1.00000000 0.00000000) *
##         17) concave.points_worst>=0.13235 13   3 B (0.76923077 0.23076923)  
##           34) texture_mean< 21.54 10   0 B (1.00000000 0.00000000) *
##           35) texture_mean>=21.54 3   0 M (0.00000000 1.00000000) *
##        9) area_se>=38.605 15   5 B (0.66666667 0.33333333)  
##         18) compactness_se>=0.014885 11   1 B (0.90909091 0.09090909) *
##         19) compactness_se< 0.014885 4   0 M (0.00000000 1.00000000) *
##      5) radius_worst>=17.72 18   2 M (0.11111111 0.88888889)  
##       10) concavity_worst< 0.1981 4   2 B (0.50000000 0.50000000)  
##         20) texture_mean< 21.26 2   0 B (1.00000000 0.00000000) *
##         21) texture_mean>=21.26 2   0 M (0.00000000 1.00000000) *
##       11) concavity_worst>=0.1981 14   0 M (0.00000000 1.00000000) *
##    3) concave.points_worst>=0.1417 157  11 M (0.07006369 0.92993631)  
##      6) area_worst< 729.55 16   7 B (0.56250000 0.43750000)  
##       12) smoothness_mean< 0.1083 9   0 B (1.00000000 0.00000000) *
##       13) smoothness_mean>=0.1083 7   0 M (0.00000000 1.00000000) *
##      7) area_worst>=729.55 141   2 M (0.01418440 0.98581560) *

Creating a Decision Tree

rpart.plot(tree_mod, extra=1, cex=0.8)

Creating a Decision Tree

train_pred <- predict(tree_mod, train, type="class")
test_pred <- predict(tree_mod, test, type="class")

cat('Training Accuracy: ', mean(train_pred == train$diagnosis), '\n',
    'Test Set Accuracy: ', mean(test_pred == test$diagnosis), sep='')
## Training Accuracy: 0.9934211
## Test Set Accuracy: 0.9115044

Cross-Validation for Model Evaluation

set.seed(1)

bc_tree_cv <- train(diagnosis ~ ., wbc, method="rpart2", 
                 trControl = trainControl(method="cv", number=10),
                 tuneGrid = expand.grid(maxdepth=c(6)),
                 control = rpart.control(minsplit = 4, 
                                         minbucket =  2, 
                                         cp = 0))

bc_tree_cv
## CART 
## 
## 569 samples
##  30 predictor
##   2 classes: 'B', 'M' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 513, 512, 511, 512, 511, 513, ... 
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.9211326  0.8310163
## 
## Tuning parameter 'maxdepth' was held constant at a value of 6

Cross-Validation for Tuning Maximum Depth

set.seed(1)

bc_tree_cv_md <- train(diagnosis ~ ., wbc, method="rpart2", 
                       trControl = trainControl(method="cv", number=10),
                       tuneGrid = expand.grid(maxdepth=1:20),
                       control = rpart.control(minsplit = 4, 
                                               minbucket =  2, 
                                               cp = 0))

best_ix = which.max(bc_tree_cv_md$results$Accuracy)
bc_tree_cv_md$results[best_ix, ]
##   maxdepth  Accuracy     Kappa AccuracySD    KappaSD
## 5        5 0.9263644 0.8418275 0.04095775 0.08643895

Cross-Validation for Tuning Maximum Depth

plot(bc_tree_cv_md)

Cross-Validation for Tuning Maximum Depth

print(bc_tree_cv_md$finalModel)
## n= 569 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 569 212 B (0.627416520 0.372583480)  
##    2) radius_worst< 16.795 379  33 B (0.912928760 0.087071240)  
##      4) concave.points_worst< 0.1358 333   5 B (0.984984985 0.015015015)  
##        8) radius_se< 0.6431 328   3 B (0.990853659 0.009146341) *
##        9) radius_se>=0.6431 5   2 B (0.600000000 0.400000000)  
##         18) compactness_mean>=0.062575 3   0 B (1.000000000 0.000000000) *
##         19) compactness_mean< 0.062575 2   0 M (0.000000000 1.000000000) *
##      5) concave.points_worst>=0.1358 46  18 M (0.391304348 0.608695652)  
##       10) texture_worst< 25.67 19   4 B (0.789473684 0.210526316)  
##         20) area_worst< 810.3 15   1 B (0.933333333 0.066666667) *
##         21) area_worst>=810.3 4   1 M (0.250000000 0.750000000) *
##       11) texture_worst>=25.67 27   3 M (0.111111111 0.888888889)  
##         22) concavity_mean< 0.09679 6   3 B (0.500000000 0.500000000)  
##           44) texture_mean< 19.435 3   0 B (1.000000000 0.000000000) *
##           45) texture_mean>=19.435 3   0 M (0.000000000 1.000000000) *
##         23) concavity_mean>=0.09679 21   0 M (0.000000000 1.000000000) *
##    3) radius_worst>=16.795 190  11 M (0.057894737 0.942105263)  
##      6) texture_mean< 16.11 17   8 B (0.529411765 0.470588235)  
##       12) concave.points_mean< 0.06626 9   0 B (1.000000000 0.000000000) *
##       13) concave.points_mean>=0.06626 8   0 M (0.000000000 1.000000000) *
##      7) texture_mean>=16.11 173   2 M (0.011560694 0.988439306)  
##       14) concavity_worst< 0.1907 5   2 M (0.400000000 0.600000000)  
##         28) texture_mean< 21.26 2   0 B (1.000000000 0.000000000) *
##         29) texture_mean>=21.26 3   0 M (0.000000000 1.000000000) *
##       15) concavity_worst>=0.1907 168   0 M (0.000000000 1.000000000) *

Cross-Validation for Tuning Maximum Depth

rpart.plot(bc_tree_cv_md$finalModel, extra=1, cex=0.8)

Cross-Validation for Tuning Complexity Parameter

set.seed(1)

bc_tree_cv_cp <- train(diagnosis ~ ., wbc, method="rpart", 
                       trControl = trainControl(method="cv", number=10),
                       tuneGrid = expand.grid(cp=seq(0, 0.1, 0.001)),
                       control = rpart.control(minsplit = 4, 
                                               minbucket =  2, 
                                               maxdepth = 30))

best_ix = which.max(bc_tree_cv_cp$results$Accuracy)
bc_tree_cv_cp$results[best_ix, ]
##       cp  Accuracy     Kappa AccuracySD    KappaSD
## 12 0.011 0.9264584 0.8413244 0.03979754 0.08602599

Cross-Validation for Tuning Complexity Parameter

plot(bc_tree_cv_cp, pch="")

Example 2: Iris Dataset

iris <- read.table('data/iris.txt', sep='\t', header=TRUE)
summary(iris)
##   sepal_length    sepal_width     petal_length    petal_width   
##  Min.   :4.300   Min.   :2.000   Min.   :1.000   Min.   :0.100  
##  1st Qu.:5.100   1st Qu.:2.800   1st Qu.:1.600   1st Qu.:0.300  
##  Median :5.800   Median :3.000   Median :4.350   Median :1.300  
##  Mean   :5.843   Mean   :3.057   Mean   :3.758   Mean   :1.199  
##  3rd Qu.:6.400   3rd Qu.:3.300   3rd Qu.:5.100   3rd Qu.:1.800  
##  Max.   :7.900   Max.   :4.400   Max.   :6.900   Max.   :2.500  
##        species  
##  setosa    :50  
##  versicolor:50  
##  virginica :50  
##                 
##                 
## 

Example 2: Iris Dataset

p1 <- ggplot(iris, aes(x=sepal_length, y=sepal_width, col=species)) +
  geom_point(alpha=0.8)

p2 <- ggplot(iris, aes(x=petal_length, y=petal_width, col=species)) +
  geom_point(alpha=0.8)

grid.arrange(p1, p2, ncol=2)

Cross-Validation for Tuning Maximum Depth

set.seed(1)

iris_tree_cv <- train(species ~ ., iris, method="rpart2", 
                      trControl = trainControl(method="cv", number=10),
                      tuneGrid = expand.grid(maxdepth=1:20),
                      control = rpart.control(minsplit = 1, 
                                              minbucket =  1, 
                                              cp = 0))

best_ix = which.max(iris_tree_cv$results$Accuracy)
iris_tree_cv$results[best_ix, ]
##   maxdepth Accuracy Kappa AccuracySD    KappaSD
## 3        3     0.96  0.94 0.04661373 0.06992059

Cross-Validation for Tuning Maximum Depth

plot(iris_tree_cv)

Cross-Validation for Tuning Maximum Depth

rpart.plot(iris_tree_cv$finalModel, extra=1, cex=1)

Example 3: Wine Quality Dataset

wine <- read.table("data/winequality-white.csv", sep=",", header=TRUE)

wine$grade <- ifelse(wine$quality < 5, "L", 
                     ifelse(wine$quality < 7, "M", "H"))

wine$grade <- factor(wine$grade, levels=c("L", "M", "H"))
wine$quality <- NULL

summary(wine)
##  fixed.acidity    volatile.acidity  citric.acid     residual.sugar  
##  Min.   : 3.800   Min.   :0.0800   Min.   :0.0000   Min.   : 0.600  
##  1st Qu.: 6.300   1st Qu.:0.2100   1st Qu.:0.2700   1st Qu.: 1.700  
##  Median : 6.800   Median :0.2600   Median :0.3200   Median : 5.200  
##  Mean   : 6.855   Mean   :0.2782   Mean   :0.3342   Mean   : 6.391  
##  3rd Qu.: 7.300   3rd Qu.:0.3200   3rd Qu.:0.3900   3rd Qu.: 9.900  
##  Max.   :14.200   Max.   :1.1000   Max.   :1.6600   Max.   :65.800  
##    chlorides       free.sulfur.dioxide total.sulfur.dioxide
##  Min.   :0.00900   Min.   :  2.00      Min.   :  9.0       
##  1st Qu.:0.03600   1st Qu.: 23.00      1st Qu.:108.0       
##  Median :0.04300   Median : 34.00      Median :134.0       
##  Mean   :0.04577   Mean   : 35.31      Mean   :138.4       
##  3rd Qu.:0.05000   3rd Qu.: 46.00      3rd Qu.:167.0       
##  Max.   :0.34600   Max.   :289.00      Max.   :440.0       
##     density             pH          sulphates         alcohol     
##  Min.   :0.9871   Min.   :2.720   Min.   :0.2200   Min.   : 8.00  
##  1st Qu.:0.9917   1st Qu.:3.090   1st Qu.:0.4100   1st Qu.: 9.50  
##  Median :0.9937   Median :3.180   Median :0.4700   Median :10.40  
##  Mean   :0.9940   Mean   :3.188   Mean   :0.4898   Mean   :10.51  
##  3rd Qu.:0.9961   3rd Qu.:3.280   3rd Qu.:0.5500   3rd Qu.:11.40  
##  Max.   :1.0390   Max.   :3.820   Max.   :1.0800   Max.   :14.20  
##  grade   
##  L: 183  
##  M:3655  
##  H:1060  
##          
##          
## 

Cross-Validation for Tuning Maximum Depth

set.seed(1)

wine_tree_cv_md <- train(grade ~ ., wine, method="rpart2", 
                         trControl = trainControl(method="cv", number=10),
                         tuneGrid = expand.grid(maxdepth=seq(50, 300, by=5)),
                         control = rpart.control(minsplit = 16, 
                                                 minbucket =  1, 
                                                 cp = 0))

best_ix = which.max(wine_tree_cv_md$results$Accuracy)
wine_tree_cv_md$results[best_ix, ]
##    maxdepth  Accuracy     Kappa AccuracySD    KappaSD
## 17      130 0.7901136 0.4125933 0.02726712 0.06787859

Cross-Validation for Tuning Maximum Depth

plot(wine_tree_cv_md)

Cross-Validation for Tuning Complexity Parameter

set.seed(1)

wine_tree_cv_cp <- train(grade ~ ., wine, method="rpart", 
                         trControl = trainControl(method="cv", number=10),
                         tuneGrid = expand.grid(cp=seq(0, 0.05, 0.001)),
                         control = rpart.control(minsplit = 32, 
                                                 minbucket =  4, 
                                                 maxdepth = 10))

best_ix = which.max(wine_tree_cv_cp$results$Accuracy)
wine_tree_cv_cp$results[best_ix, ]
##      cp  Accuracy     Kappa AccuracySD    KappaSD
## 5 0.004 0.7774509 0.3258256 0.02246495 0.06582335

Cross-Validation for Tuning Complexity Parameter

plot(wine_tree_cv_cp)

Cross-Validation for Tuning Complexity Parameter

rpart.plot(wine_tree_cv_cp$finalModel, extra=1, cex = 0.8)

Example 4: Diamonds Dataset

diamonds <- read.table("data/diamonds.txt", sep="\t", header=TRUE)
diamonds <- diamonds[,c(1:4,7)]
summary(diamonds)
##      carat               cut        color        clarity     
##  Min.   :0.2000   Fair     : 1610   D: 6775   SI1    :13065  
##  1st Qu.:0.4000   Good     : 4906   E: 9797   VS2    :12258  
##  Median :0.7000   Ideal    :21551   F: 9542   SI2    : 9194  
##  Mean   :0.7979   Premium  :13791   G:11292   VS1    : 8171  
##  3rd Qu.:1.0400   Very Good:12082   H: 8304   VVS2   : 5066  
##  Max.   :5.0100                     I: 5422   VVS1   : 3655  
##                                     J: 2808   (Other): 2531  
##      price      
##  Min.   :  326  
##  1st Qu.:  950  
##  Median : 2401  
##  Mean   : 3933  
##  3rd Qu.: 5324  
##  Max.   :18823  
## 

Creating a Single Tree Model

set.seed(1)

dmd_tree_mod <- rpart(price ~ ., diamonds, method="anova", 
                      control = rpart.control(minsplit = 32, 
                                              minbucket =  16, 
                                              cp = 0.001, 
                                              maxdepth = 6))

rpart.plot(dmd_tree_mod, extra=1, cex = 0.8)

Creating a Single Tree Model

pred <- predict(dmd_tree_mod, diamonds)
SSE <- sum((diamonds$price - pred)^2)
SST <- sum((diamonds$price - mean(diamonds$price))^2)
r2 <- 1 - SSE / SST
r2
## [1] 0.9453158

Cross-Validation for Tuning Maximum Depth

set.seed(1)

dmd_tree_cv_md <- train(price ~ ., diamonds, method="rpart2", metric="Rsquared", 
                        trControl = trainControl(method="cv", number=10),
                        tuneGrid = expand.grid(maxdepth=seq(50, 600, by=50)),
                        control = rpart.control(minsplit = 32, 
                                                minbucket =  16, 
                                                 cp = 0.00001))

best_ix = which.max(dmd_tree_cv_md$results$Rsquared)
dmd_tree_cv_md$results[best_ix, ]
##   maxdepth     RMSE  Rsquared      MAE   RMSESD  RsquaredSD    MAESD
## 9      450 755.9399 0.9640919 385.6539 29.01713 0.002404378 9.637386

Cross-Validation for Tuning Maximum Depth

plot(dmd_tree_cv_md)