Foreword: About the Machine Learning in Medicine (MLM) project
The MLM project has been initialized in 2016 and aims to:
1.Encourage using Machine Learning techniques in medical research in Vietnam
2.Promote the use of R statistical programming language, an open source and leading tool for practicing data science.
Background
The 17th case study (Chapter 16-Machine Learning in Medicine-Cookbook 1 by T. J. Cleophas and A. H. Zwinderman, Springer, 2014) introduces the Tree-based model for developing clinical classification rule.
Decision tree is considered to be the most effective data mining algorithm for establishing the clinical decision rules. The tree based models present many advantages, including a full compatibility to all data types, simplicity in modelling and most of all, a straightforward solution for clinical decision making. Desicion tree models are widely supported by variant software packages, such as SAS, S-Plus, SPSS, Matlab, Weka and R. However, the modelling performance could be largely different among these packages.
Recently, some machine learning frameworks were introduced in R, likes CARET or MLR packages. These packages allow implementing easily a lot of prebuilt algorithms and offer useful features such as Training, Tuning and Testing. In this study, we will focus on the CART algorithm (Breiman et al. 1980), a popular method for growing a decision tree. In R, the CART algorithm is supported by rpart package. It should be noted that the train( ) function caret package performs simultaneously Tuning, Resampling and Training, whilst the mlr package handle independently these process.
Overfitting is a critical problem for CART modelling.There are two alternative solutions to overcome this problem: (1) Setting the limits of tree size and/or (2) pruning the tree. Despite that the CART algorithm is supported by the same package (rpart) in both caret and mlr, its tuning procedure is very different: In caret, CART model could be tuned using two alternative methods:
The “rpart” method consists of tuning the model based on Complexity parameter (cp) or Tree pruning, while “rpart2” method aims to optimize the maximum tree depth parameter (or setting the constraint on tree size).
The tuning procedure in mlr offers higher flexibility, as user can tune both cp and maxdepth values by screening every possible combinations between these two parameters (a grid-based or random tuning control). The mlr also allows manually filtering the features based on different criteria, such as information gain or gain ratio.
In this experiment, we will evaluate the performance of 3 different approaches: Automated training by caret using either rpart and rpart2 methods versus user-controlled modelling in mlr. Our hypothesis is that a manual tuning would provide better performance.
Materials and method
The original dataset was provided in Chaper 16, Machine Learning in Medicine-Cookbook 1 (T. J. Cleophas and A. H. Zwinderman, SpringerBriefs in Statistics 2014). You can get the original data in SPSS SAV format decisiontreebinary.sav) from their website: extras.springer.com.
Data analysis was performed in R statistical programming language. The caret and mlr packages will be used through our experiment.
The original dataset (n=1004) will be randomly splitted into Train (80%) and Test (20%) subsets. The Train subset (n=804) will be consistently used for model tuning and training in both caret and mlr. The performance of final models will be compared using the same test subset (n=200).
First, following packages must be established in R-studio (some other packages might also be required during the analysis), and we will personalize the aesthetic effects for ggplot2
library(foreign)
library(tidyverse)
library(caret)
library(mlr)
library(gridExtra)
library(rattle)
library(rpart.plot)
my_theme <- function(base_size =8, base_family = "sans"){
theme_minimal(base_size = base_size, base_family = base_family) +
theme(
axis.text = element_text(size =8),
axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
axis.title = element_text(size =8),
panel.grid.major = element_line(color = "gray"),
panel.grid.minor = element_blank(),
panel.background = element_rect(fill = "#fef5f9"),
strip.background = element_rect(fill = "#290029", color = "#290029", size =0.5),
strip.text = element_text(face = "bold", size = 8, color = "white"),
legend.position = "bottom",
legend.justification = "center",
legend.background = element_blank(),
panel.border = element_rect(color = "grey5", fill = NA, size = 0.5)
)
}
theme_set(my_theme())
myfillcolors=c("#ff0033","#330033", "#cc0033" , "#660033", "#990033","#630063")
Results
Step 0: Data loading and reshape
require(foreign)
df=read.spss("decisiontreebinary.sav",use.value.labels = TRUE,to.data.frame = TRUE)%>%as_tibble()
df<-df%>%dplyr::rename(.,Infarction=infarct_rating,Age=age,Cholesterol=cholesterol_level,Smoke=smoking,Education=education,Weight=weight_level)
df$Smoke%<>%as.factor()%>%recode_factor(.,`0` = "No", `1` = "Medium",`2` = "High")
df$Infarction%<>%as.factor()%>%recode_factor(.,`no` = "No", `yes` = "Occured")
df
## # A tibble: 1,004 × 6
## Infarction Age Cholesterol Smoke Education Weight
## <fctr> <dbl> <fctr> <fctr> <fctr> <fctr>
## 1 No 44.85561 Low No High school high
## 2 No 42.71335 Medium No High school high
## 3 No 43.33660 High No College high
## 4 No 44.01998 High No High school high
## 5 No 67.96738 Low No College high
## 6 No 40.30740 Medium No College high
## 7 No 66.55518 Low No College high
## 8 No 45.95118 Low No College high
## 9 No 52.26541 Low No High school high
## 10 No 43.86261 Low No High school high
## # ... with 994 more rows
Step 1) Data exploration
Figure 1: Descriptive analysis
smoke=ggplot(data=df,aes(x=Infarction,fill=Smoke))+geom_bar(position="fill",alpha=0.8,color="black")+scale_fill_manual(values=myfillcolors)+coord_flip()
chol=ggplot(data=df,aes(x=Infarction,fill=Cholesterol))+geom_bar(position="fill",alpha=0.8,color="black")+scale_fill_manual(values=myfillcolors)+coord_flip()
wt=ggplot(data=df,aes(x=Infarction,fill=Weight))+geom_bar(position="fill",alpha=0.8,color="black")+scale_fill_manual(values=myfillcolors)+coord_flip()
edu=ggplot(data=df,aes(x=Infarction,fill=Education))+geom_bar(position="fill",alpha=0.8,color="black")+scale_fill_manual(values=myfillcolors)+coord_flip()
a=df%>%ggplot(aes(x=Infarction,y=Age,fill=Infarction))+geom_boxplot(alpha=0.8)+scale_fill_manual(values=myfillcolors)+coord_flip()
grid.arrange(smoke,chol,wt,edu,a,ncol=2)
The data exploration showed a high contrast in Age, Cholesterol, Smoking status between two Infarction classes. The Education factor does not affect the Infarction outcome.
Data splitting using CARET
set.seed(123)
idTrain=createDataPartition(y=df$Infarction,p=0.8,list=FALSE)
trainset=df[idTrain,]
testset=df[-idTrain,]
Step 1: Growing Decision Tree in CARET, based CART algorithm and two Tuning methods
Control=trainControl(method= "repeatedcv",number=10,repeats=10,classProbs=TRUE,summaryFunction =multiClassSummary)
As mentioned above, the CART algorithm could be tuned by two alternative methods in CARET. We will develop 2 different versions : the cart1 model implied cp based pruning whilst the cart2 model consists of a maximum tree size tuning.
Figure 2: CP tuning based Decision Tree
cart1=caret::train(Infarction~.,data=trainset,method="rpart",trControl=Control,tuneLength=10)
cart1
## CART
##
## 804 samples
## 5 predictor
## 2 classes: 'No', 'Occured'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 724, 724, 723, 723, 723, 725, ...
## Resampling results across tuning parameters:
##
## cp logLoss ROC Accuracy Kappa Sensitivity
## 0.00000000 0.4164073 0.8431908 0.8430763 0.4761837 0.5101838
## 0.01137885 0.3916414 0.7772318 0.8461766 0.4652982 0.4730515
## 0.02275770 0.4022487 0.7536177 0.8471579 0.4653680 0.4689338
## 0.03413655 0.4122506 0.7317413 0.8465235 0.4870632 0.5226838
## 0.04551539 0.4123961 0.7284855 0.8472611 0.4930641 0.5321691
## 0.05689424 0.4132571 0.7272709 0.8470111 0.4911052 0.5285294
## 0.06827309 0.4140702 0.7269932 0.8459984 0.4889806 0.5285294
## 0.07965194 0.4140702 0.7269932 0.8459984 0.4889806 0.5285294
## 0.09103079 0.4204099 0.7132698 0.8421310 0.4603267 0.5030515
## 0.10240964 0.4693742 0.6178621 0.8126815 0.2463664 0.2826471
## Specificity Pos_Pred_Value Neg_Pred_Value Detection_Rate
## 0.9297842 0.6663804 0.8801372 0.10523916
## 0.9434152 0.6985555 0.8739534 0.09754309
## 0.9456176 0.7164750 0.8736163 0.09676647
## 0.9308656 0.6702012 0.8832972 0.10784877
## 0.9292907 0.6723378 0.8851412 0.10983476
## 0.9299206 0.6740622 0.8844241 0.10908476
## 0.9286508 0.6704020 0.8842582 0.10908476
## 0.9286508 0.6704020 0.8842582 0.10908476
## 0.9303795 0.6622607 0.8798929 0.10384375
## 0.9506052 0.6053396 0.8401661 0.05833533
## Balanced_Accuracy
## 0.7199840
## 0.7082333
## 0.7072757
## 0.7267747
## 0.7307299
## 0.7292250
## 0.7285901
## 0.7285901
## 0.7167155
## 0.6166261
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.04551539.
library(partykit)
library(party)
fancyRpartPlot(cart1$finalModel,palettes="RdPu")
The tree pruning result suggests that the optimized model would contain 3 decision nodes: Age, High cholesterol and Medium Smoking.
Figure 3: Max tree depth tuning based Decision Tree
cart2=caret::train(Infarction~.,data=trainset,method = "rpart2",trControl=Control,tuneLength=10)
## note: only 6 possible values of the max tree depth from the initial fit.
## Truncating the grid to 6 .
cart2
## CART
##
## 804 samples
## 5 predictor
## 2 classes: 'No', 'Occured'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 723, 723, 724, 724, 723, 723, ...
## Resampling results across tuning parameters:
##
## maxdepth logLoss ROC Accuracy Kappa Sensitivity
## 3 0.4154876 0.7227306 0.8461203 0.4787661 0.5050735
## 5 0.3892899 0.7668103 0.8539554 0.4712135 0.4489338
## 6 0.3951423 0.7685720 0.8514614 0.4729551 0.4653309
## 14 0.4050255 0.7931376 0.8483486 0.4804320 0.4970588
## 15 0.4050255 0.7931376 0.8483486 0.4804320 0.4970588
## 21 0.4050255 0.7931376 0.8483486 0.4804320 0.4970588
## Specificity Pos_Pred_Value Neg_Pred_Value Detection_Rate
## 0.9349281 0.6753042 0.8799109 0.10422236
## 0.9592684 0.7514931 0.8710764 0.09272201
## 0.9519023 0.7213028 0.8735411 0.09608787
## 0.9396652 0.6860649 0.8787761 0.10267606
## 0.9396652 0.6860649 0.8787761 0.10267606
## 0.9396652 0.6860649 0.8787761 0.10267606
## Balanced_Accuracy
## 0.7200008
## 0.7041011
## 0.7086166
## 0.7183620
## 0.7183620
## 0.7183620
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was maxdepth = 5.
fancyRpartPlot(cart2$finalModel,palettes="RdPu")
Based on the max tree depth tuning output, the optimized decision tree model would include up to 13 decision nodes as presented above.
A) Averaged confusion matrices
confusionMatrix(cart1,positive ="Occured")
## Cross-Validated (10 fold, repeated 10 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction No Occured
## No 11.0 5.6
## Occured 9.7 73.7
##
## Accuracy (average) : 0.8473
confusionMatrix(cart2,positive ="Occured")
## Cross-Validated (10 fold, repeated 10 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction No Occured
## No 9.3 3.2
## Occured 11.4 76.1
##
## Accuracy (average) : 0.854
The performance of those 2 models could be compared using the averaged confusion matrices. Results show that the max tree depth tuning provides a slightly better accuracy than the CP based pruning (0.853 vs 0.848)
B) Models comparison based on Test subset
pred1<-predict(cart1,testset,type="prob")%>%cbind(testset,.)
pred1$Predicted=predict(cart1,testset)
confusionMatrix(pred1$Predicted,reference=pred1$Infarction,positive ="Occured",mode="everything")
## Confusion Matrix and Statistics
##
## Reference
## Prediction No Occured
## No 19 13
## Occured 22 146
##
## Accuracy : 0.825
## 95% CI : (0.7651, 0.875)
## No Information Rate : 0.795
## P-Value [Acc > NIR] : 0.1679
##
## Kappa : 0.4155
## Mcnemar's Test P-Value : 0.1763
##
## Sensitivity : 0.9182
## Specificity : 0.4634
## Pos Pred Value : 0.8690
## Neg Pred Value : 0.5938
## Precision : 0.8690
## Recall : 0.9182
## F1 : 0.8930
## Prevalence : 0.7950
## Detection Rate : 0.7300
## Detection Prevalence : 0.8400
## Balanced Accuracy : 0.6908
##
## 'Positive' Class : Occured
##
pred2<-predict(cart2,testset,type="prob")%>%cbind(testset,.)
pred2$Predicted=predict(cart2,testset)
confusionMatrix(pred2$Predicted,reference=pred2$Infarction,positive ="Occured",mode="everything")
## Confusion Matrix and Statistics
##
## Reference
## Prediction No Occured
## No 21 8
## Occured 20 151
##
## Accuracy : 0.86
## 95% CI : (0.8041, 0.9049)
## No Information Rate : 0.795
## P-Value [Acc > NIR] : 0.01168
##
## Kappa : 0.5182
## Mcnemar's Test P-Value : 0.03764
##
## Sensitivity : 0.9497
## Specificity : 0.5122
## Pos Pred Value : 0.8830
## Neg Pred Value : 0.7241
## Precision : 0.8830
## Recall : 0.9497
## F1 : 0.9152
## Prevalence : 0.7950
## Detection Rate : 0.7550
## Detection Prevalence : 0.8550
## Balanced Accuracy : 0.7309
##
## 'Positive' Class : Occured
##
Independent testing results show that the max tree depth tuning method result a significantly higher performance than CP based model tuning.
Step 2: Manual, User-controlled Modelling in MLR package
task=makeClassifTask(id="Infarction",data=trainset,target="Infarction",positive = "Occured")
learner = makeLearner("classif.rpart", predict.type = "prob")
First, we created a Classfication Task and a basic CART learner in MLR package.
Feature selection
The feature filtering is very important, especially when you have a dataset with a large number of features. This simple data exploration step might improve incredibly your model’s performance and save a lot of your time for training the complex algorithm such as random Forest, neural network or Boosting.
Figure 4: Contribution of each feature to the Classification
ig=generateFilterValuesData(task,method="information.gain")%>%.$data%>%ggplot(aes(x=reorder(name,information.gain),y=information.gain,fill=reorder(name,information.gain)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors,name="information gain")+scale_x_discrete("Features")+coord_flip()
mr=generateFilterValuesData(task,method="mrmr")%>%.$data%>%ggplot(aes(x=reorder(name,mrmr),y=mrmr,fill=reorder(name,mrmr)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors,name="mrmr")+scale_x_discrete("Features")+coord_flip()
pmi=generateFilterValuesData(task,method="permutation.importance",imp.learner=learner)%>%.$data%>%ggplot(aes(x=reorder(name,permutation.importance),y=permutation.importance,fill=reorder(name,permutation.importance)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors,name="permut.Importance")+scale_x_discrete("Features")+coord_flip()
gt=generateFilterValuesData(task,method="gain.ratio")%>%.$data%>%ggplot(aes(x=reorder(name,gain.ratio),y=gain.ratio,fill=reorder(name,gain.ratio)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors,name="Gain ratio")+scale_x_discrete("Features")+coord_flip()
grid.arrange(ig,mr,pmi,gt)
Those figures indicate that only Age, Cholesterol, weight level and Smoking status contribute significantly to the Infarction pronosis.
task4f=makeClassifTask(id="Infarction",data=trainset[,-5],target="Infarction",positive = "Occured")
tasktest=makeClassifTask(id="Infarction",data=testset[,-5],target="Infarction",positive = "Occured")
We adjusted the input data for our classification task: the number of features was reduced to 4 (Education was excluded from both training and testing data).
Manual Grid-based Tuning process
Then, we begin our Manual Grid-based Tuning process in MLR:
The manual tuning process in MLR consists of testing every possible combinations between CP and Max Tree Depth values based on a 10x10 tuning grid and a 10x10 cross-validation, then voting for the model with best performance. Optimal combination of CP and MaxtreeDepth will be choosen for our algorithm.
learner$par.set
## Type len Def Constr Req Tunable Trafo
## minsplit integer - 20 1 to Inf - TRUE -
## minbucket integer - - 1 to Inf - TRUE -
## cp numeric - 0.01 0 to 1 - TRUE -
## maxcompete integer - 4 0 to Inf - TRUE -
## maxsurrogate integer - 5 0 to Inf - TRUE -
## usesurrogate discrete - 2 0,1,2 - TRUE -
## surrogatestyle discrete - 0 0,1 - TRUE -
## maxdepth integer - 30 1 to 30 - TRUE -
## xval integer - 10 0 to Inf - FALSE -
## parms untyped - - - - TRUE -
ps=makeParamSet(makeDiscreteParam("maxdepth",values = c(1,2,3,4,5,6,7)),makeNumericParam("cp",lower=0.01,upper=0.1))
ctrlgrid = makeTuneControlGrid()
rdesc = makeResampleDesc("RepCV",reps = 10,folds=10)
set.seed(123)
res=tuneParams(learner, task=task4f,resampling=rdesc,par.set=ps,control=ctrlgrid,measures = list(mmce,bac))
mmce$minimize
## [1] TRUE
res$x
## $maxdepth
## [1] 5
##
## $cp
## [1] 0.02
The grid-based tuning is now completed.
Figure 5: Grid-based tuning result
resdf=generateHyperParsEffectData(res)
resdata=resdf$data%>%as_tibble()
resdata%>%ggplot(aes(x=cp,y=maxdepth))+geom_point(aes(size=mmce.test.mean,fill=mmce.test.mean),alpha=0.6,shape=21)+geom_vline(xintercept=res$x$cp,color="red",size=0.7)+geom_hline(yintercept=res$x$maxdepth,color="red",size=0.7)+scale_fill_gradient(high="purple",low="#ff0033")
resdata%>%ggplot(aes(x=cp,y=maxdepth))+geom_point(aes(size=bac.test.mean,fill=bac.test.mean),alpha=0.6,shape=21)+geom_vline(xintercept=res$x$cp,color="blue",size=0.7)+geom_hline(yintercept=res$x$maxdepth,color="blue",size=0.7)+scale_fill_gradient(low="purple",high="#ff0033")
The figures represent the optimal combination of CP and maxtreedepth parameters for the best performing model.
Training the Decision Tree using Optimized algorithm
learner2=setHyperPars(learner,par.vals = res$x)
cartmlr=mlr::train(learner2,task4f)
predmlr=predict(cartmlr,tasktest)
mets=list(auc,bac,tpr,tnr,mmce,ber,fpr,fnr)
performance(predmlr, measures =mets)
## auc bac tpr tnr mmce ber
## 0.80602853 0.76123639 0.93710692 0.58536585 0.13500000 0.23876361
## fpr fnr
## 0.41463415 0.06289308
truth=predmlr$data$truth
confusionMatrix(predmlr$data$response, reference=predmlr$data$truth,positive ="Occured",mode="everything")
## Confusion Matrix and Statistics
##
## Reference
## Prediction No Occured
## No 24 10
## Occured 17 149
##
## Accuracy : 0.865
## 95% CI : (0.8097, 0.9091)
## No Information Rate : 0.795
## P-Value [Acc > NIR] : 0.006939
##
## Kappa : 0.5578
## Mcnemar's Test P-Value : 0.248213
##
## Sensitivity : 0.9371
## Specificity : 0.5854
## Pos Pred Value : 0.8976
## Neg Pred Value : 0.7059
## Precision : 0.8976
## Recall : 0.9371
## F1 : 0.9169
## Prevalence : 0.7950
## Detection Rate : 0.7450
## Detection Prevalence : 0.8300
## Balanced Accuracy : 0.7612
##
## 'Positive' Class : Occured
##
As we could see, the Decision Tree that based on a Manual, User-controlled Training process in MLR presents a remarkably higher performance than both models developped by CARET.
Figure 6: The best decision tree for making prognosis of Infarction
fancyRpartPlot(cartmlr$learner.model,palettes="RdPu")
Conclusion: Our experiment demonstrated that a manual, user-controlled model training that includes both feature selection and grid-based tuning can provide better model’s performance than any automated modelling process.