Objective

  • What are decision trees?
  • Using the recursive partitioning algorithm to predict animal classes
  • An important weakness of decision trees

Show case

  • how to build a decision tree with rpart, and how to tune its hyperparameters.
  • task: in public engagement at a wildlife sanctuary, we create an interactive game for children, to teach them about different animal classes.
    • the game asks the children to think of any animal in the sanctuary, and then asks them questions about the physical characteristics of that animal.
    • given the child’s responses, the model should tell the child what class their animal belongs to (mammal, bird, reptile etc.).

Zoo dataset

# Loading and exploring the zoo dataset
data(Zoo, package = "mlbench")
zooTib <- as_tibble(Zoo)
zooTib
## # A tibble: 101 x 17
##    hair  feathers eggs  milk  airborne aquatic predator toothed backbone
##    <lgl> <lgl>    <lgl> <lgl> <lgl>    <lgl>   <lgl>    <lgl>   <lgl>   
##  1 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   TRUE     TRUE    TRUE    
##  2 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   FALSE    TRUE    TRUE    
##  3 FALSE FALSE    TRUE  FALSE FALSE    TRUE    TRUE     TRUE    TRUE    
##  4 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   TRUE     TRUE    TRUE    
##  5 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   TRUE     TRUE    TRUE    
##  6 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   FALSE    TRUE    TRUE    
##  7 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   FALSE    TRUE    TRUE    
##  8 FALSE FALSE    TRUE  FALSE FALSE    TRUE    FALSE    TRUE    TRUE    
##  9 FALSE FALSE    TRUE  FALSE FALSE    TRUE    TRUE     TRUE    TRUE    
## 10 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   FALSE    TRUE    TRUE    
## # ... with 91 more rows, and 8 more variables: breathes <lgl>, venomous <lgl>,
## #   fins <lgl>, legs <int>, tail <lgl>, domestic <lgl>, catsize <lgl>,
## #   type <fct>
# Converting logical variables to factors
zooTib <- mutate_if(zooTib, is.logical, as.factor) # mutate_all(zooTib, as.factor)

# Show first few lines of dataset
zooTib
## # A tibble: 101 x 17
##    hair  feathers eggs  milk  airborne aquatic predator toothed backbone
##    <fct> <fct>    <fct> <fct> <fct>    <fct>   <fct>    <fct>   <fct>   
##  1 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   TRUE     TRUE    TRUE    
##  2 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   FALSE    TRUE    TRUE    
##  3 FALSE FALSE    TRUE  FALSE FALSE    TRUE    TRUE     TRUE    TRUE    
##  4 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   TRUE     TRUE    TRUE    
##  5 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   TRUE     TRUE    TRUE    
##  6 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   FALSE    TRUE    TRUE    
##  7 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   FALSE    TRUE    TRUE    
##  8 FALSE FALSE    TRUE  FALSE FALSE    TRUE    FALSE    TRUE    TRUE    
##  9 FALSE FALSE    TRUE  FALSE FALSE    TRUE    TRUE     TRUE    TRUE    
## 10 TRUE  FALSE    FALSE TRUE  FALSE    FALSE   FALSE    TRUE    TRUE    
## # ... with 91 more rows, and 8 more variables: breathes <fct>, venomous <fct>,
## #   fins <fct>, legs <int>, tail <fct>, domestic <fct>, catsize <fct>,
## #   type <fct>
# List all variable names
names(zooTib)
##  [1] "hair"     "feathers" "eggs"     "milk"     "airborne" "aquatic" 
##  [7] "predator" "toothed"  "backbone" "breathes" "venomous" "fins"    
## [13] "legs"     "tail"     "domestic" "catsize"  "type"

Training the decision tree model

# Creating the task and learner
zooTask <- makeClassifTask(data = zooTib, target = "type")
tree <- makeLearner("classif.rpart")

Hyperparameters of the rpart algorithm

  • Objectives:
    • show you what hyperparameters need to be tuned for the rpart algorithm,
    • what they do, and
    • why we need to tune them in order to get the best performing tree possible
  • Decision tree algorithms are described as greedy i.e. it searchs for the split that will perform best at the current node, rather than the one that will produce the best result globally.

  • Cons:
    • the algorithm isn’t guaranteed to learn a globally optimal model
    • if left unchecked, the tree will continue to grow deeper until all the leaves are pure (of only one class)
    • for large datasets, growing extremely deep trees becomes computationally expensive
  • Solutions:
    • grow a full tree and then prune it
    • employ stopping criteria
  • The stopping criteria we can apply at each stage of the tree-building process are:
    • minimum number of cases in a node before splitting: minsplit. If a node has fewer than the specified number, the node will not be split further.
    • maximum depth of the tree: maxdepth. If a node is already at this depth, it will not be split further.
    • minimum improvement in performance for a split: complexity parameter (cp in rpart) is calculated for each level of depth of the tree. If the cp value of a depth is less than the chosen threshold value, the nodes at this level will not be split further. In other words, if adding another layer to the tree doesn’t improve the performance of the model by cp, don’t split the nodes.
    • minimum number of cases in a leaf: minbucket. If splitting a node would result in leaves containing fewer cases than minbucket, the node will not be split.
Four hyperparameters in rpart for the stopping criteria

Four hyperparameters in rpart for the stopping criteria

These 4 criteria combined can make for very stringent and complicated stopping criteria. As the values of these criteria cannot be learned directly from the data, they are hyperparameters! \(\rightarrow\) What do we do with hyperparameters? Tune them! \(\rightarrow\) So when we build a model with rpart, we will tune these stopping criteria to get values that give us the best-performing model.

Perform hyperparameter tuning minsplit, minbucket, cp, and maxdepth

# Printing available rpart hyperparameters
getParamSet(tree)
##                    Type len  Def   Constr Req Tunable Trafo
## minsplit        integer   -   20 1 to Inf   -    TRUE     -
## minbucket       integer   -    - 1 to Inf   -    TRUE     -
## cp              numeric   - 0.01   0 to 1   -    TRUE     -
## maxcompete      integer   -    4 0 to Inf   -    TRUE     -
## maxsurrogate    integer   -    5 0 to Inf   -    TRUE     -
## usesurrogate   discrete   -    2    0,1,2   -    TRUE     -
## surrogatestyle discrete   -    0      0,1   -    TRUE     -
## maxdepth        integer   -   30  1 to 30   -    TRUE     -
## xval            integer   -   10 0 to Inf   -   FALSE     -
## parms           untyped   -    -        -   -    TRUE     -
# Defining the hyperparameter space for tuning
treeParamSpace <- makeParamSet(
  makeIntegerParam("minsplit", lower = 5, upper = 20),
  makeIntegerParam("minbucket", lower = 3, upper = 10),
  makeNumericParam("cp", lower = 0.01, upper = 0.1),
  makeIntegerParam("maxdepth", lower = 3, upper = 10))
# Defining the random search
randSearch <- makeTuneControlRandom(maxit = 200)
cvForTuning <- makeResampleDesc("CV", iters = 5)
# Performing hyperparameter tuning
parallelStartSocket(cpus = detectCores())

tunedTreePars <- tuneParams(tree, task = zooTask,
                            resampling = cvForTuning,
                            par.set = treeParamSpace,
                            control = randSearch)

parallelStop()

tunedTreePars
## Tune result:
## Op. pars: minsplit=8; minbucket=6; cp=0.0536; maxdepth=10
## mmce.test.mean=0.0790476

Note: with hyperparameter tuning, the best hyperparameter combination gave us a mean misclassification error (mmce) of 0.0698

Training the model with the tuned hyperparameters

# Training the final tuned model
tunedTree <- setHyperPars(tree, par.vals = tunedTreePars$x)
tunedTreeModel <- train(tunedTree, zooTask)
# Plotting the decision tree
treeModelData <- getLearnerModel(tunedTreeModel)
rpart.plot(treeModelData, roundint = FALSE,
box.palette = "BuBn",
type = 5)

# Exploring the model
printcp(treeModelData, digits = 3)
## 
## Classification tree:
## rpart::rpart(formula = f, data = d, xval = 0, minsplit = 8, minbucket = 6, 
##     cp = 0.053641479217913, maxdepth = 10)
## 
## Variables actually used in tree construction:
## [1] airborne backbone feathers fins     milk    
## 
## Root node error: 60/101 = 0.594
## 
## n= 101 
## 
##       CP nsplit rel error
## 1 0.3333      0     1.000
## 2 0.2167      1     0.667
## 3 0.1667      2     0.450
## 4 0.0917      3     0.283
## 5 0.0536      5     0.100

Cross-validating our decision tree model

# Cross-validating the model-building process
outer <- makeResampleDesc("CV", iters = 5)
treeWrapper <- makeTuneWrapper("classif.rpart", resampling = cvForTuning,
                               par.set = treeParamSpace,
                               control = randSearch)
parallelStartSocket(cpus = detectCores())
cvWithTuning <- resample(treeWrapper, zooTask, resampling = outer)
parallelStop()
# Extracting the cross-validation result
cvWithTuning
## Resample Result
## Task: zooTib
## Learner: classif.rpart.tuned
## Aggr perf: mmce.test.mean=0.0980952
## Runtime: 36.4924

The cross-validated estimate of model performance gives us an mmce of 0.12

Strengths and weaknesses of tree-based algorithms

  • The strengths of the tree-based algorithm are:
    • The intuition behind tree-building is quite simple, and each individual tree is very interpretable
    • They can handle categorical and continuous predictor variables
    • They make no assumptions about the distribution of the predictor variables
    • They can handle missing values in sensible ways
    • They can handle continuous variables on different scales
  • The weaknesses of tree-based algorithm are:
    • Individual trees are very susceptible to overfitting, so much so that decision trees are now rarely used

To sum up

  • The rpart algorithm is a supervised learner for both classification and regression problems
  • Tree-based learners start with all the cases in the root node, and find sequential binary splits until cases find themselves in leaf nodes
  • Tree construction is a greedy process, and can be limited by setting stopping criteria (such as the minimum number of cases allowed in a node before it can be split)
  • The Gini gain is a criteria used to decide which predictor variable will result in the best split at a particular node
  • Decision trees have a tendency to overfit the training set