This tutorial is going to show how to use ‘party’ R package to train model using decision tree. We are going to use iris data set.
library(party)
## Loading required package: grid
## Loading required package: zoo
##
## Attaching package: 'zoo'
##
## The following objects are masked from 'package:base':
##
## as.Date, as.Date.numeric
##
## Loading required package: sandwich
## Loading required package: strucchange
## Loading required package: modeltools
## Loading required package: stats4
str(iris)
## 'data.frame': 150 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
head(iris)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1 5.1 3.5 1.4 0.2 setosa
## 2 4.9 3.0 1.4 0.2 setosa
## 3 4.7 3.2 1.3 0.2 setosa
## 4 4.6 3.1 1.5 0.2 setosa
## 5 5.0 3.6 1.4 0.2 setosa
## 6 5.4 3.9 1.7 0.4 setosa
Following R code snippets explains how to get training and testing sample data set for model training and model validations
set.seed(1234) #To get reproducible result
ind <- sample(2,nrow(iris), replace=TRUE, prob=c(0.7,0.3))
trainData <- iris[ind==1,]
testData <- iris[ind==2,]
We then load package party, build a decision tree, and check the prediction result. Function ctree() provides some parameters, such as MinSplit, MinBusket, MaxSurrogate and MaxDepth, to control the training of decision trees. Below we use default settings to build a decision tree. Examples of setting the above parameters are available in Chapter 13. In the code below, myFormula species that Species is the target variable and all other variables are independent variables
myFormula <- Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
iris_ctree <- ctree(myFormula, data=trainData)
Following is going to check the prediction on trainData itself
#train_predict <- predict(iris_ctree)
train_predict <- predict(iris_ctree,trainData,type="response")
Next, create confusion matrix and misclassification errors
table(train_predict,trainData$Species)
##
## train_predict setosa versicolor virginica
## setosa 40 0 0
## versicolor 0 37 3
## virginica 0 1 31
mean(train_predict != trainData$Species) * 100
## [1] 3.571
Ok, Above misclassification error rate is 3.6% and other way to say is Model is 96.4% accurate.
Do you think 96.4% is good? Hold on. We just validated training data. It does not make sense because same training data is used to build the model. Right.
Let us validate the model on test set
test_predict <- predict(iris_ctree, newdata= testData,type="response")
Let us check confusion matrix and misclassification errors
table(test_predict, testData$Species)
##
## test_predict setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 12 2
## virginica 0 0 14
mean(test_predict != testData$Species) * 100
## [1] 5.263
Does the misclassification error makes sense?
We can look into the model and do some plot to better understand the model
print(iris_ctree)
##
## Conditional inference tree with 4 terminal nodes
##
## Response: Species
## Inputs: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width
## Number of observations: 112
##
## 1) Petal.Length <= 1.9; criterion = 1, statistic = 104.643
## 2)* weights = 40
## 1) Petal.Length > 1.9
## 3) Petal.Width <= 1.7; criterion = 1, statistic = 48.939
## 4) Petal.Length <= 4.4; criterion = 0.974, statistic = 7.397
## 5)* weights = 21
## 4) Petal.Length > 4.4
## 6)* weights = 19
## 3) Petal.Width > 1.7
## 7)* weights = 32
plot(iris_ctree)
plot(iris_ctree, type="simple")
Next validate the model through cross validation