The aim of this project is to compare different decision tree packages in R. Decision trees tend, on their own, to be weak learners, although their performance can be meaningfully improved with boosting or random forests. But there are a of packages for ordinary decision trees in R, so I was curious how they performed. Mostly, these packages differ in their handling of missing data, so I needed a data set with at least some missing data (an easy task).
My working hypothesis was that packages that could handle missing data (for example by imputation) would outperform packages that just skipped cases with missing values. So-called ‘completer analysis’, where only complete cases are considered, is typically a bad idea: subjects with no missing data at all are routinely different from the population of interest in important but unknown ways. I would expect, therefore, that predicting new cases based on a completer analysis to perform poorly.
My chosen data set includes 1700 patients who presented to a hospital with myocardial infarction in Russia (data collected from 1992 – 1995; source is UCI data repository, donated 2020). Data was collected for 12 outcomes (different possible complications). These outcome variables were categorized into a single complications variable with 3 levels: no complications, non-fatal complications, and fatal complications. There was no missing data in the outcome variable. A list of all variables and descriptions, and a list of papers in which this database has been used, can be found at https://leicester.figshare.com/articles/dataset/Myocardial_infarction_complications_Database/12045261?file=22803572.
Missing data in the predictors at a rate of ≤10% was considered acceptable. For predictors missing at a rate from (10%, 25%], a physician was consulted to provide insight as to whether the variable is considered clinically relevant for the outcome. Variables that were not clinically relevant were excluded if missing in this range; variables that were considered clinically relevant were noted for further analysis to determine the impact in the model.
The following variables had a rate of missingness in this range: duration of arterial hypertension, blood pressure (systolic and diastolic), low serum potassium (hypokalemia), serum potassium content (mmol/L), increase in serum sodium, serum sodium content (mmol/L), serum ALT, serum AST. Of these, only blood pressure was considered clinically relevant; all others were excluded. (Note: 2 blood pressure measurements were reported—one taken in the ER and one taken in the ICU; most patients had bp measured in only one location, so this was coalesced into a single variable).
Predictors missing at rates > 25% were excluded from further analysis. Three variables were excluded due to missing rates >25%: family history of coronary heart disease (CHD), serum CPK content, and use of opioids and lidocaine by emergency cardiology department.
Additionally, variables assessing Q-wave changes in the QRS complex were excluded, because recent research indicates that this categorization is not medically significant (https://www.healio.com/news/cardiology).
The final dataset included 89 predictors. Next, cases missing values for more than 20% of the predictors were removed (128 cases in total).
(I have hidden the code for ease of reading, but am happy to share if you reach out to me)
The sample is mostly male (1065 male, 635 female), aged 26-92 (average 62 years). Female subjects were slightly older than male subjects.
Z %>%
ggplot(aes(x=SEX, y=AGE, fill=SEX)) +
geom_boxplot(alpha=.5) +
theme_classic() +
theme(legend.position="top") +
labs(y= "Density", x = "Age", title = "Age Distribution by Sex", fill = 'Sex')+
scale_fill_discrete(labels = c("Female", "Male"))
The majority of subjects (1468) did not have a history of chronic heart failure (ZSN_A), and 13.4% of subjects had a history of diabetes mellitus. Only 195 (11.5%) of subjects were administered nitrates in the ICU (NITR_S) in the ICU, and 28% of subjects were given lidocaine in the ICU (LID_S_n). Fifty subjects presented with a right ventricular myocardial infarction (IM_PG_P), and 46 subjects were in a state of cardiogenic shock at admission to the ICU (K_SH_POST).
For all methods below, a training dataset consisted of 80% of the total dataset (1360 observations). In the training set, 468 cases were complete (no missing values).
set.seed(1200)
round(.8*1700,0)
## [1] 1360
training = sample(1:nrow(Y), 1360)
The tree package handles missing data by pushing missing values down the tree as far as possible. While not a pure ‘completer’ analysis, this decision tree used only 561 cases out of the 1360 available in the training set. A tree created using the tree package is shown below. The predictive accuracy was 49.7%. Compared to other packages, this tree is also much less beautiful
tree.train = tree(outcome.cat~.,
data = Z[training,])
plot(tree.train)
text(tree.train,pretty=0,cex=0.7)
pred=predict(tree.train,Z[-training,],type="class")
(ct<-table(pred,Z[-training,]$outcome.cat))
##
## pred 0 1 2
## 0 98 69 13
## 1 44 60 18
## 2 6 21 11
sum(diag(ct))/nrow(Z[-training,])
## [1] 0.4970588
The Rpart package allows all cases to be included in the modeling as long as the outcome and at least one predictor are available. It handles missing values by creating surrogate variables, a method similar to imputation. If a subject is missing the variable at the split, it goes down the list of surrogate variables, using the best surrogate variable for which there is data. If there is not an observation for any of the surrogate variables, then the “blind rule”, to predict according to the majority class for that variable, is used. Only surrogates that perform better than the ‘blind rule’ are included in the tree.
tree.r<-rpart(outcome.cat~.,
Z[training,],
method = "class")
fancyRpartPlot(tree.r)
pred<-predict(tree.r, Z[-training,], type = "class")
(ct<-table(pred,Z[-training,]$outcome.cat))
##
## pred 0 1 2
## 0 52 39 5
## 1 96 111 33
## 2 0 0 4
sum(diag(ct))/nrow(Z[-training,])
## [1] 0.4911765
The caret package imputes missing data where possible using KNN, and then omits all remaining cases with missing values. Prior to imputing missing data, 34.4% of the training set and 35% of the testing set was complete cases. After imputing, 47.7% of the training data and 47.9% of the training data was complete cases–an improvement, but still not great.
Predictive accuracy for the decision tree through the caret package is 49.1%. Partly, this is because caret actually doesn’t predict more than 2 levels of an outcome.
Notably, this tree is much more parsimonious than the other trees. Only two variables are in this this tree: ZSN_A and Age (part of the preprocessing step to impute missing values also centers and scales them, so this is why the thresholds appear different).
preProcess_missingdata_model <- preProcess(Z[training,], method='knnImpute')
ppmissingtest = preProcess(Z[-training,], method = 'knnImpute')
trainData <- predict(preProcess_missingdata_model, newdata = Z[training,])
pptestdata <- predict(ppmissingtest, newdata = Z[-training,])
tree.tr = train(outcome.cat~.,
data=trainData,
method="rpart",
#preProcess = "knnImpute",
trControl = trainControl(method = "cv"),
na.action = na.omit)
suppressMessages(library(rattle))
fancyRpartPlot(tree.tr$finalModel)
pred=predict(tree.tr,pptestdata,type="raw")
pptestdata = na.omit(pptestdata)
ct<-table(pred,pptestdata$outcome.cat)
ct
##
## pred 0 1 2
## 0 37 22 5
## 1 45 43 11
## 2 0 0 0
sum(diag(ct))/nrow(pptestdata)
## [1] 0.4907975
The C50 library uses an “information entropy” based algorithm to split the data into successively purer leaves. Because of this the algorithm is able to handle missing values without imputing. The C5.0 function does have an internal pruning option (“winnow”) but the resulting tree was still too complex to be legible as a plot. Instead, predictor importance is shown below; usage reflects the percent of the training sample that ends up in terminal nodes after going through the split. The predictive accuracy for the decision tree through the C50 package is 54.4%.
tree.c50 = C5.0(outcome.cat~.,
data = Z[training,],
)
tree.c50
##
## Call:
## C5.0.formula(formula = outcome.cat ~ ., data = Z[training, ])
##
## Classification Tree
## Number of samples: 1360
## Number of predictors: 85
##
## Tree size: 170
##
## Non-standard options: attempt to group attributes
C5imp(tree.c50)
## Overall
## K_SH_POST 99.12
## IM_PG_P 96.84
## NITR_S 93.60
## ZSN_A 91.47
## zab_leg_02 80.81
## MP_TP_POST 75.15
## FIB_G_POST 70.59
## fibr_ter_02 67.43
## endocr_02 66.91
## endocr_01 66.32
## AGE 62.57
## O_L_POST 60.74
## LID_S_n 60.07
## n_p_ecg_p_06 47.35
## n_p_ecg_p_12 44.41
## TIME_B_S 31.25
## zab_leg_01 25.81
## GB 24.71
## n_p_ecg_p_07 24.41
## SIM_GIPERT 22.28
## n_p_ecg_p_11 21.32
## TRENT_S_n 21.18
## FK_STENOK 20.51
## L_BLOOD 20.00
## ASP_S_n 17.79
## n_r_ecg_p_04 17.43
## INF_ANAM 17.13
## nr_11 14.85
## SEX 14.12
## ANT_CA_S_n 14.12
## TIKL_S_n 14.12
## nr_04 11.91
## ROE 10.15
## GEPAR_S_n 9.85
## n_r_ecg_p_03 8.90
## IBS_POST 8.82
## n_r_ecg_p_01 8.75
## ritm_ecg_p_07 8.68
## nr_03 7.50
## fibr_ter_03 7.50
## zab_leg_03 7.43
## STENOK_AN 4.56
## ritm_ecg_p_01 4.56
## B_BLOK_S_n 3.75
## diastolic 3.60
## systolic 3.46
## n_r_ecg_p_05 0.44
## nr_01 0.00
## nr_02 0.00
## nr_07 0.00
## nr_08 0.00
## np_01 0.00
## np_04 0.00
## np_05 0.00
## np_07 0.00
## np_08 0.00
## np_09 0.00
## np_10 0.00
## endocr_03 0.00
## zab_leg_04 0.00
## zab_leg_06 0.00
## SVT_POST 0.00
## GT_POST 0.00
## ritm_ecg_p_02 0.00
## ritm_ecg_p_04 0.00
## ritm_ecg_p_06 0.00
## ritm_ecg_p_08 0.00
## n_r_ecg_p_02 0.00
## n_r_ecg_p_06 0.00
## n_r_ecg_p_08 0.00
## n_r_ecg_p_09 0.00
## n_r_ecg_p_10 0.00
## n_p_ecg_p_01 0.00
## n_p_ecg_p_03 0.00
## n_p_ecg_p_04 0.00
## n_p_ecg_p_05 0.00
## n_p_ecg_p_08 0.00
## n_p_ecg_p_09 0.00
## n_p_ecg_p_10 0.00
## fibr_ter_01 0.00
## fibr_ter_05 0.00
## fibr_ter_06 0.00
## fibr_ter_07 0.00
## fibr_ter_08 0.00
## NOT_NA_KB 0.00
pred<-predict(tree.c50, Z[-training,], type = "class")
ct<-table(pred,Z[-training,]$outcome.cat)
ct
##
## pred 0 1 2
## 0 79 46 6
## 1 61 94 24
## 2 8 10 12
sum(diag(ct))/nrow(Z[-training,])
## [1] 0.5441176
A neat thing about this package is that you can easily boost your tree by modifying the ‘trials’ (default is 1) to any number between 1 and 100. Here it improved the predictive accuracy by 1%.
tree.boost = C5.0(outcome.cat~.,
data = Z[training,],
trials=80)
pred<-predict(tree.boost, Z[-training,], type = "class")
ct<-table(pred,Z[-training,]$outcome.cat)
ct
##
## pred 0 1 2
## 0 81 49 6
## 1 58 84 21
## 2 9 17 15
sum(diag(ct))/nrow(Z[-training,])
## [1] 0.5294118
Like Rpart, the ctree function in the party package handles missing data using surrogate splits. In ctree, the surrogate splits are identified as splitting variables that lead to approximately the same division of cases going left/right as the original split. However, the Rpart output is very clear on what the surrogate variables are and how many variables went through them. The ctree function is a bit more ‘black box’ in this regard. Like the C5.0 tree, the plot is too busy (at least with this data) to be informative.
The predictive accuracy of the ctree under the party package is 52.4%.
tree.p = ctree(outcome.cat~.,
data = Z[training,],
)
tree.p
##
## Conditional inference tree with 17 terminal nodes
##
## Response: outcome.cat
## Inputs: AGE, SEX, INF_ANAM, STENOK_AN, FK_STENOK, IBS_POST, GB, SIM_GIPERT, ZSN_A, nr_11, nr_01, nr_02, nr_03, nr_04, nr_07, nr_08, np_01, np_04, np_05, np_07, np_08, np_09, np_10, endocr_01, endocr_02, endocr_03, zab_leg_01, zab_leg_02, zab_leg_03, zab_leg_04, zab_leg_06, O_L_POST, K_SH_POST, MP_TP_POST, SVT_POST, GT_POST, FIB_G_POST, IM_PG_P, ritm_ecg_p_01, ritm_ecg_p_02, ritm_ecg_p_04, ritm_ecg_p_06, ritm_ecg_p_07, ritm_ecg_p_08, n_r_ecg_p_01, n_r_ecg_p_02, n_r_ecg_p_03, n_r_ecg_p_04, n_r_ecg_p_05, n_r_ecg_p_06, n_r_ecg_p_08, n_r_ecg_p_09, n_r_ecg_p_10, n_p_ecg_p_01, n_p_ecg_p_03, n_p_ecg_p_04, n_p_ecg_p_05, n_p_ecg_p_06, n_p_ecg_p_07, n_p_ecg_p_08, n_p_ecg_p_09, n_p_ecg_p_10, n_p_ecg_p_11, n_p_ecg_p_12, fibr_ter_01, fibr_ter_02, fibr_ter_03, fibr_ter_05, fibr_ter_06, fibr_ter_07, fibr_ter_08, L_BLOOD, ROE, TIME_B_S, NOT_NA_KB, NITR_S, LID_S_n, B_BLOK_S_n, ANT_CA_S_n, GEPAR_S_n, ASP_S_n, TIKL_S_n, TRENT_S_n, systolic, diastolic
## Number of observations: 1360
##
## 1) K_SH_POST == {0}; criterion = 1, statistic = 179.953
## 2) ZSN_A == {2, 4}; criterion = 1, statistic = 125.722
## 3)* weights = 53
## 2) ZSN_A == {0, 1}
## 4) AGE <= 61; criterion = 1, statistic = 116.815
## 5) NITR_S == {0}; criterion = 1, statistic = 38.228
## 6) ZSN_A == {0}; criterion = 1, statistic = 30.574
## 7) n_p_ecg_p_12 == {1}; criterion = 1, statistic = 30.212
## 8)* weights = 21
## 7) n_p_ecg_p_12 == {0}
## 9) LID_S_n == {0}; criterion = 0.99, statistic = 18.013
## 10) zab_leg_02 == {0}; criterion = 0.978, statistic = 16.455
## 11)* weights = 323
## 10) zab_leg_02 == {1}
## 12)* weights = 14
## 9) LID_S_n == {1}
## 13)* weights = 143
## 6) ZSN_A == {1}
## 14)* weights = 25
## 5) NITR_S == {1}
## 15)* weights = 49
## 4) AGE > 61
## 16) NITR_S == {0}; criterion = 1, statistic = 31.737
## 17) IM_PG_P == {1}; criterion = 1, statistic = 28.43
## 18)* weights = 20
## 17) IM_PG_P == {0}
## 19) ZSN_A == {1}; criterion = 0.996, statistic = 19.694
## 20)* weights = 42
## 19) ZSN_A == {0}
## 21) zab_leg_02 == {0}; criterion = 0.992, statistic = 18.598
## 22) n_p_ecg_p_12 == {0}; criterion = 0.989, statistic = 17.834
## 23) ritm_ecg_p_01 == {1}; criterion = 0.996, statistic = 20.087
## 24) endocr_01 == {0}; criterion = 0.973, statistic = 16.05
## 25) TIME_B_S <= 6; criterion = 0.964, statistic = 15.515
## 26)* weights = 195
## 25) TIME_B_S > 6
## 27)* weights = 102
## 24) endocr_01 == {1}
## 28)* weights = 54
## 23) ritm_ecg_p_01 == {0}
## 29)* weights = 131
## 22) n_p_ecg_p_12 == {1}
## 30)* weights = 17
## 21) zab_leg_02 == {1}
## 31)* weights = 42
## 16) NITR_S == {1}
## 32)* weights = 87
## 1) K_SH_POST == {1}
## 33)* weights = 42
pred<-predict(tree.p, Z[-training,], type = "response")
ct<-table(pred,Z[-training,]$outcome.cat)
ct
##
## pred 0 1 2
## 0 65 43 5
## 1 76 100 24
## 2 7 7 13
sum(diag(ct))/nrow(Z[-training,])
## [1] 0.5235294
Overall, none of the methods used had a high predictive accuracy, although the C5.0 performed the best (54.4%).
Commonly identified predictors in having a complication from myocardial infarction are AGE, history of chronic heart failure (ZSN_A), receiving liquid nitrates in the ICU (NITR_S), SEX, whether the subject presented to the ICU in a state of cardiogenic shock (K_SH_POST), whether the subject received lidocaine in the ICU (LID_S_N), whether the subject had a history of diabetes (endocr_01), and whether the subject presented with a right ventricular myocardial infarction (IM_PG_P).