Case study 17: Tuning a Decision Tree

Foreword: About the Machine Learning in Medicine (MLM) project

The MLM project has been initialized in 2016 and aims to:

1.Encourage using Machine Learning techniques in medical research in Vietnam

2.Promote the use of R statistical programming language, an open source and leading tool for practicing data science.

Background

The 17th case study (Chapter 16-Machine Learning in Medicine-Cookbook 1 by T. J. Cleophas and A. H. Zwinderman, Springer, 2014) introduces the Tree-based model for developing clinical classification rule.

Decision tree is considered to be the most effective data mining algorithm for establishing the clinical decision rules. The tree based models present many advantages, including a full compatibility to all data types, simplicity in modelling and most of all, a straightforward solution for clinical decision making. Desicion tree models are widely supported by variant software packages, such as SAS, S-Plus, SPSS, Matlab, Weka and R. However, the modelling performance could be largely different among these packages.

Recently, some machine learning frameworks were introduced in R, likes CARET or MLR packages. These packages allow implementing easily a lot of prebuilt algorithms and offer useful features such as Training, Tuning and Testing. In this study, we will focus on the CART algorithm (Breiman et al. 1980), a popular method for growing a decision tree. In R, the CART algorithm is supported by rpart package. It should be noted that the train( ) function caret package performs simultaneously Tuning, Resampling and Training, whilst the mlr package handle independently these process.

Overfitting is a critical problem for CART modelling.There are two alternative solutions to overcome this problem: (1) Setting the limits of tree size and/or (2) pruning the tree. Despite that the CART algorithm is supported by the same package (rpart) in both caret and mlr, its tuning procedure is very different: In caret, CART model could be tuned using two alternative methods:

The “rpart” method consists of tuning the model based on Complexity parameter (cp) or Tree pruning, while “rpart2” method aims to optimize the maximum tree depth parameter (or setting the constraint on tree size).

The tuning procedure in mlr offers higher flexibility, as user can tune both cp and maxdepth values by screening every possible combinations between these two parameters (a grid-based or random tuning control). The mlr also allows manually filtering the features based on different criteria, such as information gain or gain ratio.

In this experiment, we will evaluate the performance of 3 different approaches: Automated training by caret using either rpart and rpart2 methods versus user-controlled modelling in mlr. Our hypothesis is that a manual tuning would provide better performance.

Materials and method

The original dataset was provided in Chaper 16, Machine Learning in Medicine-Cookbook 1 (T. J. Cleophas and A. H. Zwinderman, SpringerBriefs in Statistics 2014). You can get the original data in SPSS SAV format decisiontreebinary.sav) from their website: extras.springer.com.

Data analysis was performed in R statistical programming language. The caret and mlr packages will be used through our experiment.

The original dataset (n=1004) will be randomly splitted into Train (80%) and Test (20%) subsets. The Train subset (n=804) will be consistently used for model tuning and training in both caret and mlr. The performance of final models will be compared using the same test subset (n=200).

First, following packages must be established in R-studio (some other packages might also be required during the analysis), and we will personalize the aesthetic effects for ggplot2

library(foreign)
library(tidyverse)
library(caret)
library(mlr)
library(gridExtra)
library(rattle)
library(rpart.plot)

my_theme <- function(base_size =8, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
    theme(
      axis.text = element_text(size =8),
      axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
      axis.title = element_text(size =8),
      panel.grid.major = element_line(color = "gray"),
      panel.grid.minor = element_blank(),
      panel.background = element_rect(fill = "#fef5f9"),
      strip.background = element_rect(fill = "#290029", color = "#290029", size =0.5),
      strip.text = element_text(face = "bold", size = 8, color = "white"),
      legend.position = "bottom",
      legend.justification = "center",
      legend.background = element_blank(),
      panel.border = element_rect(color = "grey5", fill = NA, size = 0.5)
    )
}

theme_set(my_theme())

myfillcolors=c("#ff0033","#330033", "#cc0033" , "#660033", "#990033","#630063")

Results

Step 0: Data loading and reshape

require(foreign)

df=read.spss("decisiontreebinary.sav",use.value.labels = TRUE,to.data.frame = TRUE)%>%as_tibble()
df<-df%>%dplyr::rename(.,Infarction=infarct_rating,Age=age,Cholesterol=cholesterol_level,Smoke=smoking,Education=education,Weight=weight_level)
df$Smoke%<>%as.factor()%>%recode_factor(.,`0` = "No", `1` = "Medium",`2` = "High")
df$Infarction%<>%as.factor()%>%recode_factor(.,`no` = "No", `yes` = "Occured")

df
## # A tibble: 1,004 × 6
##    Infarction      Age Cholesterol  Smoke   Education Weight
##        <fctr>    <dbl>      <fctr> <fctr>      <fctr> <fctr>
## 1          No 44.85561         Low     No High school   high
## 2          No 42.71335      Medium     No High school   high
## 3          No 43.33660        High     No     College   high
## 4          No 44.01998        High     No High school   high
## 5          No 67.96738         Low     No     College   high
## 6          No 40.30740      Medium     No     College   high
## 7          No 66.55518         Low     No     College   high
## 8          No 45.95118         Low     No     College   high
## 9          No 52.26541         Low     No High school   high
## 10         No 43.86261         Low     No High school   high
## # ... with 994 more rows

Step 1) Data exploration

Figure 1: Descriptive analysis

smoke=ggplot(data=df,aes(x=Infarction,fill=Smoke))+geom_bar(position="fill",alpha=0.8,color="black")+scale_fill_manual(values=myfillcolors)+coord_flip()

chol=ggplot(data=df,aes(x=Infarction,fill=Cholesterol))+geom_bar(position="fill",alpha=0.8,color="black")+scale_fill_manual(values=myfillcolors)+coord_flip()

wt=ggplot(data=df,aes(x=Infarction,fill=Weight))+geom_bar(position="fill",alpha=0.8,color="black")+scale_fill_manual(values=myfillcolors)+coord_flip()

edu=ggplot(data=df,aes(x=Infarction,fill=Education))+geom_bar(position="fill",alpha=0.8,color="black")+scale_fill_manual(values=myfillcolors)+coord_flip()

a=df%>%ggplot(aes(x=Infarction,y=Age,fill=Infarction))+geom_boxplot(alpha=0.8)+scale_fill_manual(values=myfillcolors)+coord_flip()

grid.arrange(smoke,chol,wt,edu,a,ncol=2)

The data exploration showed a high contrast in Age, Cholesterol, Smoking status between two Infarction classes. The Education factor does not affect the Infarction outcome.

Data splitting using CARET

set.seed(123)
idTrain=createDataPartition(y=df$Infarction,p=0.8,list=FALSE)
trainset=df[idTrain,]
testset=df[-idTrain,]

Step 1: Growing Decision Tree in CARET, based CART algorithm and two Tuning methods

Control=trainControl(method= "repeatedcv",number=10,repeats=10,classProbs=TRUE,summaryFunction =multiClassSummary)

As mentioned above, the CART algorithm could be tuned by two alternative methods in CARET. We will develop 2 different versions : the cart1 model implied cp based pruning whilst the cart2 model consists of a maximum tree size tuning.

Figure 2: CP tuning based Decision Tree

cart1=caret::train(Infarction~.,data=trainset,method="rpart",trControl=Control,tuneLength=10)

cart1
## CART 
## 
## 804 samples
##   5 predictor
##   2 classes: 'No', 'Occured' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 724, 724, 723, 723, 723, 725, ... 
## Resampling results across tuning parameters:
## 
##   cp          logLoss    ROC        Accuracy   Kappa      Sensitivity
##   0.00000000  0.4164073  0.8431908  0.8430763  0.4761837  0.5101838  
##   0.01137885  0.3916414  0.7772318  0.8461766  0.4652982  0.4730515  
##   0.02275770  0.4022487  0.7536177  0.8471579  0.4653680  0.4689338  
##   0.03413655  0.4122506  0.7317413  0.8465235  0.4870632  0.5226838  
##   0.04551539  0.4123961  0.7284855  0.8472611  0.4930641  0.5321691  
##   0.05689424  0.4132571  0.7272709  0.8470111  0.4911052  0.5285294  
##   0.06827309  0.4140702  0.7269932  0.8459984  0.4889806  0.5285294  
##   0.07965194  0.4140702  0.7269932  0.8459984  0.4889806  0.5285294  
##   0.09103079  0.4204099  0.7132698  0.8421310  0.4603267  0.5030515  
##   0.10240964  0.4693742  0.6178621  0.8126815  0.2463664  0.2826471  
##   Specificity  Pos_Pred_Value  Neg_Pred_Value  Detection_Rate
##   0.9297842    0.6663804       0.8801372       0.10523916    
##   0.9434152    0.6985555       0.8739534       0.09754309    
##   0.9456176    0.7164750       0.8736163       0.09676647    
##   0.9308656    0.6702012       0.8832972       0.10784877    
##   0.9292907    0.6723378       0.8851412       0.10983476    
##   0.9299206    0.6740622       0.8844241       0.10908476    
##   0.9286508    0.6704020       0.8842582       0.10908476    
##   0.9286508    0.6704020       0.8842582       0.10908476    
##   0.9303795    0.6622607       0.8798929       0.10384375    
##   0.9506052    0.6053396       0.8401661       0.05833533    
##   Balanced_Accuracy
##   0.7199840        
##   0.7082333        
##   0.7072757        
##   0.7267747        
##   0.7307299        
##   0.7292250        
##   0.7285901        
##   0.7285901        
##   0.7167155        
##   0.6166261        
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was cp = 0.04551539.
library(partykit)
library(party)

fancyRpartPlot(cart1$finalModel,palettes="RdPu")

The tree pruning result suggests that the optimized model would contain 3 decision nodes: Age, High cholesterol and Medium Smoking.

Figure 3: Max tree depth tuning based Decision Tree

cart2=caret::train(Infarction~.,data=trainset,method = "rpart2",trControl=Control,tuneLength=10)
## note: only 6 possible values of the max tree depth from the initial fit.
##  Truncating the grid to 6 .
cart2
## CART 
## 
## 804 samples
##   5 predictor
##   2 classes: 'No', 'Occured' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 723, 723, 724, 724, 723, 723, ... 
## Resampling results across tuning parameters:
## 
##   maxdepth  logLoss    ROC        Accuracy   Kappa      Sensitivity
##    3        0.4154876  0.7227306  0.8461203  0.4787661  0.5050735  
##    5        0.3892899  0.7668103  0.8539554  0.4712135  0.4489338  
##    6        0.3951423  0.7685720  0.8514614  0.4729551  0.4653309  
##   14        0.4050255  0.7931376  0.8483486  0.4804320  0.4970588  
##   15        0.4050255  0.7931376  0.8483486  0.4804320  0.4970588  
##   21        0.4050255  0.7931376  0.8483486  0.4804320  0.4970588  
##   Specificity  Pos_Pred_Value  Neg_Pred_Value  Detection_Rate
##   0.9349281    0.6753042       0.8799109       0.10422236    
##   0.9592684    0.7514931       0.8710764       0.09272201    
##   0.9519023    0.7213028       0.8735411       0.09608787    
##   0.9396652    0.6860649       0.8787761       0.10267606    
##   0.9396652    0.6860649       0.8787761       0.10267606    
##   0.9396652    0.6860649       0.8787761       0.10267606    
##   Balanced_Accuracy
##   0.7200008        
##   0.7041011        
##   0.7086166        
##   0.7183620        
##   0.7183620        
##   0.7183620        
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was maxdepth = 5.
fancyRpartPlot(cart2$finalModel,palettes="RdPu")

Based on the max tree depth tuning output, the optimized decision tree model would include up to 13 decision nodes as presented above.

A) Averaged confusion matrices

confusionMatrix(cart1,positive ="Occured")
## Cross-Validated (10 fold, repeated 10 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction   No Occured
##    No      11.0     5.6
##    Occured  9.7    73.7
##                             
##  Accuracy (average) : 0.8473
confusionMatrix(cart2,positive ="Occured")
## Cross-Validated (10 fold, repeated 10 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction   No Occured
##    No       9.3     3.2
##    Occured 11.4    76.1
##                            
##  Accuracy (average) : 0.854

The performance of those 2 models could be compared using the averaged confusion matrices. Results show that the max tree depth tuning provides a slightly better accuracy than the CP based pruning (0.853 vs 0.848)

B) Models comparison based on Test subset

pred1<-predict(cart1,testset,type="prob")%>%cbind(testset,.)
pred1$Predicted=predict(cart1,testset)

confusionMatrix(pred1$Predicted,reference=pred1$Infarction,positive ="Occured",mode="everything")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Occured
##    No       19      13
##    Occured  22     146
##                                          
##                Accuracy : 0.825          
##                  95% CI : (0.7651, 0.875)
##     No Information Rate : 0.795          
##     P-Value [Acc > NIR] : 0.1679         
##                                          
##                   Kappa : 0.4155         
##  Mcnemar's Test P-Value : 0.1763         
##                                          
##             Sensitivity : 0.9182         
##             Specificity : 0.4634         
##          Pos Pred Value : 0.8690         
##          Neg Pred Value : 0.5938         
##               Precision : 0.8690         
##                  Recall : 0.9182         
##                      F1 : 0.8930         
##              Prevalence : 0.7950         
##          Detection Rate : 0.7300         
##    Detection Prevalence : 0.8400         
##       Balanced Accuracy : 0.6908         
##                                          
##        'Positive' Class : Occured        
## 
pred2<-predict(cart2,testset,type="prob")%>%cbind(testset,.)
pred2$Predicted=predict(cart2,testset)

confusionMatrix(pred2$Predicted,reference=pred2$Infarction,positive ="Occured",mode="everything")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Occured
##    No       21       8
##    Occured  20     151
##                                           
##                Accuracy : 0.86            
##                  95% CI : (0.8041, 0.9049)
##     No Information Rate : 0.795           
##     P-Value [Acc > NIR] : 0.01168         
##                                           
##                   Kappa : 0.5182          
##  Mcnemar's Test P-Value : 0.03764         
##                                           
##             Sensitivity : 0.9497          
##             Specificity : 0.5122          
##          Pos Pred Value : 0.8830          
##          Neg Pred Value : 0.7241          
##               Precision : 0.8830          
##                  Recall : 0.9497          
##                      F1 : 0.9152          
##              Prevalence : 0.7950          
##          Detection Rate : 0.7550          
##    Detection Prevalence : 0.8550          
##       Balanced Accuracy : 0.7309          
##                                           
##        'Positive' Class : Occured         
## 

Independent testing results show that the max tree depth tuning method result a significantly higher performance than CP based model tuning.

Step 2: Manual, User-controlled Modelling in MLR package

task=makeClassifTask(id="Infarction",data=trainset,target="Infarction",positive = "Occured")

learner = makeLearner("classif.rpart", predict.type = "prob")

First, we created a Classfication Task and a basic CART learner in MLR package.

Feature selection

The feature filtering is very important, especially when you have a dataset with a large number of features. This simple data exploration step might improve incredibly your model’s performance and save a lot of your time for training the complex algorithm such as random Forest, neural network or Boosting.

Figure 4: Contribution of each feature to the Classification

ig=generateFilterValuesData(task,method="information.gain")%>%.$data%>%ggplot(aes(x=reorder(name,information.gain),y=information.gain,fill=reorder(name,information.gain)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors,name="information gain")+scale_x_discrete("Features")+coord_flip()

mr=generateFilterValuesData(task,method="mrmr")%>%.$data%>%ggplot(aes(x=reorder(name,mrmr),y=mrmr,fill=reorder(name,mrmr)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors,name="mrmr")+scale_x_discrete("Features")+coord_flip()

pmi=generateFilterValuesData(task,method="permutation.importance",imp.learner=learner)%>%.$data%>%ggplot(aes(x=reorder(name,permutation.importance),y=permutation.importance,fill=reorder(name,permutation.importance)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors,name="permut.Importance")+scale_x_discrete("Features")+coord_flip()

gt=generateFilterValuesData(task,method="gain.ratio")%>%.$data%>%ggplot(aes(x=reorder(name,gain.ratio),y=gain.ratio,fill=reorder(name,gain.ratio)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors,name="Gain ratio")+scale_x_discrete("Features")+coord_flip()

grid.arrange(ig,mr,pmi,gt)

Those figures indicate that only Age, Cholesterol, weight level and Smoking status contribute significantly to the Infarction pronosis.

task4f=makeClassifTask(id="Infarction",data=trainset[,-5],target="Infarction",positive = "Occured")

tasktest=makeClassifTask(id="Infarction",data=testset[,-5],target="Infarction",positive = "Occured")

We adjusted the input data for our classification task: the number of features was reduced to 4 (Education was excluded from both training and testing data).

Manual Grid-based Tuning process

Then, we begin our Manual Grid-based Tuning process in MLR:

The manual tuning process in MLR consists of testing every possible combinations between CP and Max Tree Depth values based on a 10x10 tuning grid and a 10x10 cross-validation, then voting for the model with best performance. Optimal combination of CP and MaxtreeDepth will be choosen for our algorithm.

learner$par.set
##                    Type len  Def   Constr Req Tunable Trafo
## minsplit        integer   -   20 1 to Inf   -    TRUE     -
## minbucket       integer   -    - 1 to Inf   -    TRUE     -
## cp              numeric   - 0.01   0 to 1   -    TRUE     -
## maxcompete      integer   -    4 0 to Inf   -    TRUE     -
## maxsurrogate    integer   -    5 0 to Inf   -    TRUE     -
## usesurrogate   discrete   -    2    0,1,2   -    TRUE     -
## surrogatestyle discrete   -    0      0,1   -    TRUE     -
## maxdepth        integer   -   30  1 to 30   -    TRUE     -
## xval            integer   -   10 0 to Inf   -   FALSE     -
## parms           untyped   -    -        -   -    TRUE     -
ps=makeParamSet(makeDiscreteParam("maxdepth",values = c(1,2,3,4,5,6,7)),makeNumericParam("cp",lower=0.01,upper=0.1))

ctrlgrid = makeTuneControlGrid()

rdesc = makeResampleDesc("RepCV",reps = 10,folds=10)

set.seed(123)
res=tuneParams(learner, task=task4f,resampling=rdesc,par.set=ps,control=ctrlgrid,measures = list(mmce,bac))

mmce$minimize
## [1] TRUE
res$x
## $maxdepth
## [1] 5
## 
## $cp
## [1] 0.02

The grid-based tuning is now completed.

Figure 5: Grid-based tuning result

resdf=generateHyperParsEffectData(res)

resdata=resdf$data%>%as_tibble()

resdata%>%ggplot(aes(x=cp,y=maxdepth))+geom_point(aes(size=mmce.test.mean,fill=mmce.test.mean),alpha=0.6,shape=21)+geom_vline(xintercept=res$x$cp,color="red",size=0.7)+geom_hline(yintercept=res$x$maxdepth,color="red",size=0.7)+scale_fill_gradient(high="purple",low="#ff0033")

resdata%>%ggplot(aes(x=cp,y=maxdepth))+geom_point(aes(size=bac.test.mean,fill=bac.test.mean),alpha=0.6,shape=21)+geom_vline(xintercept=res$x$cp,color="blue",size=0.7)+geom_hline(yintercept=res$x$maxdepth,color="blue",size=0.7)+scale_fill_gradient(low="purple",high="#ff0033")

The figures represent the optimal combination of CP and maxtreedepth parameters for the best performing model.

Training the Decision Tree using Optimized algorithm

learner2=setHyperPars(learner,par.vals = res$x)

cartmlr=mlr::train(learner2,task4f)
predmlr=predict(cartmlr,tasktest)

mets=list(auc,bac,tpr,tnr,mmce,ber,fpr,fnr)

performance(predmlr, measures =mets)
##        auc        bac        tpr        tnr       mmce        ber 
## 0.80602853 0.76123639 0.93710692 0.58536585 0.13500000 0.23876361 
##        fpr        fnr 
## 0.41463415 0.06289308
truth=predmlr$data$truth

confusionMatrix(predmlr$data$response, reference=predmlr$data$truth,positive ="Occured",mode="everything")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Occured
##    No       24      10
##    Occured  17     149
##                                           
##                Accuracy : 0.865           
##                  95% CI : (0.8097, 0.9091)
##     No Information Rate : 0.795           
##     P-Value [Acc > NIR] : 0.006939        
##                                           
##                   Kappa : 0.5578          
##  Mcnemar's Test P-Value : 0.248213        
##                                           
##             Sensitivity : 0.9371          
##             Specificity : 0.5854          
##          Pos Pred Value : 0.8976          
##          Neg Pred Value : 0.7059          
##               Precision : 0.8976          
##                  Recall : 0.9371          
##                      F1 : 0.9169          
##              Prevalence : 0.7950          
##          Detection Rate : 0.7450          
##    Detection Prevalence : 0.8300          
##       Balanced Accuracy : 0.7612          
##                                           
##        'Positive' Class : Occured         
## 

As we could see, the Decision Tree that based on a Manual, User-controlled Training process in MLR presents a remarkably higher performance than both models developped by CARET.

Figure 6: The best decision tree for making prognosis of Infarction

fancyRpartPlot(cartmlr$learner.model,palettes="RdPu")

Conclusion: Our experiment demonstrated that a manual, user-controlled model training that includes both feature selection and grid-based tuning can provide better model’s performance than any automated modelling process.