Decision Trees in R

Hello, and welcome to the Decision Trees in R. Here, we will be going over what Decision Trees are, what they are used for, and how to utilize them in the R environment.

The Classification Problem

Suppose we are lost in a forest, and are very hungry. Unable to go on without eating something first, we take a look around, only to find nothing immediately edible – just mushrooms.

We’re starving, so anything looks great to us, but eating one of those carelessly might result in us getting poisoned. To know if we can or cannot eat a mushroom, we need to classify it based on our knowledge of its features, in other words, we have a classification problem on our hands.

This sort of problem is not simple to solve effectively – there are many variables involved in correctly classifying something. There are many different kinds of mathematical models that were created to aid us in classification tasks. One of these models is the Decision Tree model.

The Decision Tree is a predictive model based on the analysis of a set of data points that describe the type of object we want to classify. In our example, it might be a set of observations of a mushroom’s cap type, its color, odor, shape of its stalk, etc. These descriptions of our object are called features, and are very important in many different kinds of machine learning algorithms, including Decision Trees. The classification we want out of these features is set aside as a “result” of sorts.

The process of constructing the tree Now, the question is that how this probablities are calculated?

let’s quickly review some terminology that can apply to decision trees. If some of these definitions don’t make sense right now, don’t worry, as we’ll be going through some examples that will clarify this.

Node:. In a node, we have a dataset that gets tested for a certain attribute. The goal of the node is to split the dataset on an attribute.

Leaf node: It is the terminal node in the tree that predicts the outcome.

Root node: It appears at the top of tree and contains the entire dataset for that tree.

__ How to find the best feature in each node to split

Entropy: Entropy is calculated for each node. It is the amount of information disorder, or the amount of randomness in the data. The entropy in the node depends on how much random data is in that node. In decision tree we are looking for a trees that have smallest entropy in their nodes. The entropy is used to calculate the homogeneity of the samples in that node. If the samples are completely homogeneous the entropy is zero and if the sample is an equally divided it has entropy of one. It means, if all data in a node are either poisonous or edible, then the entropy is zero, but if the half of data are poisonous and other half are edible, then the entropuy is one. For example, to calculate the Entropy of our target class:

Entropy = - p(edible)log(p(edible)) - p(poisonous)log(p(poisonous))

__ Information gain__: This is the information collected that can increase the level of certainty in a particular predictor (attributes). We can think of information gain and entropy as opposites. As entropy, or the amount of randomness, decreases, the information gain, or amount of certainty, increases, and vice-versa. So, constructing a decision tree is all about finding prdictors (attributes) that returns the highest information gain.

Algorithm:

  1. Calculate entropy of the target field (the class label) for whole dataset.
  2. For each attribute:
    1. split the dataset on the attribute
    2. calculate entropy of the target field on splited dataset, using the attribute values
    3. calculate the information gain of the attribute
  3. select the attribute that has the largest informmation gain
  4. Branch the tree using the selected attribute
  5. stop, if it is a node with entropy of 0, otherwise jump to step2. we have our dataset at the root node consisting of 4 colors: red, blue, green, and yellow. There are 4 dots for each color, totaling 16 dots. We’ll be using histograms to look at the chance that an out-of-sample data point would be a certain color.
download.file("https://ibm.box.com/shared/static/dpdh09s70abyiwxguehqvcq3dn0m7wve.data", "mushroom.data")
mushrooms <- read.csv("mushroom.data", header = FALSE)
head(mushrooms)
##   V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V11 V12 V13 V14 V15 V16 V17 V18 V19 V20 V21
## 1  p  x  s  n  t  p  f  c  n   k   e   e   s   s   w   w   p   w   o   p   k
## 2  e  x  s  y  t  a  f  c  b   k   e   c   s   s   w   w   p   w   o   p   n
## 3  e  b  s  w  t  l  f  c  b   n   e   c   s   s   w   w   p   w   o   p   n
## 4  p  x  y  w  t  p  f  c  n   n   e   e   s   s   w   w   p   w   o   p   k
## 5  e  x  s  g  f  n  f  w  b   k   t   e   s   s   w   w   p   w   o   e   n
## 6  e  x  y  y  t  a  f  c  b   n   e   c   s   s   w   w   p   w   o   p   k
##   V22 V23
## 1   s   u
## 2   n   g
## 3   n   m
## 4   s   u
## 5   a   g
## 6   n   g
colnames(mushrooms) <- c("Class","cap.shape","cap.surface","cap.color","bruises","odor","gill.attachment","gill.spacing",
                         "gill.size","gill.color","stalk.shape","stalk.root","stalk.surface.above.ring",
                         "stalk.surface.below.ring","stalk.color.above.ring","stalk.color.below.ring","veil.type","veil.color",
                         "ring.number","ring.type","print","population","habitat")
head(mushrooms)
##   Class cap.shape cap.surface cap.color bruises odor gill.attachment
## 1     p         x           s         n       t    p               f
## 2     e         x           s         y       t    a               f
## 3     e         b           s         w       t    l               f
## 4     p         x           y         w       t    p               f
## 5     e         x           s         g       f    n               f
## 6     e         x           y         y       t    a               f
##   gill.spacing gill.size gill.color stalk.shape stalk.root
## 1            c         n          k           e          e
## 2            c         b          k           e          c
## 3            c         b          n           e          c
## 4            c         n          n           e          e
## 5            w         b          k           t          e
## 6            c         b          n           e          c
##   stalk.surface.above.ring stalk.surface.below.ring stalk.color.above.ring
## 1                        s                        s                      w
## 2                        s                        s                      w
## 3                        s                        s                      w
## 4                        s                        s                      w
## 5                        s                        s                      w
## 6                        s                        s                      w
##   stalk.color.below.ring veil.type veil.color ring.number ring.type print
## 1                      w         p          w           o         p     k
## 2                      w         p          w           o         p     n
## 3                      w         p          w           o         p     n
## 4                      w         p          w           o         p     k
## 5                      w         p          w           o         e     n
## 6                      w         p          w           o         p     k
##   population habitat
## 1          s       u
## 2          n       g
## 3          n       m
## 4          s       u
## 5          a       g
## 6          n       g
# Define the factor names for "Class"
levels(mushrooms$Class) <- c("Edible","Poisonous")
# Define the factor names for "odor"
levels(mushrooms$odor) <- c("Almonds","Anise","Creosote","Fishy","Foul","Musty","None","Pungent","Spicy")

# Define the factor names for "print"
levels(mushrooms$print) <- c("Black","Brown","Buff","Chocolate","Green","Orange","Purple","White","Yellow")
head(mushrooms)
##   Class cap.shape cap.surface cap.color bruises odor gill.attachment
## 1     p         x           s         n       t    p               f
## 2     e         x           s         y       t    a               f
## 3     e         b           s         w       t    l               f
## 4     p         x           y         w       t    p               f
## 5     e         x           s         g       f    n               f
## 6     e         x           y         y       t    a               f
##   gill.spacing gill.size gill.color stalk.shape stalk.root
## 1            c         n          k           e          e
## 2            c         b          k           e          c
## 3            c         b          n           e          c
## 4            c         n          n           e          e
## 5            w         b          k           t          e
## 6            c         b          n           e          c
##   stalk.surface.above.ring stalk.surface.below.ring stalk.color.above.ring
## 1                        s                        s                      w
## 2                        s                        s                      w
## 3                        s                        s                      w
## 4                        s                        s                      w
## 5                        s                        s                      w
## 6                        s                        s                      w
##   stalk.color.below.ring veil.type veil.color ring.number ring.type print
## 1                      w         p          w           o         p     k
## 2                      w         p          w           o         p     n
## 3                      w         p          w           o         p     n
## 4                      w         p          w           o         p     k
## 5                      w         p          w           o         e     n
## 6                      w         p          w           o         p     k
##   population habitat
## 1          s       u
## 2          n       g
## 3          n       m
## 4          s       u
## 5          a       g
## 6          n       g
Now we can get to building our model proper. For Decision Trees, we are going to utilize two different, but related, libraries: rpart to create the decision tree, and rpart.plot to visualize our decision tree. To import libraries, we use the library function, like so
# Import our required libraries
library(rpart)
library(rpart.plot)
To create our decision tree model, we can use the rpart function. rpart is simple to use: you provide it a formula, show it the dataset it is supposed to use and choose a method (either "class" for classification or "anova" for regression).

A great trick to know when handling very large structured datasets (our dataset has over 20 columns we want to use!) is that in formula declarations, one can use the . operator as a quick way of designating "all other columns" to R. You can also print the Decision Tree model to retrieve a summary describing it.
# Create a classification decision tree using "Class" as the variable we want to predict and everything else as its predictors.
myDecisionTree <- rpart(Class ~ ., data = mushrooms, method = "class")

# Print out a summary of our created model.
print(myDecisionTree)
## n= 8124 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 8124 3916 e (0.51797144 0.48202856)  
##   2) odor=a,l,n 4328  120 e (0.97227357 0.02772643)  
##     4) print=b,h,k,n,o,u,w,y 4256   48 e (0.98872180 0.01127820) *
##     5) print=r 72    0 p (0.00000000 1.00000000) *
##   3) odor=c,f,m,p,s,y 3796    0 p (0.00000000 1.00000000) *

Now that we have our model, we can draw it to gain a better understanding of how it is classifying the data points. We can use the rpart.plot function – a specialized function for plotting trees – to render our model. This function takes on some parameters for visualizing the tree in different ways – try changing the type (from 1 to 4) parameter to see what happens!

rpart.plot(myDecisionTree, type = 3, extra = 2, under = TRUE, faclen=5, cex = .75)

As we can see (under the classification results), our decision tree has perfect accuracy when classifying poisonous mushrooms, and almost perfect accuracy when dealing with edible ones!

newCase  <- mushrooms[10,-1]
newCase
##    cap.shape cap.surface cap.color bruises odor gill.attachment gill.spacing
## 10         b           s         y       t    a               f            c
##    gill.size gill.color stalk.shape stalk.root stalk.surface.above.ring
## 10         b          g           e          c                        s
##    stalk.surface.below.ring stalk.color.above.ring stalk.color.below.ring
## 10                        s                      w                      w
##    veil.type veil.color ring.number ring.type print population habitat
## 10         p          w           o         p     k          s       m
predict(myDecisionTree, newCase, type = "class")
## 10 
##  e 
## Levels: e p

How accurate is model?

train_ind <- sample(c(1:nrow(mushrooms)), size = 10)
## 75% of the sample size
n <- nrow(mushrooms)
smp_size <- floor(0.75 * n)

## set the seed to make your partition reproductible
set.seed(123)
train_ind <- sample(c(1:n), size = smp_size)

mushrooms_train <- mushrooms[train_ind, ]
mushrooms_test <- mushrooms[-train_ind, ]
newDT <- rpart(Class ~ ., data = mushrooms_train, method = "class")
result <- predict(newDT, mushrooms_test[,-1], type = "class")
head(result)
##  3  4  7  8 10 11 
##  e  p  e  e  e  e 
## Levels: e p
head(mushrooms_test$Class)
## [1] "e" "p" "e" "e" "e" "e"
table(mushrooms_test$Class, result)
##    result
##        e    p
##   e 1068    0
##   p   15  948

END