Decision Tree

Gini index

\[ Gini=1-\sum(pi)^{2}\,for\,i=1\, to\, number\, of\, classes \]

sample_data <- data.frame(
  Day=1:14,
  Outlook=c('Sunny','Sunny','Overcast','Rain','Rain','Rain','Overcast',
            'Sunny','Sunny','Rain','Sunny','Overcast','Overcast','Rain'),
  Temp=c('Hot','Hot','Hot','Mild','Cool','Cool','Cool',
         'Mild','Cool','Mild','Mild','Mild','Hot','Mild'),
  Humidity=c('High','High','High','High','Normal','Normal','Normal',
             'High','Normal','Normal','Normal','High','Norma','High'),
  Wild=c('Weak','Strong','Weak','Weak','Weak','Strong','Strong',
         'Weak','Weak','Weak','Strong','Strong','Weak','Strong'),
  Decision=c('No','No','Yes','Yes','Yes','No','Yes',
             'No','Yes','Yes','Yes','Yes','Yes','No')
)
knitr::kable(sample_data)
Day Outlook Temp Humidity Wild Decision
1 Sunny Hot High Weak No
2 Sunny Hot High Strong No
3 Overcast Hot High Weak Yes
4 Rain Mild High Weak Yes
5 Rain Cool Normal Weak Yes
6 Rain Cool Normal Strong No
7 Overcast Cool Normal Strong Yes
8 Sunny Mild High Weak No
9 Sunny Cool Normal Weak Yes
10 Rain Mild Normal Weak Yes
11 Sunny Mild Normal Strong Yes
12 Overcast Mild High Strong Yes
13 Overcast Hot Norma Weak Yes
14 Rain Mild High Strong No

Calculation example

Outlook

table(sample_data$Outlook,sample_data$Decision)
##           
##            No Yes
##   Overcast  0   4
##   Rain      2   3
##   Sunny     3   2
Gini_outlook_sunny=1-(2/5)^2-(3/5)^2
Gini_outlook_overcast=1-(4/4)^2
Gini_outlook_rain=1-(3/5)^2-(2/5)^2
Gini_outlook=5/14*Gini_outlook_sunny+4/14*Gini_outlook_overcast+5/14*Gini_outlook_rain
Gini_outlook
## [1] 0.3428571

Temperature

table(sample_data$Temp,sample_data$Decision)
##       
##        No Yes
##   Cool  1   3
##   Hot   2   2
##   Mild  2   4
Gini_temp_cool=1-(1/4)^2-(3/4)^2
Gini_temp_hot=1-(2/4)^2-(2/4)^2
Gini_temp_mild=1-(2/6)^2-(4/6)^2
Gini_temp=4/14*Gini_temp_cool+4/14*Gini_temp_hot+6/14*Gini_temp_mild
Gini_temp
## [1] 0.4404762

Humidity

Wind

Decision

The winner will be outlook feature because its cost is the lowest.

knitr::kable(data.frame(
  Feature=c('Outlook','Temp','Humidity','Wind'),
  GiniIndex=c(0.342,0.439,0.367,0.428)
))
Feature GiniIndex
Outlook 0.342
Temp 0.439
Humidity 0.367
Wind 0.428

Decision tree: Iris data

패키지 부르기

library(caret, quietly = TRUE)
library(rattle, quietly = TRUE)
## Rattle: A free graphical interface for data science with R.
## Version 5.4.0 Copyright (c) 2006-2020 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
  • Caret : Classification and regression training
  • rattle : Graphical user interface for data science in R
  • library() : 패키지를 불러오는 함수

데이터

data("iris")
names(iris) = tolower(names(iris)) #컬럼의 이름을 모두 소문자로 바꾼다.
table(iris$species)
## 
##     setosa versicolor  virginica 
##         50         50         50
index = createDataPartition(
  y=iris$species,  #분할 기준
  p=0.7,  #학습데이터 비율
  list=FALSE) #row index로 반환하기
train.set = iris[index,] #학습 데이터 
test.set = iris[-index,] #검증 데이터
model <- train(
  species ~ .,
  data=train.set,
  method='rpart', #decision tree method
  trControl=trainControl(method='cv') #k fold cross validation
)
iris.pred = predict(model, newdata = test.set)
confusionMatrix(table(iris.pred, test.set$species))
## Confusion Matrix and Statistics
## 
##             
## iris.pred    setosa versicolor virginica
##   setosa         15          0         0
##   versicolor      0         14         3
##   virginica       0          1        12
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9111          
##                  95% CI : (0.7878, 0.9752)
##     No Information Rate : 0.3333          
##     P-Value [Acc > NIR] : 8.467e-16       
##                                           
##                   Kappa : 0.8667          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.9333           0.8000
## Specificity                 1.0000            0.9000           0.9667
## Pos Pred Value              1.0000            0.8235           0.9231
## Neg Pred Value              1.0000            0.9643           0.9062
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.3111           0.2667
## Detection Prevalence        0.3333            0.3778           0.2889
## Balanced Accuracy           1.0000            0.9167           0.8833
fancyRpartPlot(model$finalModel)

Reference