This tutorial is going to show how to use ‘party’ R package to train model using decision tree. We are going to use iris data set.

library(party)
## Loading required package: grid
## Loading required package: zoo
## 
## Attaching package: 'zoo'
## 
## The following objects are masked from 'package:base':
## 
##     as.Date, as.Date.numeric
## 
## Loading required package: sandwich
## Loading required package: strucchange
## Loading required package: modeltools
## Loading required package: stats4
str(iris)
## 'data.frame':    150 obs. of  5 variables:
##  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
##  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
##  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
##  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
##  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
head(iris)
##   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1          5.1         3.5          1.4         0.2  setosa
## 2          4.9         3.0          1.4         0.2  setosa
## 3          4.7         3.2          1.3         0.2  setosa
## 4          4.6         3.1          1.5         0.2  setosa
## 5          5.0         3.6          1.4         0.2  setosa
## 6          5.4         3.9          1.7         0.4  setosa

Following R code snippets explains how to get training and testing sample data set for model training and model validations

set.seed(1234) #To get reproducible result
ind <- sample(2,nrow(iris), replace=TRUE, prob=c(0.7,0.3))
trainData <- iris[ind==1,]
testData <- iris[ind==2,]

We then load package party, build a decision tree, and check the prediction result. Function ctree() provides some parameters, such as MinSplit, MinBusket, MaxSurrogate and MaxDepth, to control the training of decision trees. Below we use default settings to build a decision tree. Examples of setting the above parameters are available in Chapter 13. In the code below, myFormula speci es that Species is the target variable and all other variables are independent variables

myFormula <- Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
iris_ctree <- ctree(myFormula, data=trainData)

Following is going to check the prediction on trainData itself

#train_predict <- predict(iris_ctree)
train_predict <- predict(iris_ctree,trainData,type="response")

Next, create confusion matrix and misclassification errors

table(train_predict,trainData$Species)
##              
## train_predict setosa versicolor virginica
##    setosa         40          0         0
##    versicolor      0         37         3
##    virginica       0          1        31
mean(train_predict != trainData$Species) * 100
## [1] 3.571

Ok, Above misclassification error rate is 3.6% and other way to say is Model is 96.4% accurate.

Do you think 96.4% is good? Hold on. We just validated training data. It does not make sense because same training data is used to build the model. Right.

Let us validate the model on test set

test_predict <- predict(iris_ctree, newdata= testData,type="response")

Let us check confusion matrix and misclassification errors

table(test_predict, testData$Species)
##             
## test_predict setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0         12         2
##   virginica       0          0        14
mean(test_predict != testData$Species) * 100
## [1] 5.263

Does the misclassification error makes sense?

We can look into the model and do some plot to better understand the model

print(iris_ctree)
## 
##   Conditional inference tree with 4 terminal nodes
## 
## Response:  Species 
## Inputs:  Sepal.Length, Sepal.Width, Petal.Length, Petal.Width 
## Number of observations:  112 
## 
## 1) Petal.Length <= 1.9; criterion = 1, statistic = 104.643
##   2)*  weights = 40 
## 1) Petal.Length > 1.9
##   3) Petal.Width <= 1.7; criterion = 1, statistic = 48.939
##     4) Petal.Length <= 4.4; criterion = 0.974, statistic = 7.397
##       5)*  weights = 21 
##     4) Petal.Length > 4.4
##       6)*  weights = 19 
##   3) Petal.Width > 1.7
##     7)*  weights = 32
plot(iris_ctree)

plot of chunk unnamed-chunk-8

plot(iris_ctree, type="simple")

plot of chunk unnamed-chunk-8

Next validate the model through cross validation