Case study N°38: Optimising Clinical Decision Making

Foreword: About the Machine Learning in Medicine (MLM) project

The MLM project has been initialized in 2016 and aims to:

1.Encourage using Machine Learning techniques in medical research in Vietnam

2.Promote the use of R statistical programming language, an open source and leading tool for practicing data science.

Background

The 38th case study (Chapter 15 - Machine Learning in Medicine-Cookbook 2 by T. J. Cleophas and A. H. Zwinderman, Springer, 2014) focus on the question: How can Machine learning be used for optimising the clinical decision making.

In this datamining experiment, 90 patients with sepsis are treated with three different treatments. Various clinical data were measured as treatment outcomes, including the Aspartate aminotransferase, Alanine aminotransferase, Creatinine, Urea, c-reactive protein. Low blood pressure (3 levels) was recorded as treatment’s side-effect.

Unlike previous studies on Machine learning, this time our goal is no more predicting what will happen to the patient, but to improve our decision making with the help of computerized algorithms. So we will teach the machine to build an optimized treatment flowchart based on ancient clinical data.

As decision making is involved, the most appropriate method is Decision Tree. This study will explore the ability of CART algorithm for classifying patients into 3 different treatment modalities, in function of their outcomes data.

Materials and method

The original dataset was provided in Chaper 15, Machine Learning in Medicine-Cookbook Two (T. J. Cleophas and A. H. Zwinderman, SpringerBriefs in Statistics 2014). You can get the original data in SPSS SAV format “chap15spssmodeler.sav”) from their website: extras.springer.com.

Data analysis was performed in R statistical programming language. The mlr package (https://mlr-org.github.io/mlr-tutorial/release/html/index.html) and caret package (http://topepo.github.io/caret) will be used. Our experiment consists of training a multiclass classification model that based on CART algorithm, then validating this model via a resampling by bootstrap.

First, following packages must be established in R-studio (some other packages might also be required during the analysis), and we will personalize the aesthetic effects for ggplot2.

library(plyr)
library(tidyverse)
library(foreign)
library(mlr)
library(partykit)
library(party)
library(partykit)
library(caret)

my_theme <- function(base_size = 10, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
    theme(
      axis.text = element_text(size = 10),
      axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
      axis.title = element_text(size = 11),
      panel.grid.major = element_line(color = "grey"),
      panel.grid.minor = element_blank(),
      panel.background = element_rect(fill = "#f4fff9"),
      strip.background = element_rect(fill = "#006642", color = "#006642", size =0.5),
      strip.text = element_text(face = "bold", size = 10, color = "white"),
      legend.position = "bottom",
      legend.justification = "center",
      legend.background = element_blank(),
      panel.border = element_rect(color = "grey30", fill = NA, size = 0.5)
    )
}
theme_set(my_theme())

myfillcolors=c("#ff003b","#00aeff","#00ff83","#ffae00","#c300ff","#ff0000","blue")
mycolors=c("#aa0027","#016391","#008937","#ff9000","#7201aa","red4","#2600ff")

Results

Step 0: Data loading and reshape

require(foreign)

df=read.spss("chap15spssmodeler.sav",use.value.labels = TRUE,to.data.frame = TRUE)%>%as_tibble()

df$treatment=recode_factor(df$treatment,`1` = "Mode_1", `2` = "Mode_2",`3` = "Mode_3")
df$death=recode_factor(df$death,`no` = "Survive", `yes` = "Dead")

df
## # A tibble: 90 × 9
##     asat  alat ureum creatinine creactiveprotein leucos treatment
##    <dbl> <dbl> <dbl>      <dbl>            <dbl>  <dbl>    <fctr>
## 1      5    29   2.4         79               18     16    Mode_1
## 2     10    30   2.1         94               15     15    Mode_1
## 3      8    31   2.3         79               16     14    Mode_1
## 4      6    16   2.7         80               17     19    Mode_1
## 5      6    16   2.2         84               18     20    Mode_1
## 6      5    13   2.1         78               17     21    Mode_1
## 7     10    16   3.1         85               20     18    Mode_1
## 8      8    28   8.0         68               15     18    Mode_1
## 9      7    27   7.8         74               16     17    Mode_1
## 10     6    26   8.4         69               18     16    Mode_1
## # ... with 80 more rows, and 2 more variables: lowbloodpressure <fctr>,
## #   death <fctr>

Once the above commmands executed, we will get a clean dataframe, that will be used throughout our analysis.

Step 1) Data exploration

Table 1: Patient’s outcomes in 3 treatment modalities

psych::describeBy(df[,-c(7:9)],group=df$treatment)
## $Mode_1
##                  vars  n   mean     sd median trimmed   mad min   max
## asat                1 35  21.51  62.99      9   10.69  4.45   5 382.0
## alat                2 35  32.06  58.52     22   22.66  8.90  11 366.0
## ureum               3 35   8.64  10.29      5    6.28  3.85   2  41.8
## creatinine          4 35 125.14 149.06     79   87.62 10.38  59 765.0
## creactiveprotein    5 35  25.60  26.01     17   18.41  1.48  14 111.0
## leucos              6 35  20.00   6.74     18   18.66  2.97  14  41.0
##                  range skew kurtosis    se
## asat             377.0 5.35    27.65 10.65
## alat             355.0 5.30    27.30  9.89
## ureum             39.8 2.21     3.92  1.74
## creatinine       706.0 3.09     9.01 25.19
## creactiveprotein  97.0 2.60     5.25  4.40
## leucos            27.0 2.13     3.53  1.14
## 
## $Mode_2
##                  vars  n   mean     sd median trimmed    mad min max range
## asat                1 36 270.67 168.86 256.00  253.43 115.64   8 754   746
## alat                2 36 277.47 215.93 188.00  244.17  94.89  20 879   859
## ureum               3 36  16.48  10.25  12.75   14.99   4.45   4  43    39
## creatinine          4 36 232.22 132.24 189.00  213.27  61.53  65 639   574
## creactiveprotein    5 36  45.53  25.42  38.00   41.60   3.71  14 131   117
## leucos              6 36  28.25   4.64  28.00   28.33   1.48  17  42    25
##                  skew kurtosis    se
## asat             1.14     1.37 28.14
## alat             1.56     1.86 35.99
## ureum            1.41     0.90  1.71
## creatinine       1.49     1.88 22.04
## creactiveprotein 1.94     3.26  4.24
## leucos           0.15     1.69  0.77
## 
## $Mode_3
##                  vars  n    mean     sd median trimmed    mad min  max
## asat                1 19 1156.53 623.29  872.0 1174.41 581.18   9 2000
## alat                2 19  745.47 239.59  847.0  773.88 134.92  32  976
## ureum               3 19   49.07  16.91   41.8   48.79  10.38  20   83
## creatinine          4 19  621.53 122.79  600.0  623.94 105.26 341  861
## creactiveprotein    5 19   63.95  45.30   58.0   62.88  62.27  15  131
## leucos              6 19   36.68   3.06   36.0   36.76   2.97  30   42
##                  range  skew kurtosis     se
## asat              1991  0.10    -1.45 142.99
## alat               944 -1.47     1.69  54.96
## ureum               63  0.42    -1.01   3.88
## creatinine         520 -0.10    -0.22  28.17
## creactiveprotein   116  0.07    -1.82  10.39
## leucos              12  0.16    -0.23   0.70
## 
## attr(,"call")
## by.data.frame(data = x, INDICES = group, FUN = describe, type = type)

Figure 1: Relationship between the outcome markers and treatment classes

plotfuncLow <- function(data,mapping){
  p <- ggplot(data = data,mapping=mapping)+geom_point(aes(fill=df$treatment),shape=21,color="black")+stat_density2d(geom="polygon",aes(fill=df$treatment,alpha = ..level..))+scale_fill_manual(values=myfillcolors)
  p
}

plotfuncmid <- function(data,mapping){
  p <- ggplot(data = data,mapping=mapping)+geom_density(aes(fill=df$treatment),alpha=0.5,color="black")+scale_fill_manual(values=myfillcolors)
  p
}

library(GGally)

datalog<-df%>%.[,c(1:6)]%>%log(.)%>%mutate(.,treatment=df$treatment)

ggpairs(datalog,columns=1:6,lower=list(continuous=plotfuncLow),upper=NULL,diag=list(continuous=plotfuncmid))

The figure 1 provides some useful information about the relationship between the biomarkers (in logarithmic scale), as well as their distribution across 3 treatment modes. Those graphs indicates a clear contrast between the treatment classes, and any of those features could be used as the marker for evaluating treatment outcome.

Figure 2A-B: The pattern of treatment outcomes in function of Type of treatment and Mortality

dfscale<-df[,c(1:6)]%>%mutate(.,Bloodpressure=as.numeric(df$lowbloodpressure))%>%as.matrix()%>%scale()%>%as_tibble()%>%mutate(.,Treatment=df$treatment,Outcome=df$death,Id=row.names(.))

dfscale%>%gather(asat:Bloodpressure,key="Marker",value="Value")%>%ggplot(aes(x=reorder(Id,Value),y=reorder(Marker,Value),fill=Value))+geom_tile(color="black",show.legend=T)+facet_wrap(~Treatment,shrink=T,scale="free",ncol=1)+scale_fill_gradient2(low="#dafc00",mid="#fcc500",high="#fc0060",midpoint=1)+theme(axis.text.x=element_blank())+scale_y_discrete("Markers")+scale_x_discrete("Patient's Id")

dfscale%>%gather(asat:Bloodpressure,key="Marker",value="Value")%>%ggplot(aes(x=reorder(Id,Value),y=reorder(Marker,Value),fill=Value))+geom_tile(color="black",show.legend=T)+facet_wrap(~Outcome,shrink=T,scale="free",ncol=1)+scale_fill_gradient2(low="#dafc00",mid="#fcc500",high="#fc0060",midpoint=0.5)+theme(axis.text.x=element_blank())+scale_y_discrete("Markers")+scale_x_discrete("Patient's Id")

The Figure 2 indicates that: Mode 2 and 3 seem more likely to be associated to Low blood pressure and higher mortality rate. Those treatment modes also implied more severe clinical conditions, as reflected by a higher level of 6 biomarkers. It should be noted that those are baseline levels, as the data did not provide any information about the over time change of those markers in response to treatments, so these results would be interpreted as: “The treatment modes 2 and 3 seems more appropriate for patients in critical conditions of sepsis, and there still be a higher mortality risk among those patients”

Step 2) CART model by CARET

library(caret)

set.seed(123)
CART<-caret::train(treatment~.,
            data=df[,-9],
            method="rpart2",
            preProcess = NULL,
            tuneLength=5,
            trControl = trainControl(method = "boot", number=500L,verboseIter = FALSE,summaryFunction=multiClassSummary)
)
## note: only 2 possible values of the max tree depth from the initial fit.
##  Truncating the grid to 2 .

We just performed an integrated process in CARET, this includes both model tuning and training.

CART
## CART 
## 
## 90 samples
##  7 predictor
##  3 classes: 'Mode_1', 'Mode_2', 'Mode_3' 
## 
## No pre-processing
## Resampling: Bootstrapped (500 reps) 
## Summary of sample sizes: 90, 90, 90, 90, 90, 90, ... 
## Resampling results across tuning parameters:
## 
##   maxdepth  Accuracy   Kappa      Mean_Sensitivity  Mean_Specificity
##   1         0.7223380  0.5464793  0.6266117         0.8489765       
##   2         0.8584367  0.7782086  0.8555079         0.9283918       
##   Mean_Pos_Pred_Value  Mean_Neg_Pred_Value  Mean_Detection_Rate
##   0.8216216            0.8855241            0.2407793          
##   0.8582721            0.9302901            0.2861456          
##   Mean_Balanced_Accuracy
##   0.7377941             
##   0.8919499             
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was maxdepth = 2.

The final output shows that the max tree depth was fixed at 2, so the final model contains 2 nodes. This model showed a good performance with balanced accuracy = 0.89 and Kappa coefficient = 0.86

Unfortunately, the caret package doesn’t support the special metrics for multiclass problem. We will return to them on the next move…

Figure 3: The Decision Tree

library(partykit)
library(party)
plot(as.party(CART$finalModel))

The Figure 3 represents our Decision Tree (or Treatment flow chart). This implies 2 markers: ASAT and Creatinine at 2 cutoff levels. This flowchart allows us to assign one among 3 treatment modes to patient.

For example: If the ASAT level is above 77 UI and creatinine level above 416.5, the patient would receive the Mode 3 of treatment.

confusionMatrix(CART)
## Bootstrapped (500 reps) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction Mode_1 Mode_2 Mode_3
##     Mode_1   36.2    3.5    1.1
##     Mode_2    2.1   32.5    2.9
##     Mode_3    0.3    4.2   17.2
##                             
##  Accuracy (average) : 0.8583

The confusion matrix analysis showed the averaged performance of model over all iterations. the misclassication rate seems very low for all 3 reference classes.

Table 4A,B,C: Averaged performance of CART model over 500 bootstrap iterations

resCART<-CART$resample%>%as_tibble()%>%mutate(.,Iter=row.names(.))
resCART=dplyr::rename(resCART,ACC=Accuracy,KAP=Kappa,SEN=Mean_Sensitivity,SPEC=Mean_Specificity,PPV=Mean_Pos_Pred_Value,NPV=Mean_Neg_Pred_Value,BAC=Mean_Balanced_Accuracy)%>%.[,-c(7,9)]
resCART=gather(resCART,ACC:BAC,key="Metric",value="Value")

resCART%>%ggplot(aes(x=Metric,y=Value,fill=reorder(Metric,Value),color=reorder(Metric,Value)))+stat_summary(fun.ymin=min,fun.ymax=max,fun.y="median",shape=21,size=1)+coord_flip()+scale_color_manual(values=mycolors,name="Metrics") +
  scale_fill_manual(values=myfillcolors,name="Metrics")

resCART%>%ggplot(aes(x=Metric,y=Value))+geom_boxplot(aes(fill=Metric),alpha=0.6)+coord_flip()+facet_wrap(~Metric,scales="free",ncol=2)+theme(axis.text.y=element_blank())+scale_fill_brewer(palette = "Set1")+scale_color_brewer(palette = "Set1")

resCART%>%ggplot(aes(x=as.numeric(Iter),y=Value))+geom_path(aes(color=Metric),size=0.7,alpha=0.9)+facet_wrap(~Metric,scales="free",ncol=2)+scale_color_manual(values=mycolors,name="Metrics")

Figure 5: Important variables in the model

vim<-varImp(CART)%>%.$importance%>%as_tibble()%>%mutate(.,Marker=row.names(.))%>%.[c(1:6),]
vim%>%ggplot(aes(x=reorder(Marker,Overall),y=Overall,fill=reorder(Marker,Overall)))+geom_bar(stat="identity",alpha=0.7,width=0.5,color="black")+scale_x_discrete("Markers")+scale_fill_manual(values=myfillcolors,name="Markers")+coord_flip()

Step 3: CART model by MLR package

As mentionned above, the caret doesn’t support the full evaluation of multiclass classification tasks. Therefore we will perform again the model training in MLR package;

library(mlr)

task = makeClassifTask(id = "Multiclass", data = df[,c(1:7)], target = "treatment")
learner = makeLearner("classif.rpart", predict.type = "prob",fix.factors.prediction = TRUE,maxdepth=2)

CART2=mlr::train(learner,task)

Figure 6: Decision Tree (Again) trained by mlr package

library(partykit)
library(party)
plot(as.party(CART2$learner.model))

As you can see, the Decision tree is identical to that provided by Caret.

Figure 7: Model’s performance evaluated by bootstraping

rdesc = makeResampleDesc("Bootstrap", iters=500L)
rdf<-resample(learner,task,rdesc,measures=list(multiclass.au1p,multiclass.aunp,multiclass.brier))%>%.$measures.test%>%gather(.,multiclass.au1p:multiclass.brier,key="Metric",value="Value")

rdf%>%ggplot(aes(x=iter,y=Value))+geom_path(size=1,aes(col=Metric))+facet_wrap(~Metric,scales="free",ncol=1)+scale_color_manual(values=mycolors)

rdf%>%ggplot(aes(x=Metric,y=Value))+geom_boxplot(alpha=0.7,aes(fill=Metric,col=Metric))+facet_wrap(~Metric,scales="free",ncol=1)+coord_flip()+scale_color_manual(values=mycolors)+scale_fill_manual(values=myfillcolors)

rdf%>%ggplot(aes(x=Value))+geom_histogram(alpha=0.7,aes(fill=Metric,col=Metric))+facet_wrap(~Metric,scales="free",ncol=2)+scale_color_manual(values=mycolors)+scale_fill_manual(values=myfillcolors)

Note: multiclass.au1p = Weighted average 1 vs. 1 multiclass AUC. Computes AUC of c(c - 1) binary classifiers while considering the a priori distribution of the classes.

multiclass.aunp = Weighted average 1 vs. rest multiclass AUC. Computes the AUC treating a c-dimensional classifier as c two-dimensional classifiers, taking into account the prior probability of each class.

Multiclass Brier score: Defined as: (1/n) sum_i sum_j (y_ij - p_ij)^2, where y_ij = 1 if observation i has class j (else 0), and p_ij is the predicted probability of observation i for class j.

See: http://docs.lib.noaa.gov/rescue/mwr/078/mwr-078-01-0001.pdf

and https://www.math.ucdavis.edu/~saito/data/roc/ferri-class-perf-metrics.pdf

for futher information

Figure 8A: Agreement between the model’s outcome and the truth

pred=predict(CART,newdata=df[,-9])
predf=df%>%mutate(.,Class=pred)

predf%>%ggplot(aes(x=treatment,fill=death,color=death))+geom_bar(alpha=0.7)+facet_wrap(~Class,ncol=1)+scale_fill_manual(values=myfillcolors,name="Outcome")+scale_color_manual(values=mycolors,name="Outcome")+scale_x_discrete("Given Treatment")+coord_flip()

The treatment modes 2 and 3 are associated to a higher mortality rate, particularly if the patient’s treatment is undrerated (given lower grade of treatment).

Figure 8B: Association between the treatment modalities and Low blood pressure

predf%>%ggplot(aes(x=treatment,fill=lowbloodpressure,color=lowbloodpressure))+geom_bar(alpha=0.7)+facet_wrap(~Class,ncol=1)+scale_fill_manual(values=myfillcolors,name="Low blood Pressure")+scale_color_manual(values=mycolors,name="Low blood pressure")+scale_x_discrete("Given Treatment")+coord_flip()

The figure 8B indicates that patients treated by modes 2 and 3 seems more likely to develop a low blood pressure outcome.

Conclusion: The Tree based models, including C5.0, C4.5 or CART algorithms are useful for optimising the therapeutic decision making. This approach provides many advantages, including a simple and straightforward application in clinical practice, as well that a high performance. Their accuracy is better than other data mining strategies, such as step-wise or arbitrary clinical scores.

END