# Importing Libraries -----------------------------------------------
library(tidyverse)
library(summarytools)
library(ggplot2)
library(esquisse)
library(gtsummary)
library(flextable)
library(ggcorrplot)
library(corrplot)
library(foreign)
library(plotly)
library(tibble)
library(dplyr)
library(caret)
library(readr)
library(tidymodels)
library(pROC)
library(rpart.plot)
library(randomForest)Comparison Of Machine Learning Models Performance On Predicting Survival After Pediatric HSCT
1. Introduction
Hematopoietic stem cell transplant (HSCT) is a therapeutic option with the potential to cure several hematological diseases, as well as the possibility of improving survival for a group of patients (KANATE et al., 2020).
HSCT consists of obtaining HSC from a donor and infusing them into the recipient patient, in order to reestablish his hematopoiesis and immune function. In autologous transplants, this source of cells is the patient himself, while in allogeneic transplants, the source of cells is a donor who is not genetically identical to the recipient, and may belong to the recipient’s family – related transplant – or not – unrelated transplant.
HSCT associated adverse effects such as graft-versus-host disease (GVHD), graft failure, relapse can affect the recipient impacting life quality and overall survival. These effects vary according to the complex interaction of variables related to the patient, the donor and the HSCT itself.
Tools capable of predicting a recipient survival probability using HSCT common variables can be useful for the medical team in planning pre and post HSCT interventions and possibly improving HSCT success. Supervised Machine learning strategis can be an option as it relies on a data driven approach where computational algorithms extract patterns from a rotulated dataset generating a model that can be used to investigate new observations and try to infer the probability of associated outcomes.
For this purpose, a dataset containing HSCT observations discriminated acoording to relevant variables for which the outcome to predicted is determined must be used.
In this work, I am training machine learning algorithms on an HSCT dataset obtained from a research described on the paper:
https://www.sciencedirect.com/science/article/pii/S1083879110001485?via%3Dihub
2. Materials and Methods
2.1 Dataset Pre-Processing
setwd("C:\\Users\\vinic\\OneDrive\\Área de Trabalho\\doutorado_peb\\AnaliseIntelDados_I\\_trabalho_final")
dataset_bmt_raw <- read.arff("./bone-marrow.arff")
hscts <- nrow(dataset_bmt_raw)
hsct_variables <- ncol(dataset_bmt_raw)2.1.1 Variables description
The dataset is composed by 187 HSCT events characterized by 36 independent features and 01 dependent variable (survival status), which will be the variable to be predicted.
The variables and its respective definitions are as follows:
a) Donor variables
• donor_age - Age of the donor at the time of hematopoietic stem cells apheresis
• donor_age_below_35 - Is donor age less than 35 (yes, no)
• donor_ABO - ABO blood group of the donor of hematopoietic stem cells (0, A, B, AB)
• donor_CMV - Presence of cytomegalovirus infection in the donor of hematopoietic stem cells prior to transplantation (present, absent)
b) Recipient variables
• recipient_age - Age of the recipient of hematopoietic stem cells at the time of transplantation
• recipient_age_below_10 - Is recipient age below 10 (yes, no)
• recipient_age_int - Age of the recipient discretized to intervals (0,5], (5, 10], (10, 20]
• recipient_gender - Gender of the recipient (female, male)
• recipient_body_mass - Body mass of the recipient of hematopoietic stem cells at the time of the transplantation
• recipient_ABO - ABO blood group of the recipient of hematopoietic stem cells (0, A, B ,AB)
• recipient_rh - Presence of the Rh factor on recipient’s red blood cells (plus ,minus)
• recipient_CMV - Presence of cytomegalovirus infection in the donor of hematopoietic stem cells prior to transplantation (present, absent)
• disease - Type of disease (ALL, AML, chronic, nonmalignant, lymphoma)
• disease_group - Type of disease (malignant, nonmalignant)
c) Donor-Recipient match variables
• gender_match - Compatibility of the donor and recipient according to their gender (female to male, other)
• ABO_match - Compatibility of the donor and the recipient of hematopoietic stem cells according to ABO blood group (matched, mismatched)
• CMV_status - Serological compatibility of the donor and the recipient of hematopoietic stem cells according to cytomegalovirus infection prior to transplantation (the higher the value the lower the compatibility)
• HLA_match - Compatibility of antigens of the main histocompatibility complex of the donor and the recipient of hematopoietic stem cells (10/10, 9/10, 8/10, 7/10)
• HLA_mismatch - HLA matched or mismatched
• antigen - In how many antigens there is a difference between the donor nad the recipient (0-3)
• allel - In how many allele there is a difference between the donor nad the recipient (0-4)
• HLA_group_1 - The difference type between the donor and the recipient (HLA matched, one antigen, one allel, DRB1 cell, two allele or allel+antigen, two antigenes+allel, mismatched)
• risk_group - Risk group (high, low) • stem_cell_source - Source of hematopoietic stem cells (peripheral blood, bone marrow)
d) Transplant variables
• tx_post_relapse - The second bone marrow transplantation after relapse (yes ,no) • CD34_x1e6_per_kg - CD34kgx10d6 - CD34+ cell dose per kg of recipient body weight (10^6/kg)
• CD3_x1e8_per_kg - CD3+ cell dose per kg of recipient body weight (10^8/kg)
• CD3_to_CD34_ratio - CD3+ cell to CD34+ cell ratio
• ANC_recovery - Time to neutrophils recovery defined as neutrophils count >0.5 x 10^9/L
• PLT_recovery - Time to platelet recovery defined as platelet count >50000/mm3
• acute_GvHD_II_III_IV - Development of acute graft versus host disease stage II or III or IV (yes, no)
• acute_GvHD_III_IV - Development of acute graft versus host disease stage III or IV (yes, no)
• time_to_acute_GvHD_III_IV - Time to development of acute graft versus host disease stage III or IV
str(dataset_bmt_raw)'data.frame': 187 obs. of 37 variables:
$ donor_age : num 22.8 23.3 26.4 39.7 33.4 ...
$ donor_age_below_35 : Factor w/ 2 levels "no","yes": 2 2 2 1 2 2 2 2 2 2 ...
$ donor_ABO : Factor w/ 4 levels "0","A","AB","B": 2 4 4 2 2 3 1 1 3 2 ...
$ donor_CMV : Factor w/ 2 levels "absent","present": 2 1 1 2 1 NA 1 2 1 1 ...
$ recipient_age : num 9.6 4 6.6 18.1 1.3 8.9 14.4 18.2 7.9 4.7 ...
$ recipient_age_below_10 : Factor w/ 2 levels "no","yes": 2 2 2 1 2 2 1 1 2 2 ...
$ recipient_age_int : Factor w/ 3 levels "0_5","10_20",..: 3 1 3 2 1 3 2 2 3 1 ...
$ recipient_gender : Factor w/ 2 levels "female","male": 2 2 2 1 1 2 1 2 2 2 ...
$ recipient_body_mass : num 35 20.6 23.4 50 9 40 51 56 20.5 16.5 ...
$ recipient_ABO : Factor w/ 4 levels "0","A","AB","B": 2 4 4 3 3 1 2 2 1 1 ...
$ recipient_rh : Factor w/ 2 levels "minus","plus": 2 2 2 2 1 2 1 2 2 2 ...
$ recipient_CMV : Factor w/ 2 levels "absent","present": 2 1 2 1 2 2 NA 1 2 2 ...
$ disease : Factor w/ 5 levels "ALL","AML","chronic",..: 1 1 1 2 3 3 2 5 5 5 ...
$ disease_group : Factor w/ 2 levels "malignant","nonmalignant": 1 1 1 1 1 1 1 2 2 2 ...
$ gender_match : Factor w/ 2 levels "female_to_male",..: 2 2 2 2 2 2 2 2 2 1 ...
$ ABO_match : Factor w/ 2 levels "matched","mismatched": 1 1 1 2 2 2 2 2 2 2 ...
$ CMV_status : Factor w/ 4 levels "0","1","2","3": 4 1 3 2 1 NA NA 2 3 3 ...
$ HLA_match : Factor w/ 4 levels "10/10","7/10",..: 1 1 1 1 4 1 1 2 1 4 ...
$ HLA_mismatch : Factor w/ 2 levels "matched","mismatched": 1 1 1 1 1 1 1 2 1 1 ...
$ antigen : Factor w/ 4 levels "0","1","2","3": 1 1 1 1 3 1 1 3 1 2 ...
$ allel : Factor w/ 5 levels "0","1","2","3",..: 1 1 1 1 2 1 1 4 1 3 ...
$ HLA_group_1 : Factor w/ 7 levels "DRB1_cell","matched",..: 2 2 2 2 5 2 2 3 2 1 ...
$ risk_group : Factor w/ 2 levels "high","low": 1 2 2 2 1 1 2 2 2 2 ...
$ stem_cell_source : Factor w/ 2 levels "bone_marrow",..: 2 1 1 1 2 1 2 1 2 2 ...
$ tx_post_relapse : Factor w/ 2 levels "no","yes": 1 1 1 1 1 2 1 1 1 1 ...
$ CD34_x1e6_per_kg : num 7.2 4.5 7.94 4.25 51.85 ...
$ CD3_x1e8_per_kg : num 5.38 0.41 0.42 0.14 13.05 ...
$ CD3_to_CD34_ratio : num 1.34 11.08 19.01 29.48 3.97 ...
$ ANC_recovery : num 19 16 23 23 14 16 17 22 15 16 ...
$ PLT_recovery : num 51 37 20 29 14 70 29 58 14 17 ...
$ acute_GvHD_II_III_IV : Factor w/ 2 levels "no","yes": 2 2 2 2 1 1 2 2 1 2 ...
$ acute_GvHD_III_IV : Factor w/ 2 levels "no","yes": 2 1 1 2 1 1 2 2 1 1 ...
$ time_to_acute_GvHD_III_IV: num 32 1000000 1000000 19 1000000 1000000 18 22 1000000 1000000 ...
$ extensive_chronic_GvHD : Factor w/ 2 levels "no","yes": 1 1 1 NA 1 1 NA NA 1 1 ...
$ relapse : Factor w/ 2 levels "no","yes": 1 2 2 1 1 1 1 1 1 1 ...
$ survival_time : num 999 163 435 53 2043 ...
$ survival_status : Factor w/ 2 levels "0","1": 1 2 2 2 1 1 2 2 1 1 ...
Dependent variable levels (survival_status) where changed to strings (“alive”, “dead”) instead of numerical character labels (“0”, “1”) to prevent misunderstandings.
levels(dataset_bmt_raw$survival_status)[1] "0" "1"
levels(dataset_bmt_raw$survival_status) <- c("alive", "dead")
levels(dataset_bmt_raw$survival_status)[1] "alive" "dead"
2.1.2 Summary tables (raw data)
theme_gtsummary_language(
language = "en",
iqr.sep = "-",
ci.sep = "-",
set_theme = TRUE
)
# Separating numerical from categorical columns --------------------------------
num_cols <- sapply(dataset_bmt_raw, is.numeric )
dataset_bmt_raw_numeric <- dataset_bmt_raw[num_cols]
dataset_bmt_raw_numeric$survival_status <- dataset_bmt_raw$survival_status
cat_cols<- sapply(dataset_bmt_raw, is.factor)
dataset_bmt_raw_cat <- dataset_bmt_raw[cat_cols]# Function to generate table summaries -----------------------------------------
generate_tblsummary <- function(dataset, outcome) {
dataset |>
tbl_summary(
by = all_of(outcome),
statistic = list(
all_continuous() ~ "{median} ({IQR})",
all_categorical() ~ "{n} ({p}%)"
),
digits = all_continuous() ~ 2,
) |>
add_p() |>
add_overall()
}
# Function to generate histogram plots -----------------------------------------
hist_plot_cont <- function(data, x_column){
x <- data[,x_column]
data |>
ggplot() +
aes(x = x) +
geom_histogram(bins = 25L, fill = "#112446") +
theme_bw()
}
hist_plot_cat <- function(data, x_column){
x <- data[,x_column]
data |>
ggplot() +
aes(x = x) +
geom_bar(fill = "#112446") +
theme_bw()
}2.1.2.1 Numeric variables
#Table summary: Numeric variables ---------------------------------------------
generate_tblsummary(dataset_bmt_raw_numeric, "survival_status")Characteristic |
Overall |
alive |
dead |
p-value |
|---|---|---|---|---|
| donor_age | 33.55 (13.08) | 31.71 (12.70) | 35.09 (13.66) | 0.5 |
| recipient_age | 9.60 (9.00) | 8.50 (8.53) | 12.10 (9.90) | 0.005 |
| recipient_body_mass | 33.00 (31.60) | 26.85 (28.68) | 40.30 (31.85) | 0.004 |
| Unknown | 2 | 0 | 2 | |
| CD34_x1e6_per_kg | 9.72 (10.07) | 11.11 (9.72) | 7.91 (6.80) | 0.007 |
| CD3_x1e8_per_kg | 4.33 (5.10) | 5.09 (5.21) | 3.32 (4.80) | 0.002 |
| Unknown | 5 | 1 | 4 | |
| CD3_to_CD34_ratio | 2.73 (4.04) | 2.69 (2.40) | 2.93 (5.46) | 0.12 |
| Unknown | 5 | 1 | 4 | |
| ANC_recovery | 15.00 (4.00) | 15.00 (3.00) | 15.00 (5.00) | 0.5 |
| PLT_recovery | 21.00 (21.00) | 20.00 (16.25) | 27.00 (31.00) | 0.007 |
| time_to_acute_GvHD_III_IV | 1,000,000.00 (0.00) | 1,000,000.00 (0.00) | 1,000,000.00 (999,938.00) | 0.2 |
| survival_time | 676.00 (1,435.50) | 1,428.00 (1,042.00) | 149.00 (272.00) | <0.001 |
| 1 Median (IQR) |
||||
| 2 Wilcoxon rank sum test |
||||
2.1.2.2 Categorical variables
#Table summary: Categorical variables ------------------------------------------
generate_tblsummary(dataset_bmt_raw_cat, "survival_status")Characteristic |
Overall |
alive |
dead |
p-value |
|---|---|---|---|---|
| donor_age_below_35 | 104 (56%) | 62 (61%) | 42 (49%) | 0.12 |
| donor_ABO | 0.4 | |||
| 0 | 73 (39%) | 40 (39%) | 33 (39%) | |
| A | 71 (38%) | 35 (34%) | 36 (42%) | |
| AB | 15 (8.0%) | 11 (11%) | 4 (4.7%) | |
| B | 28 (15%) | 16 (16%) | 12 (14%) | |
| donor_CMV | 0.4 | |||
| absent | 113 (61%) | 59 (58%) | 54 (64%) | |
| present | 72 (39%) | 42 (42%) | 30 (36%) | |
| Unknown | 2 | 1 | 1 | |
| recipient_age_below_10 | 99 (53%) | 61 (60%) | 38 (45%) | 0.039 |
| recipient_age_int | 0.14 | |||
| 0_5 | 47 (25%) | 30 (29%) | 17 (20%) | |
| 10_20 | 89 (48%) | 42 (41%) | 47 (55%) | |
| 5_10 | 51 (27%) | 30 (29%) | 21 (25%) | |
| recipient_gender | 0.7 | |||
| female | 75 (40%) | 42 (41%) | 33 (39%) | |
| male | 112 (60%) | 60 (59%) | 52 (61%) | |
| recipient_ABO | >0.9 | |||
| 0 | 48 (26%) | 27 (26%) | 21 (25%) | |
| A | 75 (40%) | 39 (38%) | 36 (43%) | |
| AB | 13 (7.0%) | 8 (7.8%) | 5 (6.0%) | |
| B | 50 (27%) | 28 (27%) | 22 (26%) | |
| Unknown | 1 | 0 | 1 | |
| recipient_rh | 0.085 | |||
| minus | 27 (15%) | 19 (19%) | 8 (9.6%) | |
| plus | 158 (85%) | 83 (81%) | 75 (90%) | |
| Unknown | 2 | 0 | 2 | |
| recipient_CMV | 0.6 | |||
| absent | 73 (42%) | 43 (44%) | 30 (40%) | |
| present | 100 (58%) | 55 (56%) | 45 (60%) | |
| Unknown | 14 | 4 | 10 | |
| disease | 0.012 | |||
| ALL | 68 (36%) | 38 (37%) | 30 (35%) | |
| AML | 33 (18%) | 18 (18%) | 15 (18%) | |
| chronic | 45 (24%) | 26 (25%) | 19 (22%) | |
| lymphoma | 9 (4.8%) | 0 (0%) | 9 (11%) | |
| nonmalignant | 32 (17%) | 20 (20%) | 12 (14%) | |
| disease_group | 0.3 | |||
| malignant | 155 (83%) | 82 (80%) | 73 (86%) | |
| nonmalignant | 32 (17%) | 20 (20%) | 12 (14%) | |
| gender_match | 0.9 | |||
| female_to_male | 32 (17%) | 17 (17%) | 15 (18%) | |
| other | 155 (83%) | 85 (83%) | 70 (82%) | |
| ABO_match | 0.2 | |||
| matched | 52 (28%) | 25 (25%) | 27 (32%) | |
| mismatched | 134 (72%) | 77 (75%) | 57 (68%) | |
| Unknown | 1 | 0 | 1 | |
| CMV_status | 0.7 | |||
| 0 | 48 (28%) | 26 (27%) | 22 (29%) | |
| 1 | 27 (16%) | 18 (19%) | 9 (12%) | |
| 2 | 57 (33%) | 31 (32%) | 26 (35%) | |
| 3 | 39 (23%) | 21 (22%) | 18 (24%) | |
| Unknown | 16 | 6 | 10 | |
| HLA_match | 0.9 | |||
| 10/10 | 94 (50%) | 53 (52%) | 41 (48%) | |
| 7/10 | 5 (2.7%) | 2 (2.0%) | 3 (3.5%) | |
| 8/10 | 23 (12%) | 13 (13%) | 10 (12%) | |
| 9/10 | 65 (35%) | 34 (33%) | 31 (36%) | |
| HLA_mismatch | >0.9 | |||
| matched | 159 (85%) | 87 (85%) | 72 (85%) | |
| mismatched | 28 (15%) | 15 (15%) | 13 (15%) | |
| antigen | 0.8 | |||
| 0 | 93 (50%) | 52 (51%) | 41 (48%) | |
| 1 | 21 (11%) | 10 (9.9%) | 11 (13%) | |
| 2 | 65 (35%) | 36 (36%) | 29 (34%) | |
| 3 | 7 (3.8%) | 3 (3.0%) | 4 (4.7%) | |
| Unknown | 1 | 1 | 0 | |
| allel | >0.9 | |||
| 0 | 93 (50%) | 52 (51%) | 41 (48%) | |
| 1 | 54 (29%) | 28 (28%) | 26 (31%) | |
| 2 | 32 (17%) | 18 (18%) | 14 (16%) | |
| 3 | 6 (3.2%) | 3 (3.0%) | 3 (3.5%) | |
| 4 | 1 (0.5%) | 0 (0%) | 1 (1.2%) | |
| Unknown | 1 | 1 | 0 | |
| HLA_group_1 | >0.9 | |||
| DRB1_cell | 9 (4.8%) | 4 (3.9%) | 5 (5.9%) | |
| matched | 94 (50%) | 53 (52%) | 41 (48%) | |
| mismatched | 5 (2.7%) | 2 (2.0%) | 3 (3.5%) | |
| one_allel | 14 (7.5%) | 8 (7.8%) | 6 (7.1%) | |
| one_antigen | 42 (22%) | 22 (22%) | 20 (24%) | |
| three_diffs | 4 (2.1%) | 3 (2.9%) | 1 (1.2%) | |
| two_diffs | 19 (10%) | 10 (9.8%) | 9 (11%) | |
| risk_group | 0.043 | |||
| high | 69 (37%) | 31 (30%) | 38 (45%) | |
| low | 118 (63%) | 71 (70%) | 47 (55%) | |
| stem_cell_source | 0.084 | |||
| bone_marrow | 42 (22%) | 18 (18%) | 24 (28%) | |
| peripheral_blood | 145 (78%) | 84 (82%) | 61 (72%) | |
| tx_post_relapse | 23 (12%) | 9 (8.8%) | 14 (16%) | 0.11 |
| acute_GvHD_II_III_IV | 112 (60%) | 58 (57%) | 54 (64%) | 0.4 |
| acute_GvHD_III_IV | 40 (21%) | 17 (17%) | 23 (27%) | 0.084 |
| extensive_chronic_GvHD | 28 (18%) | 11 (11%) | 17 (31%) | 0.001 |
| Unknown | 31 | 0 | 31 | |
| relapse | 28 (15%) | 5 (4.9%) | 23 (27%) | <0.001 |
| 1 n (%) |
||||
| 2 Pearson’s Chi-squared test; Fisher’s exact test |
||||
Summary tables stratifying by outcome classes (alive, dead) the absolute and relative distribution frequencies for categorical variables, and the mean (interquartile range) values for numeric variables.
The summary tables show a well-balanced cohort in terms of the outcome distribution. In addition, this cohort is composed of children and adolescent recipients (median age 9.6 [9.0] ). When compared to live group, the dead group of patients are older, presents higher body mass index, proportion of high-risk group score, frequency of relapse, slower platelet count recovery and received lower infusion of CD34 and CD3 cells.
#Histogram graph: Recipient Age ------------------------------------------------
ggplot(dataset_bmt_raw) +
aes(x = recipient_age) +
geom_histogram(bins = 30L, fill = "#112446") +
labs(x = "Recipient age", y = "Frequency") +
theme_minimal()2.2 Missing data
total_values_count <- dim(dataset_bmt_raw)[1] * dim(dataset_bmt_raw)[2]
na_count <- sum(is.na(dataset_bmt_raw))
na_counts_col <- colSums(is.na(dataset_bmt_raw))
cols_with_na_count <- na_counts_col[na_counts_col>0]
cols_with_na_count <- sort(cols_with_na_count[cols_with_na_count>0], decreasing=TRUE)
frame <- enframe(cols_with_na_count)There are 81 missing values out of 6919 total values (1.17 %).
frame# A tibble: 12 × 2
name value
<chr> <dbl>
1 extensive_chronic_GvHD 31
2 CMV_status 16
3 recipient_CMV 14
4 CD3_x1e8_per_kg 5
5 CD3_to_CD34_ratio 5
6 donor_CMV 2
7 recipient_body_mass 2
8 recipient_rh 2
9 recipient_ABO 1
10 ABO_match 1
11 antigen 1
12 allel 1
Number of missing values for each dataset column.
2.2.1 Removing columns from the dataset
HLA match variables
This cohort presents five columns to describe donor-recipient HLA match status while the literature measures only the number of HLA allele differences - mismatches - as the most relevant marker to describe this characteristic.
levels(dataset_bmt_raw$HLA_match)[1] "10/10" "7/10" "8/10" "9/10"
levels(dataset_bmt_raw$HLA_match) <- (c(0,3,2,1))
dataset_bmt_raw[,'HLA_match'] <- as.numeric(as.character(dataset_bmt_raw[,'HLA_match']))
dataset_bmt_raw$HLA_mismatch_count <- dataset_bmt_raw$HLA_match
dataset_bmt_raw$HLA_match <- NULLGvHD variables
extensive_chronic_GvHD
The variable extensive_chronic_GvHD presents high amount of missing data, concentrated only on the dead group. As chronic GvHD is detected later after the HSCT, it might by compromised for patients that died early after HSCT. To avoid artifacts on algorithms training, this columns is going to be discarded.
Acute GvHD
Similarly, the occurrence of acute GvHD is an important event impacting overall survival after HSCT. There are three columns dedicated to this event when the literature describes association between high degree (III-IV) GvHD and mortality after HSCT. The occurence of this event is summarized by the variable acute_GvHD_III_IV. The remaining two columns can be removed.
Survival Time
The survival time might be a variable that is going to cause data leakage as it contains partial information related to the outcome to be predicted.
# Removing variables from dataset ----------------------------------------------
dataset_bmt <- mutate(dataset_bmt_raw,
extensive_chronic_GvHD = NULL,
HLA_mismatch = NULL,
antigen = NULL,
allel = NULL,
HLA_group_1 = NULL,
time_to_acute_GvHD_III_IV = NULL,
acute_GvHD_II_III_IV = NULL,
survival_time = NULL)Categorical variables
Some variables were constructed as feature engineering from raw data already present on the dataset. They were removed for the purpose of dimensionality reduction.
dataset_bmt <- mutate(dataset_bmt,
donor_age_below_35 = NULL,
recipient_age_below_10 = NULL,
recipient_age_int = NULL,
gender_match = NULL,
CD34_x1e6_per_kg = NULL,
CD3_x1e8_per_kg = NULL
)2.2.2 Transforming columns from the dataset
Two variables present positive outliers that might comprimize algorithms training and might be better acomodated if presented as categories. They are PLT_recovery and ANC_recovery.
hist_plot_cont(dataset_bmt, "PLT_recovery")#hist_plot_cont(filter(dataset_bmt, PLT_recovery<300), "PLT_recovery")hist_plot_cont(dataset_bmt, "ANC_recovery")#hist_plot_cont( filter(dataset_bmt, ANC_recovery<300), "ANC_recovery")plt_breaks <- c(0, 20, 30, 100, Inf)
plt_labels <- c("0-20", "20-30", "30-100", "100+")
dataset_bmt$PLT_recovery_range <- cut(dataset_bmt$PLT_recovery,
breaks = plt_breaks,
labels = plt_labels)
hist_plot_cat(dataset_bmt, "PLT_recovery_range")anc_breaks <- c(0, 15, 20, Inf)
anc_labels <- c("0-15", "15-20", "20+")
dataset_bmt$ANC_recovery_range <- cut(dataset_bmt$ANC_recovery,
breaks = anc_breaks,
labels = anc_labels
)
hist_plot_cat(dataset_bmt, "ANC_recovery_range")generate_tblsummary(dataset_bmt[, c("ANC_recovery_range", "PLT_recovery_range", "survival_status")], "survival_status" )Characteristic |
Overall |
alive |
dead |
p-value |
|---|---|---|---|---|
| ANC_recovery_range | 0.050 | |||
| 0-15 | 102 (55%) | 57 (56%) | 45 (53%) | |
| 15-20 | 67 (36%) | 40 (39%) | 27 (32%) | |
| 20+ | 18 (9.6%) | 5 (4.9%) | 13 (15%) | |
| PLT_recovery_range | 0.007 | |||
| 0-20 | 84 (45%) | 53 (52%) | 31 (36%) | |
| 20-30 | 49 (26%) | 23 (23%) | 26 (31%) | |
| 30-100 | 29 (16%) | 19 (19%) | 10 (12%) | |
| 100+ | 25 (13%) | 7 (6.9%) | 18 (21%) | |
| 1 n (%) |
||||
| 2 Pearson’s Chi-squared test |
||||
dataset_bmt <- mutate(dataset_bmt,
ANC_recovery = NULL,
PLT_recovery = NULL)2.3 Missing data remaining
total_values_count <- dim(dataset_bmt)[1] * dim(dataset_bmt)[2]
na_count <- sum(is.na(dataset_bmt))
na_counts_col <- colSums(is.na(dataset_bmt))
cols_with_na_count <- na_counts_col[na_counts_col>0]
cols_with_na_count <- sort(cols_with_na_count[cols_with_na_count>0], decreasing=TRUE)
frame <- enframe(cols_with_na_count)There are 43 remaining missing values out of 4301 total values (1 %).
frame# A tibble: 8 × 2
name value
<chr> <dbl>
1 CMV_status 16
2 recipient_CMV 14
3 CD3_to_CD34_ratio 5
4 donor_CMV 2
5 recipient_body_mass 2
6 recipient_rh 2
7 recipient_ABO 1
8 ABO_match 1
2.4 Eliminating rows with any missing data
dataset_bmt <- dataset_bmt[complete.cases(dataset_bmt),]
colnames(dataset_bmt) [1] "donor_age" "donor_ABO" "donor_CMV"
[4] "recipient_age" "recipient_gender" "recipient_body_mass"
[7] "recipient_ABO" "recipient_rh" "recipient_CMV"
[10] "disease" "disease_group" "ABO_match"
[13] "CMV_status" "risk_group" "stem_cell_source"
[16] "tx_post_relapse" "CD3_to_CD34_ratio" "acute_GvHD_III_IV"
[19] "relapse" "survival_status" "HLA_mismatch_count"
[22] "PLT_recovery_range" "ANC_recovery_range"
str(dataset_bmt)'data.frame': 165 obs. of 23 variables:
$ donor_age : num 22.8 23.3 26.4 39.7 33.4 ...
$ donor_ABO : Factor w/ 4 levels "0","A","AB","B": 2 4 4 2 2 3 2 1 2 1 ...
$ donor_CMV : Factor w/ 2 levels "absent","present": 2 1 1 2 1 1 1 2 1 2 ...
$ recipient_age : num 9.6 4 6.6 18.1 1.3 7.9 4.7 1.9 13.4 5.1 ...
$ recipient_gender : Factor w/ 2 levels "female","male": 2 2 2 1 1 2 2 1 1 2 ...
$ recipient_body_mass: num 35 20.6 23.4 50 9 20.5 16.5 10.5 47 18.1 ...
$ recipient_ABO : Factor w/ 4 levels "0","A","AB","B": 2 4 4 3 3 1 1 4 2 2 ...
$ recipient_rh : Factor w/ 2 levels "minus","plus": 2 2 2 2 1 2 2 2 2 2 ...
$ recipient_CMV : Factor w/ 2 levels "absent","present": 2 1 2 1 2 2 2 1 1 1 ...
$ disease : Factor w/ 5 levels "ALL","AML","chronic",..: 1 1 1 2 3 5 5 3 3 1 ...
$ disease_group : Factor w/ 2 levels "malignant","nonmalignant": 1 1 1 1 1 2 2 1 1 1 ...
$ ABO_match : Factor w/ 2 levels "matched","mismatched": 1 1 1 2 2 2 2 2 1 2 ...
$ CMV_status : Factor w/ 4 levels "0","1","2","3": 4 1 3 2 1 3 3 2 1 2 ...
$ risk_group : Factor w/ 2 levels "high","low": 1 2 2 2 1 2 2 1 1 2 ...
$ stem_cell_source : Factor w/ 2 levels "bone_marrow",..: 2 1 1 1 2 2 2 2 2 2 ...
$ tx_post_relapse : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 1 1 1 1 ...
$ CD3_to_CD34_ratio : num 1.34 11.08 19.01 29.48 3.97 ...
$ acute_GvHD_III_IV : Factor w/ 2 levels "no","yes": 2 1 1 2 1 1 1 1 1 1 ...
$ relapse : Factor w/ 2 levels "no","yes": 1 2 2 1 1 1 1 1 1 1 ...
$ survival_status : Factor w/ 2 levels "alive","dead": 1 2 2 2 1 1 1 1 1 1 ...
$ HLA_mismatch_count : num 0 0 0 0 1 0 1 1 2 0 ...
$ PLT_recovery_range : Factor w/ 4 levels "0-20","20-30",..: 3 3 1 2 1 1 1 1 1 3 ...
$ ANC_recovery_range : Factor w/ 3 levels "0-15","15-20",..: 2 2 3 3 1 1 2 1 1 1 ...
2.5 Numeric variables correlation Matrix
nums <- unlist(lapply(dataset_bmt, is.numeric), use.names = FALSE)
correlation_plot <- corrplot(cor(dataset_bmt[,nums]), type = "lower", tl.col = "black", tl.srt = 45, na.label=" ")2.6 Training Models with complete dataset
2.6.1 Train test partition | Cross-validation settings
# Train/test partition
set.seed(01)
intrain <- createDataPartition(
y = dataset_bmt$survival_status,
p = 0.70,
list = FALSE)
train_bmt <- dataset_bmt[intrain,]
test_bmt <- dataset_bmt[-intrain,]
### Defining Control - repeated cross validation
ctrl_rcv <- trainControl(
method = "repeatedcv",
number = 10,
repeats = 10,
summaryFunction = twoClassSummary,
classProbs = TRUE,
allowParallel = T)
### Defining Control - leave-one-out
ctrl_loocv <- trainControl(
method = "LOOCV",
summaryFunction = twoClassSummary,
classProbs = TRUE,
allowParallel = T)Function to train models and select best hyperparameters set
#Function to train models selecting best hyperparameters set
tunning_model <- function(method, ctrl, grid, train, seed){
#Setting seed
if (!missing(seed))
set.seed(seed)
#Training the model looking for best hyperparameters
if (method == "rpart2"){
#print('rpart2')
tunning_model <- train(survival_status ~ .,
method = "rpart2",
tuneLength = grid,
trControl = ctrl,
metric = "ROC",
data = train)
}
else if(method=="naive_bayes"){
#print('naive bayes')
#No preprocessing on numeric variables
tunning_model <- train(survival_status ~ .,
method = method,
#preProcess = c("center","scale"),
trControl = ctrl,
tuneGrid = grid,
metric = "ROC",
data = train)
}
else {
#print('others')
tunning_model <- train(survival_status ~ .,
method = method,
preProcess = c("center","scale"),
trControl = ctrl,
tuneGrid = grid,
metric = "ROC",
data = train)
}
#Returning trained model
return(tunning_model)
}Function evaluate models performance
# Function returning performance metrics for trained models
model_performance <- function(model, test, cutoff){
print("Pred prob")
pred_prob <- predict(model, test, type = "prob")
print("result cat")
result_cat <- as.factor(ifelse(pred_prob[,2] > cutoff,"dead","alive"))
print("cfm")
cfm <- confusionMatrix(as.factor(result_cat),
test$survival_status,
positive = "dead")
print("auc")
auc <- roc(test$survival_status, pred_prob[,2])
print("fmeas")
f_meas <- F_meas(as.factor(result_cat),
test$survival_status,
positive = "dead")
print("list")
model_results <- list()
model_results$cfm <- cfm
model_results$auc <- auc
model_results$fscore <- f_meas
return(model_results)
}2.6.1 Model 01: Penalized Logistic Regression
Training
# Logistic regression: Hyperparameters tested
tunegrid_logreg <- expand.grid(lambda = c(1e-4, 1e-3, 1e-2, 1),
cp = c("aic", "bic"))
#Training model
method <- "plr"
ctrl_set <- ctrl_rcv
grid <- tunegrid_logreg
train <- train_bmt
seed <- 01
tunning_model_logreg <- tunning_model(method, ctrl_set, grid, train, seed)
ctrl_set <- ctrl_loocv
tunning_model_logreg_loocv <- tunning_model(method, ctrl_set, grid, train, seed)Evaluating (repeated CV)
cutoff = 0.5
model_logreg_results <- model_performance(tunning_model_logreg,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_logreg_results$auc, print.thres = 'best')cutoff = 0.744
model_logreg_results <- model_performance(tunning_model_logreg,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_logreg_results$method <- tunning_model_logreg$method
model_logreg_results$bestune <- tunning_model_logreg$bestTune
model_logreg_results$final_model <- tunning_model_logreg$finalModel
model_logreg_results$cutoff <- cutoffEvaluating (leave-one-out)
cutoff = 0.5
model_logreg_results_loocv <- model_performance(tunning_model_logreg_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_logreg_results_loocv$auc, print.thres = 'best')cutoff = 0.744
model_logreg_results_loocv <- model_performance(tunning_model_logreg_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_logreg_results_loocv$method <- tunning_model_logreg_loocv$method
model_logreg_results_loocv$bestune <- tunning_model_logreg_loocv$bestTune
model_logreg_results_loocv$final_model <- tunning_model_logreg_loocv$finalModel
model_logreg_results_loocv$cutoff <- cutoff2.6.2 Model 02: Decision Trees
Training
method <- "rpart2"
train <- train_bmt
seed <- 01
grid <- 50
ctrl_set <- ctrl_rcv
tunning_model_dt <- tunning_model(method, ctrl_set, grid, train, seed)note: only 5 possible values of the max tree depth from the initial fit.
Truncating the grid to 5 .
ctrl_set <- ctrl_loocv
tunning_model_dt_loocv <- tunning_model(method, ctrl_set, grid, train, seed)note: only 5 possible values of the max tree depth from the initial fit.
Truncating the grid to 5 .
Evaluating (repeated CV)
cutoff = 0.5
model_dt_results <- model_performance(tunning_model_dt,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_dt_results$auc, print.thres = 'best')cutoff = 0.540
model_dt_results <- model_performance(tunning_model_dt,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_dt_results$method <- tunning_model_dt$method
model_dt_results$bestune <- tunning_model_dt$bestTune
model_dt_results$final_model <- tunning_model_dt$finalModel
model_dt_results$cutoff <- cutoffEvaluating (Leave-one-out)
cutoff = 0.5
model_dt_results_loocv <- model_performance(tunning_model_dt_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_dt_results_loocv$auc, print.thres = 'best')cutoff = 0.739
model_dt_results_loocv <- model_performance(tunning_model_dt_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_dt_results_loocv$method <- tunning_model_dt_loocv$method
model_dt_results_loocv$bestune <- tunning_model_dt_loocv$bestTune
model_dt_results_loocv$final_model <- tunning_model_dt_loocv$finalModel
model_dt_results_loocv$cutoff <- cutoff2.6.3 Model 03: K-NN
# K-NN: Hyperparameters tested
odds <- function(x) subset(x, x %% 2 != 0)
k_options <- odds(2:14)
distances_all = c("triangular", "rectangular", "epanechnikov", "optimal", "inv")
tunegrid_knn <- expand.grid(kmax = k_options,
distance = c(1,2),
kernel = distances_all)
#Training model
method <- "kknn"
grid <- tunegrid_knn
train <- train_bmt
seed <- 01
ctrl_set <- ctrl_rcv
tunning_model_knn <- tunning_model(method, ctrl_set, grid, train, seed)
ctrl_set <- ctrl_loocv
tunning_model_knn_loocv <- tunning_model(method, ctrl_set, grid, train, seed)Evaluating (repeated CV)
cutoff = 0.5
model_knn_results <- model_performance(tunning_model_knn,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_knn_results$auc, print.thres = 'best')cutoff = 0.676
model_knn_results <- model_performance(tunning_model_knn,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_knn_results$method <- tunning_model_knn$method
model_knn_results$bestune <- tunning_model_knn$bestTune
model_knn_results$final_model <- tunning_model_knn$finalModel
model_knn_results$cutoff <- cutoffEvaluating (leave-one-out)
cutoff = 0.5
model_knn_results_loocv <- model_performance(tunning_model_knn_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_knn_results_loocv$auc, print.thres = 'best')cutoff = 0.750
model_knn_results_loocv <- model_performance(tunning_model_knn_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_knn_results_loocv$method <- tunning_model_knn_loocv$method
model_knn_results_loocv$bestune <- tunning_model_knn_loocv$bestTune
model_knn_results_loocv$final_model <- tunning_model_knn_loocv$finalModel
model_knn_results_loocv$cutoff <- cutoff2.6.4 Model 04: Naive Bayes
Training
# Naive bayes: Hyperparameters tested
tunegrid_nb <- expand.grid(usekernel = c(TRUE, FALSE),
laplace = 0:1,
adjust = 1:10)
#Training model
method <- "naive_bayes"
grid <- tunegrid_nb
train <- train_bmt
seed <- 01
ctrl_set <- ctrl_rcv
tunning_model_nb <- tunning_model(method, ctrl_set, grid, train, seed)
ctrl_set <- ctrl_loocv
tunning_model_nb_loocv <- tunning_model(method, ctrl_set, grid, train, seed)Evaluating (repeated CV)
cutoff = 0.5
model_nb_results <- model_performance(tunning_model_nb,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_nb_results$auc, print.thres = 'best')cutoff = 0.226
model_nb_results <- model_performance(tunning_model_nb,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_nb_results$method <- tunning_model_nb$method
model_nb_results$bestune <- tunning_model_nb$bestTune
model_nb_results$final_model <- tunning_model_nb$finalModel
model_nb_results$cutoff <- cutoffEvaluating (Leave-one-out)
cutoff = 0.5
model_nb_results_loocv <- model_performance(tunning_model_nb_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_nb_results_loocv$auc, print.thres = 'best')cutoff = 0.220
model_nb_results_loocv <- model_performance(tunning_model_nb_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_nb_results_loocv$method <- tunning_model_nb_loocv$method
model_nb_results_loocv$bestune <- tunning_model_nb_loocv$bestTune
model_nb_results_loocv$final_model <- tunning_model_nb_loocv$finalModel
model_nb_results_loocv$cutoff <- cutoff2.6.5 Model 05: Random Forest
tunegrid_rf <- expand.grid(.mtry = c(sqrt(ncol(train_bmt))))
modellist_rf <- list()
for (ntree in c(500, 1000, 1500)) {
set.seed(01)
fit <- train(survival_status~.,
data = train_bmt,
method ="rf",
metric ="ROC",
tuneGrid = tunegrid_rf,
trControl= ctrl_rcv,
ntree = ntree)
key <- toString(ntree)
modellist_rf[[key]] <- fit
}
# compare results
results_rf <- resamples(modellist_rf)
summary(results_rf)
Call:
summary.resamples(object = results_rf)
Models: 500, 1000, 1500
Number of resamples: 100
ROC
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
500 0.3000000 0.5285714 0.6285714 0.6277976 0.7333333 0.9857143 0
1000 0.2666667 0.5142857 0.6285714 0.6251667 0.7333333 0.9714286 0
1500 0.2666667 0.5107143 0.6333333 0.6251548 0.7190476 0.9714286 0
Sens
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
500 0.4285714 0.7142857 0.8333333 0.7830952 0.8571429 1 0
1000 0.4285714 0.7023810 0.8333333 0.7785714 0.8571429 1 0
1500 0.4285714 0.7023810 0.8333333 0.7828571 0.8571429 1 0
Spec
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
500 0 0.2375 0.4 0.4250 0.6 1 0
1000 0 0.2000 0.4 0.4225 0.6 1 0
1500 0 0.2000 0.4 0.4245 0.6 1 0
dotplot(results_rf)#Function to train random forest
tunning_rf <- function(train_bmt, ctrl, best_ntree, seed){
if (!missing(seed))
set.seed(seed)
tunegrid_rf <- expand.grid(.mtry = c(sqrt(ncol(train_bmt))))
tunned_rf_model <- train(survival_status~.,
data =train_bmt,
method ="rf",
metric ="ROC",
tuneGrid =tunegrid_rf,
trControl =ctrl,
ntree =best_ntree)
}
best_ntree <- 500 #as determined on previous chunck
seed <- 01
ctrl_set <- ctrl_rcv
tunning_model_rf <- tunning_rf(train_bmt, ctrl_set, best_ntree, seed)
ctrl_set <- ctrl_loocv
tunning_model_rf_loocv <- tunning_rf(train, ctrl_set, best_ntree, seed)Evaluating (repeated CV)
cutoff = 0.5
model_rf_results <- model_performance(tunning_model_rf,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_rf_results$auc, print.thres = 'best')cutoff = 0.462
model_rf_results <- model_performance(tunning_model_rf,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_rf_results$method <- tunning_model_rf$method
model_rf_results$bestune <- tunning_model_rf$bestTune
model_rf_results$final_model <- tunning_model_rf$finalModel
model_rf_results$cutoff <- cutoffEvaluating (Leave-one-out)
cutoff = 0.5
model_rf_results_loocv <- model_performance(tunning_model_rf_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
plot.roc(model_rf_results_loocv$auc, print.thres = 'best')cutoff = 0.455
model_rf_results_loocv <- model_performance(tunning_model_rf_loocv,
test_bmt,
cutoff)[1] "Pred prob"
[1] "result cat"
[1] "cfm"
[1] "auc"
[1] "fmeas"
[1] "list"
model_rf_results_loocv$method <- tunning_model_rf_loocv$method
model_rf_results_loocv$bestune <- tunning_model_rf_loocv$bestTune
model_rf_results_loocv$final_model <- tunning_model_rf_loocv$finalModel
model_rf_results_loocv$cutoff <- cutoff2.7 Variable importance:
2.7.1 Genetic algorithm
set.seed(01)
ga_ctrl <- gafsControl(functions = rfGA,
method = "cv",
genParallel=TRUE,
allowParallel = TRUE)
train_bmt_ga <- mutate(train_bmt, survival_status=NULL)
rf_ga3 <- gafs(x = train_bmt_ga,
y = train_bmt$survival_status,
iters = 100,
popSize = 15,
gafsControl = ga_ctrl)
rf_ga3
Genetic Algorithm Feature Selection
116 samples
22 predictors
2 classes: 'alive', 'dead'
Maximum generations: 100
Population per generation: 15
Crossover probability: 0.8
Mutation probability: 0.1
Elitism: 0
Internal performance values: Accuracy, Kappa
Subset selection driven to maximize internal Accuracy
External performance values: Accuracy, Kappa
Best iteration chose by maximizing external Accuracy
External resampling method: Cross-Validated (10 fold)
During resampling:
* the top 5 selected variables (out of a possible 22):
ANC_recovery_range (100%), HLA_mismatch_count (100%), recipient_rh (100%), relapse (100%), donor_CMV (80%)
* on average, 9.2 variables were selected (min = 5, max = 12)
In the final search using the entire training set:
* 10 features selected at iteration 54 including:
donor_age, donor_ABO, donor_CMV, recipient_gender, recipient_rh ...
* external performance at this iteration is
Accuracy Kappa
0.6644 0.2783
2.7.2 Random Forest
set.seed(1)
rf <- randomForest(survival_status ~ .,
data = train_bmt,
importance = T,
mtry = 4,
ntree = 500)3. Results Comparing models
Repeated Cross Validation Penalized Logistic Regression presented higher ROC and specificity values when compared
3.1 Hyperparameters: best tuninng (repeatedCV and Leave-one-out)
# Joining best tune hyperparameters on a single dataframe
## Joining repeatedCV
best_tunes <- cbind(model_logreg_results$bestune,
model_dt_results$bestune,
model_knn_results$bestune,
model_nb_results$bestune,
model_rf_results$bestune
)
rownames(best_tunes) <- "repeatedcv"
## Joining leave-one-out
best_tunes_loocv <- cbind(model_logreg_results_loocv$bestune,
model_dt_results_loocv$bestune,
model_knn_results_loocv$bestune,
model_nb_results_loocv$bestune,
model_rf_results_loocv$bestune
)
rownames(best_tunes_loocv) <- "leave-one-out"
## Joining repeated CV and leave-one-out
best_tunes_all <- rbind(best_tunes, best_tunes_loocv)
best_tunes_all$ntree <- 500
best_tunes_all_t <- t(best_tunes_all)
best_tunes_all_t <- as.data.frame(best_tunes_all_t)
best_tunes_all_t$model_trained <- c("Penalized Logistic Regression",
"Penalized Logistic Regression",
"Decision Tree",
"K-Nearest neaighbors",
"K-Nearest neaighbors",
"K-Nearest neaighbors",
"Naive Bayes",
"Naive Bayes",
"Naive Bayes",
"Random Forest",
"Random Forest"
)
best_tunes_all_t <- tibble::rownames_to_column(best_tunes_all_t, "hyperparameter")
best_tunes_all_t <- best_tunes_all_t[,c(4,1,2,3)]
best_tunes_all_t model_trained hyperparameter repeatedcv leave-one-out
1 Penalized Logistic Regression lambda 1 1
2 Penalized Logistic Regression cp aic aic
3 Decision Tree maxdepth 2 8
4 K-Nearest neaighbors kmax 9 9
5 K-Nearest neaighbors distance 2 2
6 K-Nearest neaighbors kernel inv triangular
7 Naive Bayes laplace 0 0
8 Naive Bayes usekernel TRUE TRUE
9 Naive Bayes adjust 1 2
10 Random Forest mtry 4.795832 4.795832
11 Random Forest ntree 500 500
3.2 Accuracy [C.I. range], Kappa metrics (repeatedCV and Leave-one-out)
Function to generate sumarized metrics dataframe for each model
sumarized_model_metrics <- function(model_res, model_res_loocv, model){
# Capturing metrics ----------------------------------------------------------
accuracy = model_res$cfm$overall["Accuracy"]
accuracy_ci_lower = model_res$cfm$overall["AccuracyLower"]
accuracy_ci_upper = model_res$cfm$overall["AccuracyUpper"]
accuracy_pvalue = model_res$cfm$overall["AccuracyPValue"]
sensitivity = model_res$cfm$byClass["Sensitivity"]
specificity = model_res$cfm$byClass["Specificity"]
f_score = model_res$fscore
kappa = model_res$cfm$overall["Kappa"]
accuracy_loocv = model_res_loocv$cfm$overall["Accuracy"]
accuracy_loocv_ci_lower = model_res_loocv$cfm$overall["AccuracyLower"]
accuracy_loocv_ci_upper = model_res_loocv$cfm$overall["AccuracyUpper"]
accuracy_loocv_pvalue = model_res_loocv$cfm$overall["AccuracyPValue"]
sensitivity_loocv = model_res_loocv$cfm$byClass["Sensitivity"]
specificity_loocv = model_res_loocv$cfm$byClass["Specificity"]
f_score_loocv = model_res_loocv$fscore
kappa_loocv = model_res_loocv$cfm$overall["Kappa"]
# Building dataframe ---------------------------------------------------------
df <- data.frame(accuracy,
accuracy_ci_lower,
accuracy_ci_upper,
accuracy_pvalue,
sensitivity,
specificity,
f_score,
kappa,
accuracy_loocv,
accuracy_loocv_ci_lower,
accuracy_loocv_ci_upper,
accuracy_loocv_pvalue,
sensitivity_loocv,
specificity_loocv,
f_score_loocv,
kappa_loocv)
rownames(df) <- model
return(df)
}model <- "Logistic Regression"
logreg_metrics <- sumarized_model_metrics(model_logreg_results,
model_logreg_results_loocv,
model)
model <- "Decision tree"
dt_metrics <- sumarized_model_metrics(model_dt_results, model_dt_results_loocv, model)
model <- "K-Nearest Neighbors"
knn_metrics <- sumarized_model_metrics(model_knn_results, model_knn_results_loocv, model)
model <- "Naive Bayes"
nb_metrics <- sumarized_model_metrics(model_nb_results, model_nb_results_loocv, model)
model <- "Random Forest"
rf_metrics <- sumarized_model_metrics(model_rf_results, model_rf_results_loocv, model)
transpose_round <- function(metric_df){
result <- metric_df %>%
mutate_if(is.numeric, round, digits=3) %>%
t() %>%
as.data.frame() #%>%
#tibble::rownames_to_column(var="metric")
return(result)
}
logreg_metrics_t <- transpose_round(logreg_metrics)
knn_metrics_t <- transpose_round(knn_metrics)
nb_metrics_t <- transpose_round(nb_metrics)
dt_metrics_t <- transpose_round(dt_metrics)
rf_metrics_t <- transpose_round(rf_metrics)
#rownames(all_models_metrics)
all_models_metrics <- cbind(logreg_metrics_t,
dt_metrics_t,
knn_metrics_t,
nb_metrics_t,
rf_metrics_t)
all_models_metrics_full <- tibble::rownames_to_column(all_models_metrics,
"ML models")all_models_metrics_full ML models Logistic Regression Decision tree
1 accuracy 0.735 0.714
2 accuracy_ci_lower 0.589 0.567
3 accuracy_ci_upper 0.851 0.834
4 accuracy_pvalue 0.014 0.028
5 sensitivity 0.381 0.381
6 specificity 1.000 0.964
7 f_score 0.812 0.794
8 kappa 0.413 0.372
9 accuracy_loocv 0.735 0.673
10 accuracy_loocv_ci_lower 0.589 0.525
11 accuracy_loocv_ci_upper 0.851 0.801
12 accuracy_loocv_pvalue 0.014 0.096
13 sensitivity_loocv 0.381 0.381
14 specificity_loocv 1.000 0.893
15 f_score_loocv 0.812 0.758
16 kappa_loocv 0.413 0.291
K-Nearest Neighbors Naive Bayes Random Forest
1 0.612 0.735 0.694
2 0.462 0.589 0.546
3 0.748 0.851 0.817
4 0.335 0.014 0.054
5 0.333 0.429 0.524
6 0.821 0.964 0.821
7 0.708 0.806 0.754
8 0.164 0.420 0.356
9 0.633 0.714 0.694
10 0.483 0.567 0.546
11 0.766 0.834 0.817
12 0.236 0.028 0.054
13 0.333 0.333 0.524
14 0.857 1.000 0.821
15 0.727 0.800 0.754
16 0.203 0.364 0.356
Table showing relevant metrics for the five models studied and two cross validation strategies. No significant difference was detected among models in terms of accuracy.
3.3 ROC, Sensitivity and Specificity (repeatedCV)
set.seed(1)
resamples_list <- resamples(list("Logistic Regression" = tunning_model_logreg,
"KNN" = tunning_model_knn,
"Naive Bayes" = tunning_model_nb,
"Decision Tree" = tunning_model_dt,
"Random Forest" = tunning_model_rf
)
)
summary(resamples_list)
Call:
summary.resamples(object = resamples_list)
Models: Logistic Regression, KNN, Naive Bayes, Decision Tree, Random Forest
Number of resamples: 100
ROC
Min. 1st Qu. Median Mean 3rd Qu. Max.
Logistic Regression 0.3333333 0.6000000 0.7142857 0.7053810 0.8285714 1.0000000
KNN 0.1428571 0.4702381 0.5863095 0.5891310 0.7000000 1.0000000
Naive Bayes 0.3142857 0.5714286 0.6571429 0.6766786 0.8000000 0.9714286
Decision Tree 0.2714286 0.4851190 0.5482143 0.5668988 0.6428571 0.8571429
Random Forest 0.3000000 0.5285714 0.6285714 0.6277976 0.7333333 0.9857143
NA's
Logistic Regression 0
KNN 0
Naive Bayes 0
Decision Tree 0
Random Forest 0
Sens
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
Logistic Regression 0.3333333 0.6666667 0.7142857 0.7488095 0.8571429 1 0
KNN 0.3333333 0.6428571 0.8333333 0.7573810 0.8571429 1 0
Naive Bayes 0.4285714 0.7142857 0.8571429 0.8657143 1.0000000 1 0
Decision Tree 0.2857143 0.5714286 0.7142857 0.7252381 0.8571429 1 0
Random Forest 0.4285714 0.7142857 0.8333333 0.7830952 0.8571429 1 0
Spec
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
Logistic Regression 0.2 0.4000 0.6 0.5595 0.6375 1.0 0
KNN 0.0 0.2000 0.4 0.3255 0.4000 0.8 0
Naive Bayes 0.0 0.2000 0.2 0.3045 0.5000 1.0 0
Decision Tree 0.0 0.2000 0.4 0.3910 0.5250 0.8 0
Random Forest 0.0 0.2375 0.4 0.4250 0.6000 1.0 0
dotplot(resamples_list, scale.min = 0.1)Repeated Cross validation penalized logistic regression and Naive Bayes presented HIGHER Specificity and Sensitivity values when compared to Random Forest, KNN and Decision Tree, respectively.
There is no statistically significant difference among Logistic Regression and Naive Bayes in terms of ROC but Logistic Regression presented higher Specificity.
3.4 Confusion Matrix
library(ggplot2)
cfm_image <- function(cfm){
ggplot(data = cfm, mapping = aes(x = Reference, y = Prediction)) +
geom_tile(aes(fill = n), colour = "white") +
geom_text(aes(label = sprintf("%1.0f", n)), vjust = 1) +
scale_fill_gradient(low="white", high="#009194") +
theme_bw() + theme(legend.position = "none") +
scale_x_discrete(position = "top") +
coord_trans(y = "reverse")
}
cfm_img_lr <- cfm_image(as_tibble(model_logreg_results$cfm$table))
cfm_img_lr_loocv <- cfm_image(as_tibble(model_logreg_results_loocv$cfm$table))
cfm_img_knn <- cfm_image(as_tibble(model_knn_results$cfm$table))
cfm_img_knn_loocv <- cfm_image(as_tibble(model_knn_results_loocv$cfm$table))
cfm_img_dt <- cfm_image(as_tibble(model_dt_results$cfm$table))
cfm_img_dt_loocv <- cfm_image(as_tibble(model_dt_results_loocv$cfm$table))
cfm_img_nb <- cfm_image(as_tibble(model_nb_results$cfm$table))
cfm_img_nb_loocv <- cfm_image(as_tibble(model_nb_results_loocv$cfm$table))
cfm_img_rf <- cfm_image(as_tibble(model_rf_results$cfm$table))
cfm_img_rf_loocv <- cfm_image(as_tibble(model_rf_results_loocv$cfm$table))3.4.1 Penalized Logistic Regression
cfm_img_lrcfm_img_lr_loocv3.4.2 Decision Trees
cfm_img_dtcfm_img_dt_loocv3.4.3 K-Nearest Neighbors
cfm_img_knncfm_img_knn_loocv3.4.4 Naive Bayes
cfm_img_nbcfm_img_nb_loocv3.4.5 Random Forest
cfm_img_rfcfm_img_rf_loocv3.5 Variable Importance
3.5.1 Ploting Decision Tree
dtree <- function(tunning_model){
rpart.plot(tunning_model$finalModel,
extra = 4, # informações extras nos nós
type = 4, # tipo de gráfico
box.palette = "RdYlGn") # cor
}dtree_cv <- dtree(tunning_model_dt)dtree_loocv <- dtree(tunning_model_dt_loocv)dt_importance_loocv <- as.data.frame(dtree_loocv$obj$variable.importance)
dt_importance_loocv$importance <- dt_importance_loocv$`dtree_loocv$obj$variable.importance`
dt_importance_loocv$importance <- round(dt_importance_loocv$importance, 3)
dt_importance_loocv$`dtree_loocv$obj$variable.importance`= NULL
dt_importance_loocv <- tibble::rownames_to_column(dt_importance_loocv, "variable")
dt_importance_loocv <- dt_importance_loocv[order(dt_importance_loocv$importance,decreasing=TRUE),]
dt_importance_loocv variable importance
1 recipient_age 9.974
2 donor_age 7.502
3 recipient_body_mass 6.151
4 relapseyes 4.668
5 PLT_recovery_range100+ 2.070
6 CMV_status3 2.063
7 donor_ABOA 1.830
8 ANC_recovery_range20+ 1.052
9 donor_CMVpresent 1.031
10 CD3_to_CD34_ratio 0.849
11 ABO_matchmismatched 0.575
12 diseaselymphoma 0.549
13 disease_groupnonmalignant 0.395
14 diseasenonmalignant 0.395
15 risk_grouplow 0.261
16 HLA_mismatch_count 0.209
17 recipient_CMVpresent 0.206
18 CMV_status1 0.138
19 PLT_recovery_range30-100 0.138
20 recipient_rhplus 0.069
3.5.2 Genetic algorithm features selected
plot(rf_ga3)ga_imp_variables <- rf_ga3$ga$final
ga_imp_variables[1] "donor_ABO" "donor_CMV" "recipient_gender"
[4] "recipient_rh" "CMV_status" "CD3_to_CD34_ratio"
[7] "acute_GvHD_III_IV" "relapse" "HLA_mismatch_count"
3.5.3 Random Forest
rf_variables <- cbind(round(importance(rf, type = 1), 3),
round(importance(rf, type = 2), 3),
varUsed(rf, count= TRUE))
rf_variables_df <- as.data.frame(rf_variables)
rf_variables_df <- tibble::rownames_to_column(rf_variables_df, "variable")
rf_variables_df$count <- rf_variables_df$V3
rf_variables_df$V3 = NULL
rf_variables_df$ga <- ifelse(rf_variables_df$variable %in% ga_imp_variables, "yes", "no")
rf_variables_df$`decrease_ac+gini` <- rf_variables_df$MeanDecreaseAccuracy + rf_variables_df$MeanDecreaseGini
rf_variables_df <- rf_variables_df[order(rf_variables_df$`decrease_ac+gini`,decreasing=TRUE),]
rf_variables_df variable MeanDecreaseAccuracy MeanDecreaseGini count ga
19 relapse 10.994 3.441 477 yes
4 recipient_age 2.267 6.443 1425 no
6 recipient_body_mass 1.454 6.575 1433 no
17 CD3_to_CD34_ratio 0.315 6.194 1443 yes
2 donor_ABO 3.065 2.466 635 yes
8 recipient_rh 3.788 1.270 324 yes
22 ANC_recovery_range 2.149 2.117 545 no
1 donor_age -1.722 5.796 1427 no
10 disease 0.580 3.375 827 no
3 donor_CMV 2.547 0.966 331 yes
20 HLA_mismatch_count 1.242 2.161 660 yes
21 PLT_recovery_range 0.081 3.253 748 no
12 ABO_match 0.787 1.961 410 no
13 CMV_status -0.311 2.545 764 yes
5 recipient_gender -0.148 0.860 321 yes
7 recipient_ABO -1.306 2.017 622 no
9 recipient_CMV -0.415 0.763 271 no
15 stem_cell_source -0.505 0.797 246 no
18 acute_GvHD_III_IV -1.200 0.874 281 yes
14 risk_group -1.477 0.791 285 no
11 disease_group -2.129 0.444 178 no
16 tx_post_relapse -2.951 0.697 184 no
decrease_ac+gini
19 14.435
4 8.710
6 8.029
17 6.509
2 5.531
8 5.058
22 4.266
1 4.074
10 3.955
3 3.513
20 3.403
21 3.334
12 2.748
13 2.234
5 0.712
7 0.711
9 0.348
15 0.292
18 -0.326
14 -0.686
11 -1.685
16 -2.254
# Gráfico da importância das variáveis
varImpPlot(rf, sort = T)3.5.4 Penalized Logistic Regression coeficients
coeficients = (as.data.frame(tunning_model_logreg$finalModel$coefficients))
coeficients <- tibble::rownames_to_column(coeficients, "variable")
coeficients$coeficient <- coeficients$`tunning_model_logreg$finalModel$coefficients`
coeficients$`tunning_model_logreg$finalModel$coefficients`= NULL
coeficients$signal <- ifelse (coeficients$coeficient < 0, "negative", "positive" )
coeficients$coeficient_module <- round(sqrt(coeficients$coeficient**2), 3)
coeficients$coeficient = NULL
coeficients <- coeficients[order(coeficients$coeficient_module,decreasing=TRUE),]
coeficients variable signal coeficient_module
29 relapseyes negative 0.885
9 recipient_body_mass negative 0.692
13 recipient_rhplus negative 0.667
26 tx_post_relapseyes negative 0.547
16 diseasechronic negative 0.479
5 donor_ABOB positive 0.476
20 ABO_matchmismatched positive 0.472
35 ANC_recovery_range20+ negative 0.466
1 Intercept positive 0.420
33 PLT_recovery_range100+ negative 0.412
12 recipient_ABOB negative 0.369
27 CD3_to_CD34_ratio negative 0.313
17 diseaselymphoma negative 0.305
24 risk_grouplow negative 0.272
30 HLA_mismatch_count positive 0.256
4 donor_ABOAB positive 0.234
2 donor_age negative 0.219
31 PLT_recovery_range20-30 negative 0.201
3 donor_ABOA negative 0.200
34 ANC_recovery_range15-20 positive 0.177
18 diseasenonmalignant negative 0.165
19 disease_groupnonmalignant negative 0.165
8 recipient_gendermale negative 0.163
7 recipient_age positive 0.130
6 donor_CMVpresent positive 0.126
28 acute_GvHD_III_IVyes negative 0.121
32 PLT_recovery_range30-100 negative 0.105
15 diseaseAML positive 0.080
11 recipient_ABOAB positive 0.073
10 recipient_ABOA negative 0.052
25 stem_cell_sourceperipheral_blood positive 0.050
21 CMV_status1 positive 0.035
22 CMV_status2 negative 0.032
23 CMV_status3 positive 0.020
14 recipient_CMVpresent positive 0.008
dt_imp_logreg_coef <- merge(coeficients,
dt_importance_loocv,
by='variable',
all = TRUE
)
dt_imp_logreg_coef <- dt_imp_logreg_coef[order(dt_imp_logreg_coef$coeficient_module,
decreasing=TRUE),]
dt_imp_logreg_coef variable signal coeficient_module importance
32 relapseyes negative 0.885 4.668
28 recipient_body_mass negative 0.692 6.151
31 recipient_rhplus negative 0.667 0.069
35 tx_post_relapseyes negative 0.547 NA
11 diseasechronic negative 0.479 NA
16 donor_ABOB positive 0.476 NA
1 ABO_matchmismatched positive 0.472 0.575
4 ANC_recovery_range20+ negative 0.466 1.052
20 Intercept positive 0.420 NA
21 PLT_recovery_range100+ negative 0.412 2.070
26 recipient_ABOB negative 0.369 NA
5 CD3_to_CD34_ratio negative 0.313 0.849
12 diseaselymphoma negative 0.305 0.549
33 risk_grouplow negative 0.272 0.261
19 HLA_mismatch_count positive 0.256 0.209
15 donor_ABOAB positive 0.234 NA
17 donor_age negative 0.219 7.502
22 PLT_recovery_range20-30 negative 0.201 NA
14 donor_ABOA negative 0.200 1.830
3 ANC_recovery_range15-20 positive 0.177 NA
9 disease_groupnonmalignant negative 0.165 0.395
13 diseasenonmalignant negative 0.165 0.395
30 recipient_gendermale negative 0.163 NA
27 recipient_age positive 0.130 9.974
18 donor_CMVpresent positive 0.126 1.031
2 acute_GvHD_III_IVyes negative 0.121 NA
23 PLT_recovery_range30-100 negative 0.105 0.138
10 diseaseAML positive 0.080 NA
25 recipient_ABOAB positive 0.073 NA
24 recipient_ABOA negative 0.052 NA
34 stem_cell_sourceperipheral_blood positive 0.050 NA
6 CMV_status1 positive 0.035 0.138
7 CMV_status2 negative 0.032 NA
8 CMV_status3 positive 0.020 2.063
29 recipient_CMVpresent positive 0.008 0.206
Table showing variable importance for penalized logistic regression (coeficient_module) and decision tree (importance). Rows are ordered from higher to lower coeficient modules. Relapse and recipient body mass are common relevant variables for both algorithms predictions.
rf_variables_df variable MeanDecreaseAccuracy MeanDecreaseGini count ga
19 relapse 10.994 3.441 477 yes
4 recipient_age 2.267 6.443 1425 no
6 recipient_body_mass 1.454 6.575 1433 no
17 CD3_to_CD34_ratio 0.315 6.194 1443 yes
2 donor_ABO 3.065 2.466 635 yes
8 recipient_rh 3.788 1.270 324 yes
22 ANC_recovery_range 2.149 2.117 545 no
1 donor_age -1.722 5.796 1427 no
10 disease 0.580 3.375 827 no
3 donor_CMV 2.547 0.966 331 yes
20 HLA_mismatch_count 1.242 2.161 660 yes
21 PLT_recovery_range 0.081 3.253 748 no
12 ABO_match 0.787 1.961 410 no
13 CMV_status -0.311 2.545 764 yes
5 recipient_gender -0.148 0.860 321 yes
7 recipient_ABO -1.306 2.017 622 no
9 recipient_CMV -0.415 0.763 271 no
15 stem_cell_source -0.505 0.797 246 no
18 acute_GvHD_III_IV -1.200 0.874 281 yes
14 risk_group -1.477 0.791 285 no
11 disease_group -2.129 0.444 178 no
16 tx_post_relapse -2.951 0.697 184 no
decrease_ac+gini
19 14.435
4 8.710
6 8.029
17 6.509
2 5.531
8 5.058
22 4.266
1 4.074
10 3.955
3 3.513
20 3.403
21 3.334
12 2.748
13 2.234
5 0.712
7 0.711
9 0.348
15 0.292
18 -0.326
14 -0.686
11 -1.685
16 -2.254
Table showing variable importance for Random Forest (decrese_ac+gini) and genetic algorithm (ga [yes|no]). Rows are ordered from higher to lower decrese_ac+gini. Relapse, recipient_body_mass, ANC_recovery_range and donor_age are common five most important variables for both algorithms.
4. Conclusions
Classical Machine Learning models studied do not show significative difference on accuracy of predicting survival after pediatric HSCT on the cohort studied. However, Logistic Regression and Naive bayes presented higher AUC values when compared to the other models.
Naive Bayes presented the highest sensibility. Naive Bayes and K-Nearest Neighbors had lower Specificity.