Case study X3: Multilabel classification problem

Foreword: About the Machine Learning in Medicine (MLM) project

The MLM project has been initialized in 2016 by a group of young data analysts. Our main objective is to encourage using Machine Learning techniques in medical research in Vietnam and to promote the use of R statistical programming language, an open source and leading tool for practicing data science.

In the beginning, our case studies focused on the basic and popular machine learning techniques such as logistic regression, decision tree or support vector machine and were entirely based on the simple, simulated datasets from the Machine Learning in Medicine Cookbooks by T. J. Cleophas and A. H. Zwinderman (Springer 2014). However, in the extended parts of this project, the case studies will evolve toward the real world problems, that imply real data, more difficult tasks and more sophisticated methods.

In the present case study, we will confront for the first time a study question that involves multilabel classification problem.

Context

Chronic kidney disease (CKD) or chronic renal failure, is defined as a gradual loss of kidney function over a long period. The clinical symptoms of CKD develop overtime and may include: Loss of appetite, high blood pressure, edema, rise in potassium (hyperkalemia), anemia, decreased immune response. A comprehensive blood test can provide useful information, such as levels of glucose, potassium, sodium, concentration of creatinine and urea, as well as the Packed Cell Volume and quantity of blood cells.

As an early detection of CKD might help prevent its progression to kidney failure, the main objective of this study is to develop a classification rule that allows to correctly identify a patient with CKD based on physical symptoms and data from blood analysis.

Considering that diabetes mellitus is the most common cause of CKD and high blood pressure could be either an underlying cause or a complication of CKD, we attempt to develop some multilabel classification rules. These algorithms allow to classify CKD, Hypertension and Diabetes as unified pathological entity, with or without hierarchical association.

Materials and method

The original dataset in this study was created by Dr.P.Soundarapandian, L.Jerlin Rubini and Dr.P.Eswaran (India).

The study was real, in which physical symptoms, clinical and blood test data were recorded from 401 patients. Among them, 250 patients had Chronic kidney disease (CKD), 147 patients were diagnosed with Hypertension and Diabetes has been diagnosed in 137 patients. Original data and further descriptions could be downloaded from the famous UCI website:

https://archive.ics.uci.edu/ml/datasets/Chronic_Kidney_Disease#

Following packages must be established in R-studio (some other packages might also be required during the analysis), and we will personalize the aesthetic effects for ggplot2.

library(tidyverse)
library(caret)
library(mlr)

my_theme <- function(base_size = 10, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
    theme(
      axis.text = element_text(size = 10),
      axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
      axis.title = element_text(size = 10),
      panel.grid.major = element_line(color = "gray"),
      panel.grid.minor = element_blank(),
      panel.background = element_rect(fill = "#f7fdff"),
      strip.background = element_rect(fill = "#001d60", color = "#00113a", size =0.5),
      strip.text = element_text(face = "bold", size = 10, color = "white"),
      legend.position = "bottom",
      legend.justification = "center",
      legend.background = element_blank(),
      panel.border = element_rect(color = "grey5", fill = NA, size = 0.5)
    )
}

theme_set(my_theme())

myfillcolors=c("#ff003f","#0094ff", "#ae00ff" , "#94ff00", "#ffc700","#fc1814")
mycolors=c("#db0229","#026bdb","#48039e","#0d7502","#c97c02","#c40c09")

Different to binary classification and multiclass classification problems, multilabel classification problem consists of all situations when two or more target labels can be assigned to one observation at a time instead of one. This problem is usually confronted in medical studies when we focus on comorbidities, complications, or any simultaneous, coincidental events that could be observed at a time in one patient.

To deal with this problem, two different approaches could be suggested:

  1. Problem transformation methods consists of transforming the multilabel classification into binary or multiclass classification problems. All available algorithms are compatible with binary classification problem and some of them can handle both binary and multiclass problems.

  2. Algorithm adaptation methods consists of adapting the existing multiclass algorithms so they could be applied to the multilabel problem. The mlr package is the only Machine learning framework in R that supports the multilabel learning. In our study, we will explore every possible solutions for our study question using mlr.

Preparing the data

After having downloaded and extracted the dataset, we will load it in R. As the original dataset was compiled in arff format with a very complicated structure, a massive cleaning process will be required. Please replicate the codes exactly as presented below:

ckd=read.csv("chronic_kidney_disease.arff",header=F,comment.char="@",na.strings=c("?","","\t?"),)%>%as_tibble()

names(ckd)=c("Age","BP","SPG","Albumin","Sugar","RBC","PC","PCC","Bact","BGlu","BUrea","SerCreat","Sodium","K","Hb","PCV","WBC","RBCCount","HTA","DIA","CAD","APP","PedEdema","Anemia","CKD")

ckd$DIA=recode_factor(ckd$DIA,`yes`="TRUE",`no`="FALSE",`\tno`="FALSE",`\tyes`="TRUE",` yes`="TRUE")%>%as.logical()

ckd$CKD=recode_factor(ckd$CKD,`ckd`="TRUE",`no`="FALSE",`ckd\t`="TRUE",`notckd`="FALSE")%>%as.logical()

ckd$HTA=recode_factor(ckd$HTA,`yes`="TRUE",`no`="FALSE")%>%as.logical()

ckd$Age=recode_factor(ckd$Age,`notckd`="NA")%>%as.numeric()

df=select(ckd,c(Age,BP,RBCCount,SPG,Albumin,Sugar,PC,PCC,Bact,BGlu,BUrea,
                SerCreat,Sodium,K,Hb,PCV,WBC,Anemia,PedEdema,DIA,
                HTA,CKD))

df=subset(df,DIA!="NA" & HTA!="NA" & CKD!="NA")%>%as_tibble()

df$BP=as.numeric(df$BP)
df$Albumin=as.numeric(df$Albumin)
df$Sugar=as.numeric(df$Sugar)
df$BGlu=as.numeric(df$BGlu)
df$PCV=as.numeric(df$PCV)
df$WBC=as.numeric(df$WBC)

rm(ckd)

After cleaning the noises in original arff contents and renamed the variables, we need to get our data in the right format. The multilabel classification requires a dataframe which consists of the features and a logical vector (i.e coded as TRUE or FALSE values) for each label. If the target variable in your original data consists of a multilevel factor, you need to transform it to dummy variables then recode their values from (0,1) to logical FALSE/TRUE. In our case, there are 3 logical variables that correspond to 3 studied labels: DIA, HTA and CKD.

We also removed 4 cases with missing values in target labels.

The cleaned data is stored in a data.frame (or tibble) contains 22 variables, including 3 target variables and 19 features.

Important note: The 3 labels must be ranged by this order, so later we can develop the algorithms that based on causative assumption.

Data exploration

str(df)
## Classes 'tbl_df', 'tbl' and 'data.frame':    397 obs. of  22 variables:
##  $ Age     : num  38 62 54 38 42 52 60 13 43 44 ...
##  $ BP      : num  80 50 80 70 80 90 70 NA 100 90 ...
##  $ RBCCount: num  5.2 NA NA 3.9 4.6 4.4 NA 5 4 3.7 ...
##  $ SPG     : num  1.02 1.02 1.01 1 1.01 ...
##  $ Albumin : num  1 4 2 4 2 3 0 2 3 2 ...
##  $ Sugar   : num  0 0 3 0 0 0 0 4 0 0 ...
##  $ PC      : Factor w/ 2 levels "abnormal","normal": 2 2 2 1 2 NA 2 1 1 1 ...
##  $ PCC     : Factor w/ 2 levels "notpresent","present": 1 1 1 2 1 1 1 1 2 2 ...
##  $ Bact    : Factor w/ 2 levels "notpresent","present": 1 1 1 1 1 1 1 1 1 1 ...
##  $ BGlu    : num  121 NA 423 117 106 74 100 410 138 70 ...
##  $ BUrea   : num  36 18 53 56 26 25 54 31 60 107 ...
##  $ SerCreat: num  1.2 0.8 1.8 3.8 1.4 1.1 24 1.1 1.9 7.2 ...
##  $ Sodium  : num  NA NA NA 111 NA 142 104 NA NA 114 ...
##  $ K       : num  NA NA NA 2.5 NA 3.2 4 NA NA 3.7 ...
##  $ Hb      : num  15.4 11.3 9.6 11.2 11.6 12.2 12.4 12.4 10.8 9.5 ...
##  $ PCV     : num  44 38 31 32 35 39 36 44 33 29 ...
##  $ WBC     : num  7800 6000 7500 6700 7300 7800 NA 6900 9600 12100 ...
##  $ Anemia  : Factor w/ 2 levels "no","yes": 1 1 2 2 1 1 1 1 2 2 ...
##  $ PedEdema: Factor w/ 3 levels "good","no","yes": 2 2 2 3 2 3 2 3 2 2 ...
##  $ DIA     : logi  TRUE FALSE TRUE FALSE FALSE TRUE ...
##  $ HTA     : logi  TRUE FALSE FALSE TRUE FALSE TRUE ...
##  $ CKD     : logi  TRUE TRUE TRUE TRUE TRUE TRUE ...
df%>%.[,-c(7:9,18:22)]%>%psych::describe()
##          vars   n    mean      sd  median trimmed     mad    min      max
## Age         1 388   43.65   17.03   46.00   44.50   16.31    2.0    77.00
## BP          2 385   76.55   13.70   80.00   75.53   14.83   50.0   180.00
## RBCCount    3 266    4.70    1.03    4.80    4.73    1.04    2.1     8.00
## SPG         4 350    1.02    0.01    1.02    1.02    0.01    1.0     1.02
## Albumin     5 351    1.03    1.36    0.00    0.81    0.00    0.0     5.00
## Sugar       6 348    0.45    1.10    0.00    0.15    0.00    0.0     5.00
## BGlu        7 353  148.46   79.46  121.00  133.52   37.06   22.0   490.00
## BUrea       8 378   57.56   50.67   42.00   47.40   24.46    1.5   391.00
## SerCreat    9 380    3.09    5.76    1.30    1.93    0.89    0.4    76.00
## Sodium     10 310  137.49   10.44  138.00  138.31    4.45    4.5   163.00
## K          11 309    4.63    3.21    4.40    4.35    0.74    2.5    47.00
## Hb         12 345   12.50    2.91   12.60   12.61    3.41    3.1    17.80
## PCV        13 326   38.81    9.00   40.00   39.21   10.38    9.0    54.00
## WBC        14 291 8393.13 2953.33 8000.00 8133.05 2520.42 2200.0 26400.00
##             range  skew kurtosis     se
## Age         75.00 -0.40    -0.58   0.86
## BP         130.00  1.59     8.44   0.70
## RBCCount     5.90 -0.17    -0.33   0.06
## SPG          0.02 -0.16    -1.16   0.00
## Albumin      5.00  0.98    -0.43   0.07
## Sugar        5.00  2.43     4.84   0.06
## BGlu       468.00  1.98     4.06   4.23
## BUrea      389.50  2.60     9.05   2.61
## SerCreat    75.60  7.42    77.28   0.30
## Sodium     158.50 -6.92    83.27   0.59
## K           44.50 11.42   138.02   0.18
## Hb          14.70 -0.32    -0.50   0.16
## PCV         45.00 -0.42    -0.35   0.50
## WBC      24200.00  1.62     5.97 173.13
dfscale<-df%>%.[,-c(7:9,18:22)]%>%as.matrix()%>%scale()%>%as_tibble()%>%mutate(.,Hypertension=df$HTA,Diabetes=df$DIA,CKD=df$CKD,Id=row.names(.))

library(viridis)

dfscale%>%gather(Age:WBC,key="Parameter",value="Value")%>%ggplot(aes(x=reorder(Id,Value),y=reorder(Parameter,Value),fill=Value))+geom_tile(color="black",show.legend=T)+facet_wrap(~Hypertension,ncol=1,shrink=T,scale="free")+theme(axis.text.x=element_blank())+scale_y_discrete("Parameters")+scale_x_discrete("Patient's Id")+ggtitle("Hypertension")+scale_fill_viridis(option="B",begin=0,end=1)

dfscale%>%gather(Age:WBC,key="Parameter",value="Value")%>%ggplot(aes(x=reorder(Id,Value),y=reorder(Parameter,Value),fill=Value))+geom_tile(color="black",show.legend=T)+facet_wrap(~Diabetes,ncol=1,shrink=T,scale="free")+theme(axis.text.x=element_blank())+scale_y_discrete("Parameters")+scale_x_discrete("Patient's Id")+ggtitle("Diabetes")+scale_fill_viridis(option="D",begin=0,end=1)

dfscale%>%gather(Age:WBC,key="Parameter",value="Value")%>%ggplot(aes(x=reorder(Id,Value),y=reorder(Parameter,Value),fill=Value))+geom_tile(color="black",show.legend=T)+facet_wrap(~CKD,ncol=1,shrink=T,scale="free")+theme(axis.text.x=element_blank())+scale_y_discrete("Parameters")+scale_x_discrete("Patient's Id")+ggtitle("CKD")+scale_fill_viridis(option="C",begin=0,end=1)

g1=ggplot(data=df,aes(x=CKD,fill=HTA,color=HTA))+geom_bar(position="fill")+scale_fill_manual(values=myfillcolors)+scale_color_manual(values=mycolors)+coord_flip()

g2=df%>%ggplot(aes(x=CKD,fill=DIA,color=DIA))+geom_bar(position="fill")+scale_fill_manual(values=myfillcolors)+scale_color_manual(values=mycolors)+coord_flip()

g3=df%>%ggplot(aes(x=HTA,fill=DIA,color=DIA))+geom_bar(position="fill")+scale_fill_manual(values=myfillcolors)+scale_color_manual(values=mycolors)+coord_flip()

library(gridExtra)
grid.arrange(g1,g2,g3,ncol=2)

library(mldr)

datamldr=mldr_from_dataframe(df,labelIndices=c(20:22))
plot(type="LC",datamldr, color.function =rainbow)

Those figures indicate that CKD could be ruled-out in absence of Hypertension or Diabetes, while most of CKD patients also present Hight blood pressure and/or Diabetes. The Diabetes and Hypertension seem to be well associated.

Such high label mismatch is not a good sign for our (greedy) multilabel classification task, as the models could be overfitted. In the worst case, during resampling process or when data splitting is not well conducted, the balance of label combinations might be lost and the training process will be blocked (as the learner can no more handle the One class problem).

Due to label mismatch, we will not consider any resampling process, and the data plitting ratio will be fixed at 0.76 for Train subset and 0.24 for test subset. The models will be trained only one time on the train subset, then their performance will be validated one time on the test subset.

dflong=df%>%gather(c(Age:Sugar,BGlu:WBC),key="Parameter",value="Value")

dflong%>%ggplot(aes(x=Value,fill=HTA))+geom_density(alpha=0.6)+ggtitle("Hypertension")+facet_wrap(~Parameter,ncol=5,scales="free")+scale_fill_manual(values=myfillcolors)

dflong%>%ggplot(aes(x=Value,fill=CKD))+geom_density(alpha=0.6)+ggtitle("Chronic kidney disease")+facet_wrap(~Parameter,ncol=5,scales="free")+scale_fill_manual(values=myfillcolors)

dflong%>%ggplot(aes(x=Value,fill=DIA))+geom_density(alpha=0.6)+ggtitle("Diabetes")+facet_wrap(~Parameter,ncol=5,scales="free")+scale_fill_manual(values=myfillcolors)

library(mlr)

taskHTA= makeClassifTask(id = "HTA", data=df[,-c(20,22)],target = "HTA",positive="TRUE")
taskDIA= makeClassifTask(id = "DIA", data=df[,-c(21,22)],target = "DIA",positive="TRUE")
taskCKD= makeClassifTask(id = "CKD", data=df[,-c(20,21)],target = "CKD",positive="TRUE")

generateFilterValuesData(taskHTA,method="information.gain")%>%.$data%>%ggplot(aes(x=reorder(name,information.gain),y=information.gain,fill=reorder(name,information.gain)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_x_discrete("Features")+coord_flip()+ggtitle("Target=Hypertension")

generateFilterValuesData(taskDIA,method="information.gain")%>%.$data%>%ggplot(aes(x=reorder(name,information.gain),y=information.gain,fill=reorder(name,information.gain)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_x_discrete("Features")+coord_flip()+ggtitle("Target=Diabetes")

generateFilterValuesData(taskCKD,method="information.gain")%>%.$data%>%ggplot(aes(x=reorder(name,information.gain),y=information.gain,fill=reorder(name,information.gain)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_x_discrete("Features")+coord_flip()+ggtitle("Target=CKD")

The basic data exploration and the feature selection analysis using information gain show that:

Except for White blood cell count,Bacteria,Potassium and PCC, all other features contribute to the prediction of Hypertension.

The features such as Creatinine, Blood sugar test result, Glycemia, Hb, PCV, SPG, Urea and Albumin contribute the most to the prediction of Diabetes.

Haemoglobin level, Packed cell volume, creatinine and red blood cell count were identified as the most relevant features for predicting CKD, but other features might contribute as well to the prediction.

bar_missing <- function(x){
  library(dplyr)
  library(reshape2)
  library(ggplot2)
  x %>%
    is.na %>%
    melt %>%
    ggplot(data = .,
           aes(x = Var2)) +
    geom_bar(aes(y=(..count..),fill=value),alpha=0.7)+scale_fill_manual(values=c("skyblue","red"),name = "",
                                                                        labels = c("Available","Missing"))+
    theme_minimal()+
    theme(axis.text.x = element_text(angle=45, vjust=0.5)) +
    labs(x = "Variables in Dataset",
         y = "Observations")+coord_flip()
}


matrix_missing <- function(x){
  library(dplyr)
  library(reshape2)
  library(ggplot2)
  x %>%
    is.na %>%
    melt %>%
    ggplot(data = .,
           aes(x = Var1,
               y = Var2)) +
    geom_tile(aes(fill = value),alpha=0.6) +
    scale_fill_manual(values=c("skyblue","red"),name = "",
                      labels = c("Available","Missing")) +
    theme_minimal()+
    theme(axis.text.x = element_text(angle=45, vjust=0.5)) +
    labs(x = "Variables in Dataset",
         y = "Total observations")+coord_flip()
}

df%>%bar_missing()

df%>%matrix_missing()

There are a lot of missing values in our data, suggesting that we should consider an algorithm that can handle the missing values; otherwise we should apply the imputation methods. To simplify our analysis, we decide to not use any imputation method.

In mlr, following algorithms can handle automatically the missing values:

listLearners("classif", properties = "missings")[c("class", "package")]
##                         class         package
## 1         classif.bartMachine     bartMachine
## 2          classif.blackboost    mboost,party
## 3            classif.boosting    adabag,rpart
## 4                 classif.C50             C50
## 5             classif.cforest           party
## 6               classif.ctree           party
## 7                 classif.gbm             gbm
## 8                 classif.J48           RWeka
## 9                classif.JRip           RWeka
## 10         classif.naiveBayes           e1071
## 11               classif.OneR           RWeka
## 12               classif.PART           RWeka
## 13    classif.randomForestSRC randomForestSRC
## 14 classif.randomForestSRCSyn randomForestSRC
## 15              classif.rpart           rpart

Machine learning experiment

Using this dataset, we will create a MultilabelTask like other ClassifTasks (please refer to pur previous casestudies to know how to make a classification task in mlr). However, instead of one target name we must specify a vector (list) of targets which correspond to the names of logical variables in our dataframe (For example: HTA for Hypertension, DIA for Diabetes, CKD for Chronic kidney disease).

labels=c("DIA","HTA","CKD")
CKD.task = makeMultilabelTask(id = "CKD", data=df,target=labels)

Then we split the original data into Train and Test subsets using caret package

library(caret)
set.seed(123)
idx<- createDataPartition(df$CKD,p=0.76,list=FALSE)
train<- df[idx,]
test<- df[-idx,]
train.set=row.names(train)%>%as.integer()

train%>%mldr_from_dataframe(.,labelIndices=c(20:22))%>%plot(type="LC",., color.function =rainbow)

test%>%mldr_from_dataframe(.,labelIndices=c(20:22))%>%plot(type="LC",., color.function =rainbow)

Constructing a learner

As mentioned above, the Multilabel classification in mlr can be done in two ways: Algorithm adaptation methods by treating the whole problem with a specific algorithm Problem transformation methods: we will transform the problem, so it could be adapted to simple binary classification algorithms.

1) Algorithm adaptation methods

There are currently 2 available algorithm adaptation methods in R: the multivariate random forest using the randomForestSRC package and the random ferns multilabel algorithm by rFerns package. Unfortunately, our data has many missing value whilst the rFerns package cannot handle the missing value, so the SRC random forest is our only choice. We can create the learner for SRC Random Forest algorithm as follows:

Figure: Algorithm adaptation

Figure: Algorithm adaptation

2)Problem transformation method

This approach will imply a new type of learner, that called “wrapped multilabel learner”. A wrapped learner could be considered as an Ensemble learning algorithm, that consists of combining and arranging of several elementary learners, each one will handle a part of learning process.

When applied to a dataframe, the components inside a wrapped learner will work together in a synchronized way to accomplish the classification task.

First, we must create a core learner using makeLearner( ) function. This is a binary (or multiclass) classification learner.

Afterwards, we will apply one among available functions like: makeMultilabelBinaryRelevanceWrapper( ), makeMultilabelClassifierChainsWrapper ( ), makeMultilabelNestedStackingWrapper ( ), makeMultilabelDBRWrapper ( ) or makeMultilabelStackingWrapper ( ) on the core learner to convert it to a learner that uses the respective problem transformation method.

# Problem transformation method
lrn.rfsrc = makeLearner("multilabel.randomForestSRC",predict.type = "prob")

#Core learner
lrn.core= makeLearner("classif.rpart", predict.type = "prob")

#5 Wrapped learners
lrn.binrel=makeMultilabelBinaryRelevanceWrapper(lrn.core)
lrn.chain=makeMultilabelClassifierChainsWrapper(lrn.core)
lrn.nest=makeMultilabelNestedStackingWrapper(lrn.core)
lrn.dbr= makeMultilabelDBRWrapper(lrn.core)
lrn.stack=makeMultilabelStackingWrapper(lrn.core)

Those wrapping functions correspond to the following methods:

Binary relevance method

Figure: Binary relevance method

Figure: Binary relevance method

This problem transformation method consists of converting the multilabel problem to binary classification problems for each label and applies a simple binary classificator on these. By using this method, we assume that our labels are independent (i.e each one will be treated as a single, stand-alone target outcome, and will be classified without considering the remaining labels). In our case, a binary relevance method based model indicates that Hypertension, CKD and Diabetes are 3 independent entities, thus if those events present in the same patient, it was merely coincidental, without any causative or hierarchical relationship.

Classifier chains method

Figure: Classifier chains method

Figure: Classifier chains method

This method consists of training consecutively the labels with the input data. The input data in each step is extended by the already trained labels (with the real observed values). Such method is appropriate for the study hypothesis that implies hierarchical or causative relationship between the pathological entities. For example, by placing Diabetes and Hypertension before CKD, we assume that Hypertension was the consequence of Diabetes, and either of them contributed significantly to the development of CKD. Therefore, an order of the labels has to be specified when preparing the dataframe. At prediction time the labels are predicted in the same order as while training. The required labels in the input data are given by the previous done prediction of the respective label.

Nested stacking

Figure: Nested stacking method

Figure: Nested stacking method

Its principle is the same as classifier chains, but the labels in the input data are not the real ones, but estimations of the labels obtained by the already trained learners. Thus, this method also implies a hypothesis of causative/ time-order hierarchical relationship among our labels.

Dependent binary relevance (DBR) method

Figure: Dependent binary relevance (DBR) method

Figure: Dependent binary relevance (DBR) method

Each label is trained with the real observed values of all other labels. In prediction phase for a label the other necessary labels are obtained in a previous step by a base learner like the binary relevance method. Though its principle seems to be similar to that of binary relevance, the DBR method is more appropriate to study our pathological entities as comorbidities and/or complications, as the classification of each entity is based on the presence or absence of the remaining entities. However, different to Nested and Chained methods, the DBR method does not imply any assumption of causative or hierarchical association among 3 comorbidities.

Stacking

Figure: Stacking method

Figure: Stacking method

This method is the same as the dependent binary relevance (DBR) method, but in the training phase the labels used as input for each label are obtained by the binary relevance method. Thus the Stacking method could be simply considered as a combination between BR and DBR.

As our dataset contains missing value, we decided to adopt the CART algorithm as core learner for the 5 wrapped learners. Thus all 5 learners consist of Decision Tree model.

Model training and Testing

All learners could be trained in the same way: In our study, we train 6 models on a train subset that contains 76% (n=302) of the original dataset, then test those models on an independent test subset (n=95).

Prediction is easy as usual, we just introduce all trained models into predict( ) function and the test subset.

The performance of our 6 models can be assessed via function performance( ). We can specify via the measures argument which performance metrics to calculate. The default metric for multilabel classification is the Hamming loss (multilabel.hamloss). For the demonstration purpose, we will calculate all available metrics as follows:

#Training 6 models on train subset

mod.rf=mlr::train(lrn.rfsrc,CKD.task,subset=train.set)

mod.binrel=mlr::train(lrn.binrel,CKD.task,subset=train.set)

mod.chain=mlr::train(lrn.chain,CKD.task,subset=train.set)

mod.nest=mlr::train(lrn.nest,CKD.task,subset=train.set)

mod.dbr=mlr::train(lrn.dbr,CKD.task,subset=train.set)

mod.stack=mlr::train(lrn.stack,CKD.task,subset=train.set)

#Prediction on test subset

pred.rf=predict(mod.rf,newdata=test)
pred.binrel=predict(mod.binrel,newdata=test)
pred.chain=predict(mod.chain,newdata=test)
pred.nest=predict(mod.nest,newdata=test)
pred.dbr=predict(mod.dbr,newdata=test)
pred.stack=predict(mod.stack,newdata=test)

# Performance analysis

measures=list(multilabel.acc,multilabel.f1,multilabel.hamloss,multilabel.subset01,multilabel.ppv,multilabel.tpr)

p1=performance(pred.rf,measures)
p2=performance(pred.binrel,measures)
p3=performance(pred.chain,measures)
p4=performance(pred.nest,measures)
p5=performance(pred.dbr,measures)
p6=performance(pred.stack,measures)

performance=as.data.frame(rbind(p1,p2,p3,p4,p5,p6))
performance$model=c("RandomForest","Binaryrelevance","Chains","Nested","DBR","Stacking")
library(RColorBrewer)

plong=gather(performance,metrics,value,multilabel.acc:multilabel.tpr, factor_key=TRUE)

ggplot(plong)+geom_point(aes(x=model,y=value,color=metrics),size=5,alpha=0.7)+facet_grid(metrics~.)+coord_flip()+theme_bw()+scale_color_manual(values=mycolors)

ggplot(plong)+geom_tile(aes(x=model,y=metrics,fill=value),color="black")+geom_text(aes(x=model,y=metrics,label=round(value,3)),color="black")+scale_fill_distiller(palette = "Spectral")

Note: Performance metrics

multilabel.acc = Accuracy (multilabel) : Best=1, Worst=0, Averaged proportion of correctly predicted labels with respect to the total number of labels for each instance, following the definition by Charte and Charte: https://journal.r-project.org/archive/2015-2/charte-charte.pdf Fractions where the denominator becomes 0 are replaced with 1 before computing the average across all instances.

multilabel.f1=F1 measure (multilabel), Best =1, worst=0,Harmonic mean of precision and recall on a per instance basis (Micro-F1), following the definition by Montanes et al.: http://www.sciencedirect.com/science/article/pii/S0031320313004019 Fractions where the denominator becomes 0 are replaced with 1 before computing the average across all instances.

multilabel.hamloss=Hamming loss, Best =0, Worst=1. Proportion of labels that are predicted incorrectly, following the definition by Charte and Charte: https://journal.r-project.org/archive/2015-2/charte-charte.pdf.

multilabel.ppv: Positive predictive value (multilabel); best =1, worst=0. Also called precision. Averaged ratio of correctly predicted labels for each instance, following the definition by Charte and Charte: https://journal.r-project.org/archive/2015-2/charte-charte.pdf. Fractions where the denominator becomes 0 are ignored in the average calculation.

multilabel.subset01=Subset-0-1 loss, best=0, worst=1. Proportion of observations where the complete multilabel set (all 0-1-labels) is predicted incorrectly, following the definition by Charte and Charte: https://journal.r-project.org/archive/2015-2/charte-charte.pdf.

multilabel.tpr=TPR (multilabel),best =1, worst=0. Also called recall. Averaged proportion of predicted labels which are relevant for each instance, following the definition by Charte and Charte: https://journal.r-project.org/archive/2015-2/charte-charte.pdf. Fractions where the denominator becomes 0 are ignored in the average calculation.

Binary performance for each label

Wa can also calculate the binary performance measures like, e.g., accuracy, mmce or auc for each label. To do this, we will use function getMultilabelBinaryPerformances( ). We can apply this function to any multilabel prediction, e.g., also on the resample multilabel prediction. For calculating the auc, the model must be able to generate predicted probabilities.

Meaning of those binary performance metrics have been well described in our previous tutorials

perbin.chain=getMultilabelBinaryPerformances(pred.chain,list(bac,auc,mmce,fnr,fpr))%>%as.data.frame()
perbin.chain$Labels=c("Diabetes","HTA","CKD")
perbin.chain=plyr::rename(perbin.chain, c("bac.test.mean"="BAC", "auc.test.mean"="AUC","mmce.test.mean"="MMCE","fnr.test.mean"="FNR","fpr.test.mean"="FPR"))
pblong=gather(perbin.chain,metrics,value,BAC:FPR, factor_key=TRUE)
ggplot(pblong)+geom_tile(aes(x=metrics,y=Labels,fill=value),color="black")+scale_fill_distiller(palette = "Spectral")+geom_text(aes(x=metrics,y=Labels,label=round(value,3)),color="black")+ggtitle("Classifier chains")

ggplot(pblong)+geom_point(aes(x=Labels,y=value,color=Labels),size=5)+scale_y_continuous("Value",breaks=c(0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1))+coord_flip()+facet_grid(metrics~.)+theme_bw()+ggtitle("Classifier chains")

perbin.binrel=getMultilabelBinaryPerformances(pred.binrel,list(bac,auc,mmce,fnr,fpr))%>%as.data.frame()
perbin.binrel$Labels=c("Diabetes","HTA","CKD")
perbin.binrel=plyr::rename(perbin.binrel, c("bac.test.mean"="BAC", "auc.test.mean"="AUC","mmce.test.mean"="MMCE","fnr.test.mean"="FNR","fpr.test.mean"="FPR"))
pblong=gather(perbin.binrel,metrics,value,BAC:FPR, factor_key=TRUE)
ggplot(pblong)+geom_tile(aes(x=metrics,y=Labels,fill=value),color="black")+scale_fill_distiller(palette = "Spectral")+geom_text(aes(x=metrics,y=Labels,label=round(value,3)),color="black")+ggtitle("Binary relevance")

ggplot(pblong)+geom_point(aes(x=Labels,y=value,color=Labels),size=5)+scale_y_continuous("Value",breaks=c(0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1))+coord_flip()+facet_grid(metrics~.)+theme_bw()+ggtitle("Binary relevance")

perbin.dbr=getMultilabelBinaryPerformances(pred.dbr,list(bac,auc,mmce,fnr,fpr))%>%as.data.frame()
perbin.dbr$Labels=c("Diabetes","HTA","CKD")
perbin.dbr=plyr::rename(perbin.dbr, c("bac.test.mean"="BAC", "auc.test.mean"="AUC","mmce.test.mean"="MMCE","fnr.test.mean"="FNR","fpr.test.mean"="FPR"))
pblong=gather(perbin.dbr,metrics,value,BAC:FPR, factor_key=TRUE)
ggplot(pblong)+geom_tile(aes(x=metrics,y=Labels,fill=value),color="black")+scale_fill_distiller(palette = "Spectral")+geom_text(aes(x=metrics,y=Labels,label=round(value,3)),color="black")+ggtitle("Dependent binary relevance")

ggplot(pblong)+geom_point(aes(x=Labels,y=value,color=Labels),size=5)+scale_y_continuous("Value",breaks=c(0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1))+coord_flip()+facet_grid(metrics~.)+theme_bw()+ggtitle("Dependent binary relevance")

perbin.nest=getMultilabelBinaryPerformances(pred.nest,list(bac,auc,mmce,fnr,fpr))%>%as.data.frame()
perbin.nest$Labels=c("Diabetes","HTA","CKD")
perbin.nest=plyr::rename(perbin.nest, c("bac.test.mean"="BAC", "auc.test.mean"="AUC","mmce.test.mean"="MMCE","fnr.test.mean"="FNR","fpr.test.mean"="FPR"))
pblong=gather(perbin.nest,metrics,value,BAC:FPR, factor_key=TRUE)
ggplot(pblong)+geom_tile(aes(x=metrics,y=Labels,fill=value),color="black")+scale_fill_distiller(palette = "Spectral")+geom_text(aes(x=metrics,y=Labels,label=round(value,3)),color="black")+ggtitle("Nested Stacking")

ggplot(pblong)+geom_point(aes(x=Labels,y=value,color=Labels),size=5)+scale_y_continuous("Value",breaks=c(0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1))+coord_flip()+facet_grid(metrics~.)+theme_bw()+ggtitle("Nested Stacking")

perbin.rf=getMultilabelBinaryPerformances(pred.rf,list(bac,auc,mmce,fnr,fpr))%>%as.data.frame()
perbin.rf$Labels=c("Diabetes","HTA","CKD")
perbin.rf=plyr::rename(perbin.rf, c("bac.test.mean"="BAC", "auc.test.mean"="AUC","mmce.test.mean"="MMCE","fnr.test.mean"="FNR","fpr.test.mean"="FPR"))
pblong=gather(perbin.rf,metrics,value,BAC:FPR, factor_key=TRUE)
ggplot(pblong)+geom_tile(aes(x=metrics,y=Labels,fill=value),color="black")+scale_fill_distiller(palette = "Spectral")+geom_text(aes(x=metrics,y=Labels,label=round(value,3)),color="black")+ggtitle("SRC Random Forest")

ggplot(pblong)+geom_point(aes(x=Labels,y=value,color=Labels),size=5)+scale_y_continuous("Value",breaks=c(0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1))+coord_flip()+facet_grid(metrics~.)+theme_bw()+ggtitle("SRC Random Forest")

perbin.stack=getMultilabelBinaryPerformances(pred.stack,list(bac,auc,mmce,fnr,fpr))%>%as.data.frame()
perbin.stack$Labels=c("Diabetes","HTA","CKD")
perbin.stack=plyr::rename(perbin.stack, c("bac.test.mean"="BAC", "auc.test.mean"="AUC","mmce.test.mean"="MMCE","fnr.test.mean"="FNR","fpr.test.mean"="FPR"))
pblong=gather(perbin.stack,metrics,value,BAC:FPR, factor_key=TRUE)
ggplot(pblong)+geom_tile(aes(x=metrics,y=Labels,fill=value),color="black")+scale_fill_distiller(palette = "Spectral")+geom_text(aes(x=metrics,y=Labels,label=round(value,3)),color="black")+ggtitle("Stacking")

ggplot(pblong)+geom_point(aes(x=Labels,y=value,color=Labels),size=5)+scale_y_continuous("Value",breaks=c(0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1))+coord_flip()+facet_grid(metrics~.)+theme_bw()+ggtitle("Stacking")

The results suggest that all 6 models are able to classify the pathological triangle of Diabetes, Hypertension and CKD with a high accuracy.

SRC random forest seems to be a very powerful algorithm, it can handle multiple target labels and resolve most missing value problems. In addition, this algorithm works faster than other wrapped learners.

Due to missing value, all the problem transformation models were based on CART algorithm. This decision tree algorithm worked well as usual. However, the performance of these models varied from good to very good, depending on the prior assumption and model’s structure.

It seems that the most accurate prediction were obtained using the algorithms that treat the labels independently, and the Binary relevance model worked slightly better than Chain model. This result might indicate that our assumption on the comorbidity/complication relationship is not always true; i.e not all Diabetic or hypertensive patients did develop a CKD. The chained model would work perfectly when there is a clear hierarchical or causative relationship among three pathological entities.

Chain and DBR models present a relatively lower performance than the SRC RandomForest and other wrapped models. Since both Chain and DBR models used the real data of 3 labels for training, so they are more sensitive to noises in the data than the Nested stacking method that uses predicted labels generated from trained components as input data for training the next component.

Conclusion

Multilabel classification is a special supervised learning method. It could be useful to resolve the greedy problems, such as predicting the comorbidities or complications of a disease as a whole pathological entity. Such goal can be achieved by using at least 6 alternative solutions. The easiest soluton consists of adopting the specific algorithms like Random Forest SRC or rFerns. The wrapped models are more appropriate when we already have a solid evidence about the interdependent relationship among the labels. The model’s performance could be different, depending on the data quality, the mechanism of core learner and prior assumptions.

Regarless of method in use, multilabel classification is always greedy, as the model training is more sensitive to noise and imperfections in data structure than simple binary classfication tasks. A good multilabel model training requires larger sample size and well balanced data splitting. Resampling is usually not possible, due to label mismatching or unavailability of one label in randomised testing subsets.