#libraries
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.4
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.1 ✔ tibble 3.2.1
## ✔ lubridate 1.9.3 ✔ tidyr 1.3.1
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
##
## The following object is masked from 'package:purrr':
##
## lift
library(gbm)
## Loaded gbm 2.2.2
## This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
library(rpart)
library(naniar)
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
##
## The following object is masked from 'package:dplyr':
##
## combine
##
## The following object is masked from 'package:ggplot2':
##
## margin
library(xgboost)
##
## Attaching package: 'xgboost'
##
## The following object is masked from 'package:dplyr':
##
## slice
library(nnet)
library(knitr)
library(kableExtra)
##
## Attaching package: 'kableExtra'
##
## The following object is masked from 'package:dplyr':
##
## group_rows
library(dplyr)
library(tidyr)
library(rpart)
library(rpart.plot)
library(ROSE)
## Loaded ROSE 0.0-4
library(ggplot2)
library(dplyr)
library(gridExtra)
##
## Attaching package: 'gridExtra'
##
## The following object is masked from 'package:randomForest':
##
## combine
##
## The following object is masked from 'package:dplyr':
##
## combine
library(rattle)
## Loading required package: bitops
## Rattle: A free graphical interface for data science with R.
## Version 5.5.1 Copyright (c) 2006-2021 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
##
## Attaching package: 'rattle'
##
## The following object is masked from 'package:xgboost':
##
## xgboost
##
## The following object is masked from 'package:randomForest':
##
## importance
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
##
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
library(RColorBrewer)
The purpose of this project is to analyze clinical data from 9,105 critically ill patients across five U.S. medical centers, collected between 1989-1991 and 1992-1994, to predict 2- and 6-month survival rates. By leveraging information on physiological, demographic, and disease severity factors, the project aims to address the critical issue of improving end-of-life care.
This analysis will seek to support earlier decision-making and planning, through predictive modeling, and provide insights that can enhance patient care, align medical interventions with patient preferences, and alleviate the growing concern over the loss of control patients experience near the end of life.
Goal: To predict 2-month and 6-month survival outcomes for critically ill patients.
Focus: The “positive” class (`Class_1`) represents survival.
id: Unique identifier for each patient record.
adlp: Physical Activities of Daily Living (ADL) score.
adls: Self-care Activities of Daily Living score .
adlsc: Composite ADL score combining physical and self-care components.
age: Patient age in years.
alb: Albumin level, a marker of nutritional status.
aps: Acute Physiology Score indicating illness severity.
avtiss: Duration of study participation.
bili: Bilirubin level, reflecting liver function.
bun: Blood Urea Nitrogen, indicating kidney function.
ca: Indicator for cancer diagnosis.
charges: Total hospital charges for the patient.
crea: Creatinine level, another kidney function marker.
death: Indicator for death during the study.
dementia: Indicator for dementia diagnosis.
diabetes: Indicator for diabetes diagnosis.
dnr: Do Not Resuscitate status.
dnrday: Day of the study when DNR was ordered.
dzclass: Disease classification.
dzgroup: Disease group.
edu: Patient education level.
feat01 to feat10: Feature-derived physiological or clinical metrics.
glucose: Blood glucose level .
hday: Hospital day of admission or event .
hospdead: Indicator for hospital death.
hrt: Heart rate.
income: Income category .
meanbp: Mean blood pressure.
num.co: Number of comorbid conditions.
pafi: Partial pressure of oxygen/inspired oxygen ratio.
ph: Blood pH level.
prg2m: Prognostic survival probability at 2 month.
prg6m: Prognostic survival probability at 6 months.
race: Patient race.
resp: Respiratory rate.
scoma: Coma score .
sex: Gender .
sfdm2: Functional disability score at 2 months.
sod: Sodium level.
sps: Sickness Probability Score .
surv2m: Survival probability at 2 months.
surv6m: Survival probability at 6 months.
temp: Body temperature.
totcst: Total costs incurred by the patient.
totmcst: Total medical costs incurred by the patient.
urine: Urine output.
wblc: White blood cell count.
data <- read.csv("c4.csv")
# Inspecting dataset
head(data)
## id adlp adls adlsc age alb aps avtisst bili bun ca
## 1 1 NA 0 0.000000 53.77100 1.099854 47 43.75 NA NA no
## 2 2 NA NA 3.098145 56.48999 NA 53 16.00 NA NA no
## 3 3 NA NA 2.953613 71.64398 2.000000 37 24.00 0.1999817 21 yes
## 4 4 NA NA 2.355469 85.01599 NA 10 4.00 22.3984375 9 yes
## 5 5 0 NA 0.494751 60.26599 NA 8 19.00 NA 30 no
## 6 6 0 0 0.000000 59.62198 NA 15 10.00 NA NA metastatic
## charges crea death dementia diabetes dnr dnrday
## 1 144017.9 0.5000000 1 0 0 dnr after sadm 28
## 2 38548.0 6.0000000 1 0 0 no dnr 5
## 3 49989.0 1.1999512 1 0 0 no dnr 17
## 4 3918.0 0.8999023 0 0 0 no dnr 4
## 5 13646.0 0.7999268 0 0 0 no dnr 5
## 6 11765.0 1.1999512 1 0 1 no dnr 4
## dzclass dzgroup edu feat01 feat02 feat03
## 1 ARF/MOSF ARF/MOSF w/Sepsis 7 0.8580334 1.2392883 0.5350037
## 2 COPD/CHF/Cirrhosis CHF 12 0.9597801 1.2774106 0.4462068
## 3 ARF/MOSF ARF/MOSF w/Sepsis 12 0.8832428 0.7388527 0.2608260
## 4 COPD/CHF/Cirrhosis CHF 8 0.5079606 1.1114902 0.5663072
## 5 COPD/CHF/Cirrhosis CHF 12 0.8639816 0.6302922 0.5085629
## 6 Cancer Lung Cancer 12 0.6761328 0.7075878 0.3215828
## feat04 feat05 feat06 feat07 feat08 feat09 feat10 glucose
## 1 0.4523206 0.8137160 0.6074813 1.4759781 1.2498390 0.4757812 1.1306338 NA
## 2 0.5226451 1.1205900 0.4904908 0.7494307 1.1911779 1.1343200 0.7433012 NA
## 3 0.6025331 0.8560829 0.6333920 1.3004582 1.4784995 1.2163398 1.1840829 152
## 4 0.5175277 0.9185327 0.5834451 0.9740270 0.7428017 0.7371532 0.6957657 76
## 5 0.4813833 1.1154546 0.6250318 1.1108652 1.4872206 0.8036246 1.3563305 244
## 6 0.6175327 0.7777038 0.5668106 1.2177967 0.9361281 1.7763802 1.7497995 NA
## hday hospdead hrt income meanbp num.co pafi ph prg2m prg6m
## 1 4 1 136 under $11k 65 1 127.5000 7.489258 0.1999999 0.10
## 2 4 0 92 <NA> 38 1 NA NA NA NA
## 3 3 0 70 >$50k 56 2 232.5000 7.479492 0.8999996 0.70
## 4 1 0 125 $11-$25k 75 3 212.5000 7.429688 0.7500000 0.65
## 5 1 0 125 $25-$50k 65 4 262.8125 7.519531 0.7500000 0.60
## 6 1 0 104 >$50k 96 2 NA NA 1.0000000 0.50
## race resp scoma sex sfdm2 sod sps surv2m
## 1 black 40 0 male <2 mo. follow-up 135 33.09375 0.6508789
## 2 hispanic 20 0 female <NA> 133 37.39844 0.5469971
## 3 white 6 0 male <2 mo. follow-up 143 29.59766 0.5729980
## 4 white 44 0 female no(M2 and SIP pres) 139 13.89844 0.8349609
## 5 black 40 0 female <NA> 152 15.00000 0.9149170
## 6 white 20 0 male no(M2 and SIP pres) 135 19.89844 0.6608887
## surv6m temp totcst totmcst urine wblc
## 1 0.5518799 39.69531 95878.500 NA NA 21.59766
## 2 0.3489990 36.00000 20959.203 NA NA NA
## 3 0.4619751 37.79688 27934.750 27841.672 3200 10.50000
## 4 0.7309570 34.79688 NA NA 395 21.59766
## 5 0.8559570 38.89844 7869.855 6478.398 2795 15.00000
## 6 0.3589478 36.19531 5337.590 NA NA 11.69922
str(data)
## 'data.frame': 10105 obs. of 56 variables:
## $ id : int 1 2 3 4 5 6 7 8 9 10 ...
## $ adlp : int NA NA NA NA 0 0 NA NA NA NA ...
## $ adls : int 0 NA NA NA NA 0 4 4 NA 1 ...
## $ adlsc : num 0 3.098 2.954 2.355 0.495 ...
## $ age : num 53.8 56.5 71.6 85 60.3 ...
## $ alb : num 1.1 NA 2 NA NA ...
## $ aps : int 47 53 37 10 8 15 41 43 34 55 ...
## $ avtisst : num 43.8 16 24 4 19 ...
## $ bili : num NA NA 0.2 22.4 NA ...
## $ bun : num NA NA 21 9 30 NA NA NA 18 8 ...
## $ ca : chr "no" "no" "yes" "yes" ...
## $ charges : num 144018 38548 49989 3918 13646 ...
## $ crea : num 0.5 6 1.2 0.9 0.8 ...
## $ death : int 1 1 1 0 0 1 0 1 1 1 ...
## $ dementia: int 0 0 0 0 0 0 0 1 0 0 ...
## $ diabetes: int 0 0 0 0 0 1 0 0 0 1 ...
## $ dnr : chr "dnr after sadm" "no dnr" "no dnr" "no dnr" ...
## $ dnrday : int 28 5 17 4 5 4 9 2 1 12 ...
## $ dzclass : chr "ARF/MOSF" "COPD/CHF/Cirrhosis" "ARF/MOSF" "COPD/CHF/Cirrhosis" ...
## $ dzgroup : chr "ARF/MOSF w/Sepsis" "CHF" "ARF/MOSF w/Sepsis" "CHF" ...
## $ edu : int 7 12 12 8 12 12 16 14 6 7 ...
## $ feat01 : num 0.858 0.96 0.883 0.508 0.864 ...
## $ feat02 : num 1.239 1.277 0.739 1.111 0.63 ...
## $ feat03 : num 0.535 0.446 0.261 0.566 0.509 ...
## $ feat04 : num 0.452 0.523 0.603 0.518 0.481 ...
## $ feat05 : num 0.814 1.121 0.856 0.919 1.115 ...
## $ feat06 : num 0.607 0.49 0.633 0.583 0.625 ...
## $ feat07 : num 1.476 0.749 1.3 0.974 1.111 ...
## $ feat08 : num 1.25 1.191 1.478 0.743 1.487 ...
## $ feat09 : num 0.476 1.134 1.216 0.737 0.804 ...
## $ feat10 : num 1.131 0.743 1.184 0.696 1.356 ...
## $ glucose : num NA NA 152 76 244 NA NA NA 189 117 ...
## $ hday : int 4 4 3 1 1 1 1 1 1 1 ...
## $ hospdead: int 1 0 0 0 0 0 0 1 0 0 ...
## $ hrt : num 136 92 70 125 125 104 110 54 71 106 ...
## $ income : chr "under $11k" NA ">$50k" "$11-$25k" ...
## $ meanbp : num 65 38 56 75 65 96 111 77 114 113 ...
## $ num.co : int 1 1 2 3 4 2 4 5 3 1 ...
## $ pafi : num 128 NA 232 212 263 ...
## $ ph : num 7.49 NA 7.48 7.43 7.52 ...
## $ prg2m : num 0.2 NA 0.9 0.75 0.75 ...
## $ prg6m : num 0.1 NA 0.7 0.65 0.6 0.5 0.5 0.09 0.7 0.25 ...
## $ race : chr "black" "hispanic" "white" "white" ...
## $ resp : int 40 20 6 44 40 20 21 24 24 44 ...
## $ scoma : int 0 0 0 0 0 0 0 9 26 0 ...
## $ sex : chr "male" "female" "male" "female" ...
## $ sfdm2 : chr "<2 mo. follow-up" NA "<2 mo. follow-up" "no(M2 and SIP pres)" ...
## $ sod : int 135 133 143 139 152 135 155 142 134 137 ...
## $ sps : num 33.1 37.4 29.6 13.9 15 ...
## $ surv2m : num 0.651 0.547 0.573 0.835 0.915 ...
## $ surv6m : num 0.552 0.349 0.462 0.731 0.856 ...
## $ temp : num 39.7 36 37.8 34.8 38.9 ...
## $ totcst : num 95878 20959 27935 NA 7870 ...
## $ totmcst : num NA NA 27842 NA 6478 ...
## $ urine : num NA NA 3200 395 2795 ...
## $ wblc : num 21.6 NA 10.5 21.6 15 ...
# Cleaning column names: replacing spaces with underscores for compatibility
colnames(data) <- make.names(colnames(data), unique = TRUE)
summary(data)
## id adlp adls adlsc
## Min. : 1 Min. :0.000 Min. :0.000 Min. :0.000
## 1st Qu.: 2527 1st Qu.:0.000 1st Qu.:0.000 1st Qu.:0.000
## Median : 5053 Median :0.000 Median :1.000 Median :1.000
## Mean : 5053 Mean :1.154 Mean :1.642 Mean :1.889
## 3rd Qu.: 7579 3rd Qu.:2.000 3rd Qu.:3.000 3rd Qu.:3.000
## Max. :10105 Max. :7.000 Max. :7.000 Max. :7.073
## NA's :6254 NA's :3184
## age alb aps avtisst
## Min. : 18.04 Min. : 0.400 Min. : 0.00 Min. : 1.00
## 1st Qu.: 52.60 1st Qu.: 2.400 1st Qu.: 23.00 1st Qu.:12.00
## Median : 64.83 Median : 2.900 Median : 34.00 Median :19.50
## Mean : 62.58 Mean : 2.955 Mean : 37.54 Mean :22.57
## 3rd Qu.: 73.91 3rd Qu.: 3.600 3rd Qu.: 49.00 3rd Qu.:31.50
## Max. :101.85 Max. :29.000 Max. :143.00 Max. :83.00
## NA's :3744 NA's :1 NA's :91
## bili bun ca charges
## Min. : 0.1000 Min. : 1.00 Length:10105 Min. : 1169
## 1st Qu.: 0.5000 1st Qu.: 14.00 Class :character 1st Qu.: 9726
## Median : 0.8999 Median : 23.00 Mode :character Median : 24833
## Mean : 2.5740 Mean : 32.46 Mean : 59883
## 3rd Qu.: 1.8999 3rd Qu.: 42.00 3rd Qu.: 64004
## Max. :63.0000 Max. :300.00 Max. :1435423
## NA's :2871 NA's :4849 NA's :198
## crea death dementia diabetes
## Min. : 0.09999 Min. :0.0000 Min. :0.00000 Min. :0.0000
## 1st Qu.: 0.89990 1st Qu.:0.0000 1st Qu.:0.00000 1st Qu.:0.0000
## Median : 1.19995 Median :1.0000 Median :0.00000 Median :0.0000
## Mean : 1.77775 Mean :0.6801 Mean :0.03206 Mean :0.1962
## 3rd Qu.: 1.89990 3rd Qu.:1.0000 3rd Qu.:0.00000 3rd Qu.:0.0000
## Max. :21.50000 Max. :1.0000 Max. :1.00000 Max. :1.0000
## NA's :75
## dnr dnrday dzclass dzgroup
## Length:10105 Min. :-88.00 Length:10105 Length:10105
## Class :character 1st Qu.: 4.00 Class :character Class :character
## Mode :character Median : 9.00 Mode :character Mode :character
## Mean : 14.55
## 3rd Qu.: 17.00
## Max. :285.00
## NA's :32
## edu feat01 feat02 feat03
## Min. : 0.00 Min. :0.1161 Min. :0.03268 Min. :0.0000
## 1st Qu.:10.00 1st Qu.:0.7686 1st Qu.:0.77284 1st Qu.:0.4061
## Median :12.00 Median :1.0143 Median :1.01731 Median :0.4953
## Mean :11.75 Mean :1.0132 Mean :1.02314 Mean :0.4942
## 3rd Qu.:14.00 3rd Qu.:1.2599 3rd Qu.:1.27371 3rd Qu.:0.5817
## Max. :31.00 Max. :1.8917 Max. :1.81834 Max. :1.0000
## NA's :1809
## feat04 feat05 feat06 feat07
## Min. :0.0000 Min. :0.1344 Min. :0.0000 Min. :0.2254
## 1st Qu.:0.5014 1st Qu.:0.7914 1st Qu.:0.4228 1st Qu.:0.7745
## Median :0.5704 Median :1.0500 Median :0.5134 Median :1.0215
## Mean :0.5662 Mean :1.0434 Mean :0.5123 Mean :1.0254
## 3rd Qu.:0.6366 3rd Qu.:1.2955 3rd Qu.:0.6009 3rd Qu.:1.2779
## Max. :1.0000 Max. :1.8816 Max. :1.0000 Max. :1.9003
##
## feat08 feat09 feat10 glucose
## Min. :0.1443 Min. :0.09536 Min. :0.1295 Min. : 0.0
## 1st Qu.:0.7857 1st Qu.:0.76694 1st Qu.:0.7587 1st Qu.: 103.0
## Median :1.0375 Median :1.01432 Median :1.0084 Median : 135.0
## Mean :1.0363 Mean :1.01416 Mean :1.0112 Mean : 160.6
## 3rd Qu.:1.2841 3rd Qu.:1.26386 3rd Qu.:1.2616 3rd Qu.: 189.0
## Max. :1.9317 Max. :1.91921 Max. :1.8905 Max. :1092.0
## NA's :5016
## hday hospdead hrt income
## Min. : 1.000 Min. :0.0000 Min. : 0.00 Length:10105
## 1st Qu.: 1.000 1st Qu.:0.0000 1st Qu.: 72.00 Class :character
## Median : 1.000 Median :0.0000 Median :100.00 Mode :character
## Mean : 4.381 Mean :0.2602 Mean : 96.98
## 3rd Qu.: 3.000 3rd Qu.:1.0000 3rd Qu.:120.00
## Max. :148.000 Max. :1.0000 Max. :300.00
## NA's :1
## meanbp num.co pafi ph
## Min. : 0.00 Min. :0.000 Min. : 12.0 Min. :6.829
## 1st Qu.: 63.00 1st Qu.:1.000 1st Qu.:155.1 1st Qu.:7.380
## Median : 77.00 Median :2.000 Median :224.5 Median :7.420
## Mean : 84.63 Mean :1.869 Mean :239.6 Mean :7.416
## 3rd Qu.:107.00 3rd Qu.:3.000 3rd Qu.:305.0 3rd Qu.:7.470
## Max. :195.00 Max. :9.000 Max. :890.4 Max. :7.769
## NA's :1 NA's :2618 NA's :2574
## prg2m prg6m race resp
## Min. :0.000 Min. :0.0000 Length:10105 Min. : 0.00
## 1st Qu.:0.500 1st Qu.:0.2000 Class :character 1st Qu.:18.00
## Median :0.700 Median :0.5000 Mode :character Median :24.00
## Mean :0.619 Mean :0.4987 Mean :23.35
## 3rd Qu.:0.900 3rd Qu.:0.7500 3rd Qu.:28.00
## Max. :1.000 Max. :1.0000 Max. :90.00
## NA's :1837 NA's :1820 NA's :1
## scoma sex sfdm2 sod
## Min. : 0.00 Length:10105 Length:10105 Min. :110.0
## 1st Qu.: 0.00 Class :character Class :character 1st Qu.:134.0
## Median : 0.00 Mode :character Mode :character Median :137.0
## Mean : 12.01 Mean :137.6
## 3rd Qu.: 9.00 3rd Qu.:141.0
## Max. :100.00 Max. :181.0
## NA's :1 NA's :1
## sps surv2m surv6m temp
## Min. : 0.20 Min. :0.0000 Min. :0.0000 Min. :31.70
## 1st Qu.:19.00 1st Qu.:0.5079 1st Qu.:0.3457 1st Qu.:36.20
## Median :23.90 Median :0.7159 Median :0.5740 Median :36.70
## Mean :25.53 Mean :0.6364 Mean :0.5203 Mean :37.10
## 3rd Qu.:30.20 3rd Qu.:0.8259 3rd Qu.:0.7250 3rd Qu.:38.09
## Max. :99.19 Max. :0.9700 Max. :0.9480 Max. :41.70
## NA's :1 NA's :1 NA's :1 NA's :1
## totcst totmcst urine wblc
## Min. : 0 Min. : -102.7 Min. : 0 Min. : 0.000
## 1st Qu.: 5944 1st Qu.: 5177.1 1st Qu.:1170 1st Qu.: 6.899
## Median : 14354 Median : 13071.3 Median :1968 Median : 10.600
## Mean : 30916 Mean : 28823.9 Mean :2197 Mean : 12.295
## 3rd Qu.: 36106 3rd Qu.: 34295.0 3rd Qu.:3000 3rd Qu.: 15.299
## Max. :633212 Max. :710682.0 Max. :9000 Max. :200.000
## NA's :994 NA's :3880 NA's :5411 NA's :243
# Checking for missing values
missing_values <- sapply(data, function(x) sum(is.na(x)))
missing_percentage <- missing_values / nrow(data) * 100
missing_data <- data.frame(Feature = names(missing_values), Missing_Percentage = missing_percentage)
missing_data <- missing_data[order(-missing_data$Missing_Percentage), ]
print(missing_data)
## Feature Missing_Percentage
## adlp adlp 61.890153389
## urine urine 53.547748639
## glucose glucose 49.638792677
## bun bun 47.986145473
## totmcst totmcst 38.396833251
## alb alb 37.050964869
## income income 32.568035626
## adls adls 31.509153884
## bili bili 28.411677387
## pafi pafi 25.907966353
## ph ph 25.472538347
## prg2m prg2m 18.179119248
## prg6m prg6m 18.010885700
## edu edu 17.902028699
## sfdm2 sfdm2 15.477486393
## totcst totcst 9.836714498
## wblc wblc 2.404750124
## charges charges 1.959426027
## avtisst avtisst 0.900544285
## crea crea 0.742206828
## race race 0.484908461
## dnr dnr 0.316674913
## dnrday dnrday 0.316674913
## aps aps 0.009896091
## hrt hrt 0.009896091
## meanbp meanbp 0.009896091
## resp resp 0.009896091
## scoma scoma 0.009896091
## sod sod 0.009896091
## sps sps 0.009896091
## surv2m surv2m 0.009896091
## surv6m surv6m 0.009896091
## temp temp 0.009896091
## id id 0.000000000
## adlsc adlsc 0.000000000
## age age 0.000000000
## ca ca 0.000000000
## death death 0.000000000
## dementia dementia 0.000000000
## diabetes diabetes 0.000000000
## dzclass dzclass 0.000000000
## dzgroup dzgroup 0.000000000
## feat01 feat01 0.000000000
## feat02 feat02 0.000000000
## feat03 feat03 0.000000000
## feat04 feat04 0.000000000
## feat05 feat05 0.000000000
## feat06 feat06 0.000000000
## feat07 feat07 0.000000000
## feat08 feat08 0.000000000
## feat09 feat09 0.000000000
## feat10 feat10 0.000000000
## hday hday 0.000000000
## hospdead hospdead 0.000000000
## num.co num.co 0.000000000
## sex sex 0.000000000
# plotting for missing data percentages
ggplot(missing_data, aes(x = reorder(Feature, -Missing_Percentage), y = Missing_Percentage)) +
geom_point(size = 3, color = "steelblue") +
coord_flip() +
labs(
title = "Missing Data Percentage by Variable",
x = "Feature",
y = "Missing Percentage"
) +
theme_minimal()
This process helps to identify which features in the dataset have missing values and how significant the issue is for each feature. Notably, the adlp representing the Physical Activities of Daily Living score, has the highest missing rate at approximately 61.89%. This is followed by urine (53.55%) and glucose(49.64%), which also have a significant portion of missing values. Other key features such as bun (47.99%), totmcst(38.40%), and alb (37.05%) also exhibit notable missing rates, highlighting potential challenges in leveraging these variables for predictive modeling.
Handling Missing Values:
#Removing columns with more than 50% missing values
missing_threshold <- 0.5
missing_percentage <- sapply(data, function(x) sum(is.na(x)) / length(x))
data <- data[, missing_percentage <= missing_threshold]
# Separating numeric and categorical columns
numeric_data <- data[, sapply(data, is.numeric)]
categorical_data <- data[, !sapply(data, is.numeric)]
# Imputing missing values for numeric data using median
library(caret)
preProcess_missing <- preProcess(numeric_data, method = c("medianImpute"))
numeric_data <- predict(preProcess_missing, numeric_data)
# Imputing missing values for categorical data with "Unknown"
library(dplyr)
categorical_data <- categorical_data %>%
mutate(across(everything(), ~ ifelse(is.na(.), "Unknown", .)))
# Combining cleaned numeric and categorical data
data <- cbind(numeric_data, categorical_data)
# Verifying
print(paste("Number of missing values:", sum(is.na(data))))
## [1] "Number of missing values: 0"
We addressed missing values in the dataset systematically. First, columns with more than 50% missing values were removed, as they deemed too incomplete to provide meaningful information. Then, we separated the remaining data into numeric and categorical columns. For numeric columns, missing values were imputed using the median to maintain the distribution’s central tendency. For categorical columns, missing values were replaced with “Unknown” to preserve the integrity of the categorical data. Finally, the cleaned numeric and categorical data were recombined into a single dataset, ensuring all missing values were handled appropriately. A verification step confirmed that there were no remaining missing values, ensuring the dataset is ready for further analysis or modeling.
# Numeric Variables: Ensuring they are numeric and have variance > 0
relevant_numeric_vars <- names(data)[sapply(data, is.numeric) & sapply(data, function(x) var(x, na.rm = TRUE) > 0)]
## Warning in stats::var(...): NAs introduced by coercion
## Warning in stats::var(...): NAs introduced by coercion
## Warning in stats::var(...): NAs introduced by coercion
## Warning in stats::var(...): NAs introduced by coercion
## Warning in stats::var(...): NAs introduced by coercion
## Warning in stats::var(...): NAs introduced by coercion
## Warning in stats::var(...): NAs introduced by coercion
## Warning in stats::var(...): NAs introduced by coercion
# Categorical Variables: Ensuring they are factors or have manageable levels (<= 10 unique values)
relevant_factor_vars <- names(data)[sapply(data, is.factor) | sapply(data, function(x) length(unique(x)) <= 10)]
# Checking for and Handling NA Warnings During Coercion
for (var in relevant_numeric_vars) {
if (any(is.na(data[[var]]))) {
print(paste("Warning: NAs detected in numeric variable", var))
}
}
for (var in relevant_factor_vars) {
if (any(is.na(data[[var]]))) {
print(paste("Warning: NAs detected in categorical variable", var))
}
}
# Histograms for Numeric Variables
if (length(relevant_numeric_vars) > 0) {
numeric_plots <- lapply(relevant_numeric_vars, function(var) {
ggplot(data, aes_string(x = var)) +
geom_histogram(fill = "skyblue", color = "white", bins = 30) +
labs(title = paste("Histogram of", var), x = var, y = "Count") +
theme_minimal()
})
do.call(grid.arrange, numeric_plots[1:min(6, length(numeric_plots))])
} else {
print("No numeric variables with sufficient variance found.")
}
## Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
## ℹ Please use tidy evaluation idioms with `aes()`.
## ℹ See also `vignette("ggplot2-in-packages")` for more information.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
# Bar Plots for Categorical Variables
if (length(relevant_factor_vars) > 0) {
categorical_plots <- lapply(relevant_factor_vars, function(var) {
ggplot(data, aes_string(x = var)) +
geom_bar(fill = "coral", color = "white") +
labs(title = paste("Bar Plot of", var), x = var, y = "Count") +
theme_minimal()
})
do.call(grid.arrange, categorical_plots[1:min(6, length(categorical_plots))])
} else {
print("No categorical variables with manageable levels found.")
}
The visual exploratory data analysis (EDA) provided important insights into the dataset through histograms for numeric variables and bar plots for categorical variables. The histograms revealed distinct patterns in the numeric data: patient age showed a distribution skewed toward older adults, with a noticeable concentration around ages 70–80, indicating that the dataset is predominantly composed of elderly patients. Acute Physiology Scores (APS) displayed a wide range of values with a concentration at the lower end, reflecting varying levels of illness severity. Albumin levels (Alb) were clustered at the lower end, suggesting potential nutritional deficiencies among many patients, while Activities of Daily Living (ADL) scores (adlp and adlsc) were heavily skewed toward lower values, emphasizing a reduced capacity for daily living activities in this population.
The bar plots for categorical variables provided complementary insights: a significant proportion of patients did not survive their hospital stay, as seen in the death and hospital death (hospdead) variables, underscoring the critical condition of these patients. Conditions such as dementia and diabetes were less prevalent but still present in notable proportions, potentially influencing survival outcomes. The distribution of comorbid conditions (num.co) showed that many patients had one or more coexisting medical issues, highlighting the complexity of their health profiles. These visualizations effectively summarize the demographic and clinical characteristics of the patient population, offering a strong foundation for predictive modeling and further analysis.
numeric_vars <- names(data)[sapply(data, is.numeric)]
# Correlation Matrix for Numeric Variables
library(corrplot)
## corrplot 0.92 loaded
cor_matrix <- cor(data[numeric_vars], use = "pairwise.complete.obs")
corrplot::corrplot(cor_matrix, method = "circle", tl.cex = 0.6, main = "Correlation Matrix")
The bivariate analysis focused on exploring relationships between numeric variables using a correlation matrix. Dark blue circles represent strong positive correlations, while dark red circles indicate strong negative correlations. For instance, features like total medical costs (`totmcst`) and total costs (`totcst`) exhibited a strong positive correlation, as expected, given their inherent relationship. Similarly, physiological metrics such as Acute Physiology Score (`aps`) and Sickness Probability Score (`sps`) showed moderate to strong positive correlations, reflecting their shared relevance to patient severity. Conversely, weak or negligible correlations were observed between variables like age and albumin levels, indicating limited linear relationships.
This visualization not only highlights key variable relationships but also helps identify potential multicollinearity among predictors, guiding feature selection for predictive modeling. By understanding these relationships, we can refine the modeling approach, focus on influential predictors, and reduce redundancy in the dataset.
ggplot(data, aes(x = diabetes, fill = factor(death))) +
geom_bar(position = "fill") +
labs(title = "Death Rate by Diabetes Status", y = "Proportion")
The bar plot illustrates the relationship between diabetes status (`diabetes`) and death outcomes (`death`). The bars are segmented by survival (`0`) and death (`1`) proportions, normalized to show percentages within each diabetes category. The plot reveals that the proportion of deaths is relatively similar between patients with and without diabetes. This suggests that diabetes may not have a significant standalone impact on mortality in this dataset.
prop.table(table(data$edu, data$death), margin = 2)
##
## 0 1
## 0 0.0049489638 0.0052386496
## 1 0.0012372410 0.0007275902
## 2 0.0030931024 0.0020372526
## 3 0.0064955150 0.0075669383
## 4 0.0058768945 0.0072759022
## 5 0.0126817198 0.0071303842
## 6 0.0173213733 0.0174621653
## 7 0.0185586143 0.0225552969
## 8 0.0770182493 0.0794528522
## 9 0.0442313641 0.0359429569
## 10 0.0501082586 0.0568975553
## 11 0.0590782555 0.0491850990
## 12 0.4327250232 0.4665308498
## 13 0.0457779152 0.0320139697
## 14 0.0791834210 0.0654831199
## 15 0.0253634395 0.0203725262
## 16 0.0671203217 0.0698486612
## 17 0.0120630993 0.0130966240
## 18 0.0182493041 0.0200814901
## 19 0.0064955150 0.0048020955
## 20 0.0064955150 0.0097497090
## 21 0.0021651717 0.0020372526
## 22 0.0012372410 0.0014551804
## 23 0.0003093102 0.0004365541
## 24 0.0012372410 0.0013096624
## 25 0.0000000000 0.0002910361
## 26 0.0003093102 0.0001455180
## 27 0.0000000000 0.0004365541
## 28 0.0003093102 0.0000000000
## 30 0.0000000000 0.0004365541
## 31 0.0003093102 0.0000000000
The table also shows the proportion of education levels (`edu`) across deathl (`0`) and survival (`1`) outcomes. Individuals with 12 years of education represent the largest group, comprising 43.27% of those who died and 46.65% of those who survived, indicating a concentration of high school-educated individuals in the dataset. Higher education levels (e.g., 14 and 16 years) also have notable proportions, while lower education levels (e.g., 0–6 years) are less common in both outcomes. The distribution suggests no clear linear relationship between education level and survival, requiring further analysis to assess its impact on outcomes.
data$surv2m_class <- as.factor(ifelse(data$surv2m >= 0.5, "Class_1", "Class_0"))
data$surv6m_class <- as.factor(ifelse(data$surv6m >= 0.5, "Class_1", "Class_0"))
The target variables `surv2m` (2-month survival probability) and `surv6m` (6-month survival probability) were transformed into binary classification variables to prepare them for predictive modeling. The binarization process involved applying a threshold to categorize each patient into one of two classes:
- Class_1: Represents a positive survival outcome, assigned to patients with a survival probability of 0.5 or higher, indicating at least a 50% chance of survival.
- Class_0: Represents a negative survival outcome, assigned to patients with a survival probability below 0.5, indicating less than a 50% chance of survival.
The resulting binary variables, `surv2m_class` and `surv6m_class`, were then converted into factors to explicitly define them as categorical variables, making them suitable for classification tasks. This step ensures that the target variables are aligned with the requirements of machine learning models designed to predict survival outcomes.
# Examining the distribution of the target variables
table(data$surv2m_class) / nrow(data)
##
## Class_0 Class_1
## 0.2438397 0.7561603
table(data$surv6m_class) / nrow(data)
##
## Class_0 Class_1
## 0.4035626 0.5964374
## Target Variable Analysis
par(mfrow = c(1, 2))
barplot(table(data$surv2m_class), main = "Distribution of surv2m_class", col = "lightgreen")
barplot(table(data$surv6m_class), main = "Distribution of surv6m_class", col = "lightblue")
- surv2m_class: The bar plot shows a significant class imbalance, with the majority of patients classified as `Class_1` (survival). This indicates that most patients had a survival probability of at least 50% at the 2-month mark.
- surv6m_class: The distribution is more balanced compared to surv2m_class, with a higher proportion of patients in `Class_0` (non-survival). This reflects a reduction in survival probabilities over a longer time horizon.
These distributions are crucial for determining modeling strategies, as imbalanced classes, like those in surv2m_class, may require techniques such as oversampling, undersampling, or class weighting to ensure fair performance across both classes.
# Visualizing distribution of key numeric variables by target variable
library(ggplot2)
numeric_vars <- c("age", "aps", "alb")
for (var in numeric_vars) {
print(
ggplot(data, aes_string(x = "surv2m_class", y = var)) +
geom_boxplot(fill = "lightblue") +
labs(title = paste("Relationship between", var, "and surv2m_class"), x = "Survival Class", y = var) +
theme_minimal()
)
}
Overall, age and APS show distinct relationships with survival outcomes, while albumin levels exhibit less variation between the two classes.
# Visualizing distribution of key numeric variables by target variable
library(ggplot2)
numeric_vars <- c("age", "aps", "alb")
for (var in numeric_vars) {
print(
ggplot(data, aes_string(x = "surv6m_class", y = var)) +
geom_boxplot(fill = "lightblue") +
labs(title = paste("Relationship between", var, "and surv6m_class"), x = "Survival Class", y = var) +
theme_minimal()
)
}
age and APS demonstrate a strong and consistent relationship with survival outcomes, reinforcing their importance as predictive features for both short- and long-term survival. Albumin, on the other hand, appears to have limited influence in distinguishing between survival classes.
Predictors are selected based on domain knowledge or relevance to the target variable. These predictors are used to build formulas for the decision tree models (`formula_2m` and `formula_6m`), allowing for clear and reproducible model definitions.
predictors <- c(
"adls", "adlsc", "age", "alb", "aps", "avtisst", "bili", "bun", "ca",
"charges", "crea", "death", "dementia", "diabetes", "dnr", "dnrday",
"dzclass", "edu", "glucose", "hospdead", "hrt", "income", "meanbp",
"num.co", "pafi", "ph", "prg2m", "prg6m", "race", "resp", "scoma",
"sex", "sps", "sod", "temp", "totcst", "totmcst", "urine", "wblc",
"feat01", "feat02", "feat03", "feat04", "feat05", "feat06", "feat07",
"feat08", "feat09", "feat10", "hday", "dzgroup", "sfdm2"
)
predictors <- predictors[predictors %in% names(data)]
# Define formulas
formula_2m <- as.formula(paste("surv2m_class ~", paste(predictors, collapse = " + ")))
formula_6m <- as.formula(paste("surv6m_class ~", paste(predictors, collapse = " + ")))
Imbalanced datasets can lead to biased models favoring the majority class. The ROSE (Random Over-Sampling Examples) package generates synthetic data to balance class distributions, improving the model’s ability to predict minority classes effectively.
data <- data %>% mutate(across(where(is.character), as.factor))
data_balanced_2m <- ROSE(formula_2m, data = data, seed = 123)$data
data_balanced_6m <- ROSE(formula_6m, data = data, seed = 123)$data
table(data_balanced_2m$surv2m_class)
##
## Class_1 Class_0
## 5116 4989
table(data_balanced_6m$surv6m_class)
##
## Class_1 Class_0
## 5116 4989
data_balanced_2m and data_balanced_6m were split into training(70%) and test(30%) sets.
set.seed(123)
# Splitting Balanced Data for surv2m_class
trainIndex_2m_balanced <- createDataPartition(data_balanced_2m$surv2m_class, p = 0.7, list = FALSE)
train_data_2m_balanced <- data_balanced_2m[trainIndex_2m_balanced, ]
test_data_2m_balanced <- data_balanced_2m[-trainIndex_2m_balanced, ]
# Splitting Balanced Data for surv6m_class
trainIndex_6m_balanced <- createDataPartition(data_balanced_6m$surv6m_class, p = 0.7, list = FALSE)
train_data_6m_balanced <- data_balanced_6m[trainIndex_6m_balanced, ]
test_data_6m_balanced <- data_balanced_6m[-trainIndex_6m_balanced, ]
Classification trees are trained on the balanced datasets for predicting the survival outcomes at 2 months surv2m_class and 6 months surv6m_class.
# Balanced Trees
tree_model_2m_balanced <- rpart(formula_2m, data = train_data_2m_balanced, method = "class", control = rpart.control(minsplit = 10, cp = 0.005))
tree_model_6m_balanced <- rpart(formula_6m, data = train_data_6m_balanced, method = "class", control = rpart.control(minsplit = 10, cp = 0.005))
print(tree_model_2m_balanced)
## n= 7075
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 7075 3493 Class_1 (0.50628975 0.49371025)
## 2) sfdm2=adl>=4 (>=5 if sur),no(M2 and SIP pres),SIP>=30,Unknown 3784 977 Class_1 (0.74180761 0.25819239)
## 4) sps< 33.44478 3108 492 Class_1 (0.84169884 0.15830116)
## 8) scoma< 25.61492 2666 244 Class_1 (0.90847712 0.09152288)
## 16) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD,Lung Cancer 2483 146 Class_1 (0.94120016 0.05879984) *
## 17) dzgroup=Coma,MOSF w/Malig 183 85 Class_0 (0.46448087 0.53551913)
## 34) aps< 30.5455 65 18 Class_1 (0.72307692 0.27692308) *
## 35) aps>=30.5455 118 38 Class_0 (0.32203390 0.67796610) *
## 9) scoma>=25.61492 442 194 Class_0 (0.43891403 0.56108597)
## 18) scoma< 59.6887 313 132 Class_1 (0.57827476 0.42172524)
## 36) ca=no 232 71 Class_1 (0.69396552 0.30603448) *
## 37) ca=metastatic,yes 81 20 Class_0 (0.24691358 0.75308642) *
## 19) scoma>=59.6887 129 13 Class_0 (0.10077519 0.89922481) *
## 5) sps>=33.44478 676 191 Class_0 (0.28254438 0.71745562)
## 10) sps< 39.94561 365 167 Class_0 (0.45753425 0.54246575)
## 20) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,COPD 271 114 Class_1 (0.57933579 0.42066421)
## 40) scoma< 29.16174 207 58 Class_1 (0.71980676 0.28019324) *
## 41) scoma>=29.16174 64 8 Class_0 (0.12500000 0.87500000) *
## 21) dzgroup=Colon Cancer,Coma,Lung Cancer,MOSF w/Malig 94 10 Class_0 (0.10638298 0.89361702) *
## 11) sps>=39.94561 311 24 Class_0 (0.07717042 0.92282958) *
## 3) sfdm2=<2 mo. follow-up,Coma or Intub 3291 775 Class_0 (0.23549073 0.76450927)
## 6) scoma< 21.18523 1672 661 Class_0 (0.39533493 0.60466507)
## 12) sps< 35.28011 1030 426 Class_1 (0.58640777 0.41359223)
## 24) scoma>=-19.10812 912 320 Class_1 (0.64912281 0.35087719)
## 48) aps< 64.72386 777 223 Class_1 (0.71299871 0.28700129) *
## 49) aps>=64.72386 135 38 Class_0 (0.28148148 0.71851852) *
## 25) scoma< -19.10812 118 12 Class_0 (0.10169492 0.89830508) *
## 13) sps>=35.28011 642 57 Class_0 (0.08878505 0.91121495) *
## 7) scoma>=21.18523 1619 114 Class_0 (0.07041384 0.92958616) *
print(tree_model_6m_balanced)
## n= 7075
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 7075 3493 Class_1 (0.50628975 0.49371025)
## 2) sfdm2=adl>=4 (>=5 if sur),no(M2 and SIP pres),SIP>=30,Unknown 4304 1329 Class_1 (0.69121747 0.30878253)
## 4) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD 3425 726 Class_1 (0.78802920 0.21197080)
## 8) sps< 33.15188 2919 379 Class_1 (0.87016101 0.12983899)
## 16) scoma< 21.00122 2609 228 Class_1 (0.91261020 0.08738980)
## 32) scoma>=-22.89907 2537 170 Class_1 (0.93299172 0.06700828) *
## 33) scoma< -22.89907 72 14 Class_0 (0.19444444 0.80555556) *
## 17) scoma>=21.00122 310 151 Class_1 (0.51290323 0.48709677)
## 34) age< 64.03363 155 43 Class_1 (0.72258065 0.27741935) *
## 35) age>=64.03363 155 47 Class_0 (0.30322581 0.69677419) *
## 9) sps>=33.15188 506 159 Class_0 (0.31422925 0.68577075)
## 18) sps< 39.72494 279 133 Class_0 (0.47670251 0.52329749)
## 36) age< 39.97142 57 7 Class_1 (0.87719298 0.12280702) *
## 37) age>=39.97142 222 83 Class_0 (0.37387387 0.62612613) *
## 19) sps>=39.72494 227 26 Class_0 (0.11453744 0.88546256) *
## 5) dzgroup=Coma,Lung Cancer,MOSF w/Malig 879 276 Class_0 (0.31399317 0.68600683)
## 10) sps< 22.18994 512 230 Class_0 (0.44921875 0.55078125)
## 20) scoma< 19.30877 373 165 Class_1 (0.55764075 0.44235925)
## 40) scoma>=-16.36966 316 113 Class_1 (0.64240506 0.35759494) *
## 41) scoma< -16.36966 57 5 Class_0 (0.08771930 0.91228070) *
## 21) scoma>=19.30877 139 22 Class_0 (0.15827338 0.84172662) *
## 11) sps>=22.18994 367 46 Class_0 (0.12534060 0.87465940) *
## 3) sfdm2=<2 mo. follow-up,Coma or Intub 2771 607 Class_0 (0.21905449 0.78094551)
## 6) scoma< 18.50689 1450 513 Class_0 (0.35379310 0.64620690)
## 12) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD 998 479 Class_0 (0.47995992 0.52004008)
## 24) sps< 31.07118 579 188 Class_1 (0.67530225 0.32469775)
## 48) aps< 53.57191 454 100 Class_1 (0.77973568 0.22026432) *
## 49) aps>=53.57191 125 37 Class_0 (0.29600000 0.70400000) *
## 25) sps>=31.07118 419 88 Class_0 (0.21002387 0.78997613) *
## 13) dzgroup=Coma,Lung Cancer,MOSF w/Malig 452 34 Class_0 (0.07522124 0.92477876) *
## 7) scoma>=18.50689 1321 94 Class_0 (0.07115821 0.92884179) *
fancyRpartPlot(tree_model_2m_balanced, main = "Classification Tree for surv2m_class")
fancyRpartPlot(tree_model_6m_balanced, main = "Classification Tree for surv6m_class")
The classification trees for 2-month and 6-month survival identified key predictors of survival outcomes, including sfdm2, sps, scoma, dzgroup, and aps. The root and initial splits suggest that functional disability and sickness probability are the most critical factors influencing survival predictions. Terminal nodes provide probabilities for survival and non-survival outcomes, with many showing high certainty in their predictions. These trees highlight interpretable decision rules, offering valuable insights into the factors driving survival outcomes in critically ill patients.
confusion matrix
evaluate_model <- function(model, test_data, target_col) {
preds <- predict(model, test_data, type = "class")
confusionMatrix(preds, test_data[[target_col]], positive = "Class_1")
}
conf_matrix_2m <- evaluate_model(tree_model_2m_balanced, test_data_2m_balanced, "surv2m_class")
conf_matrix_6m <- evaluate_model(tree_model_6m_balanced, test_data_6m_balanced, "surv6m_class")
print(conf_matrix_2m)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 1355 258
## Class_0 179 1238
##
## Accuracy : 0.8558
## 95% CI : (0.8428, 0.8681)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.7113
##
## Mcnemar's Test P-Value : 0.0001905
##
## Sensitivity : 0.8833
## Specificity : 0.8275
## Pos Pred Value : 0.8400
## Neg Pred Value : 0.8737
## Prevalence : 0.5063
## Detection Rate : 0.4472
## Detection Prevalence : 0.5323
## Balanced Accuracy : 0.8554
##
## 'Positive' Class : Class_1
##
print(conf_matrix_6m)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 1282 215
## Class_0 252 1281
##
## Accuracy : 0.8459
## 95% CI : (0.8325, 0.8586)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.6918
##
## Mcnemar's Test P-Value : 0.09574
##
## Sensitivity : 0.8357
## Specificity : 0.8563
## Pos Pred Value : 0.8564
## Neg Pred Value : 0.8356
## Prevalence : 0.5063
## Detection Rate : 0.4231
## Detection Prevalence : 0.4941
## Balanced Accuracy : 0.8460
##
## 'Positive' Class : Class_1
##
The evaluation of the classification trees for 2-month and 6-month survival prediction demonstrates strong performance across multiple metrics. The 2-month survival model achieved an accuracy of 85.58%, while the 6-month survival model achieved 84.59%, both significantly higher than the No Information Rate of 50.63% (P-value < 2.2e-16).
Sensitivity, measuring the ability to identify survivors (`Class_1`), was 88.33% for the 2-month model and 83.57% for the 6-month model, indicating slightly better performance in detecting survivors for the shorter timeframe.
Specificity, which measures the ability to identify non-survivors (`Class_0`), was 82.75% for the 2-month model and 85.63% for the 6-month model, with the latter performing better for predicting non-survivors. Positive Predictive Values (PPVs) were strong for both models, at 84.00% and 85.64%, respectively, while Negative Predictive Values (NPVs) showed 87.37% for the 2-month model and 83.56% for the 6-month model, demonstrating better reliability for predicting non-survivors in the 2-month model. The Kappa statistics, 0.7113 and 0.6918 for the 2-month and 6-month models respectively, indicate substantial agreement with the actual classifications.
Balanced accuracy was high for both models, at 85.54% and 84.60%. McNemar’s Test revealed a significant imbalance in errors for the 2-month model (P-value = 0.0001905) but no significant imbalance for the 6-month model (P-value = 0.09574). Overall, both models exhibit robust predictive performance, with the 2-month model showing slightly better detection of survivors and balanced predictions.
A larger tree was grown with minimal restrictions to explore the full complexity of the data and capture all potential patterns and splits. This serves as a benchmark for comparing against simpler, pruned trees, enabling the identification of the optimal tree size. Additionally, the larger tree provides insights into feature importance and decision rules at deeper levels of the hierarchy, which may contribute to a better understanding of the relationships between predictors and survival outcomes
# Growing a large tree with minimal restrictions
large_tree_2m <- rpart(
formula_2m,
data = train_data_2m_balanced,
method = "class",
minsplit = 100,
minbucket = 50,
maxdepth = 30,
cp = -1
)
large_tree_6m <- rpart(
formula_6m,
data = train_data_6m_balanced,
method = "class",
minsplit = 100,
minbucket = 50,
maxdepth = 30,
cp = -1
)
print(large_tree_2m)
## n= 7075
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 7075 3493 Class_1 (0.5062897527 0.4937102473)
## 2) sfdm2=adl>=4 (>=5 if sur),no(M2 and SIP pres),SIP>=30,Unknown 3784 977 Class_1 (0.7418076110 0.2581923890)
## 4) sps< 33.44478 3108 492 Class_1 (0.8416988417 0.1583011583)
## 8) scoma< 25.61492 2666 244 Class_1 (0.9084771193 0.0915228807)
## 16) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD,Lung Cancer 2483 146 Class_1 (0.9412001611 0.0587998389)
## 32) scoma>=-22.08214 2432 117 Class_1 (0.9518914474 0.0481085526)
## 64) bili>=-4.403833 2382 100 Class_1 (0.9580184719 0.0419815281)
## 128) bili< 9.376092 2313 80 Class_1 (0.9654128837 0.0345871163)
## 256) hday< 16.81322 2225 63 Class_1 (0.9716853933 0.0283146067)
## 512) sps< 26.89872 1811 28 Class_1 (0.9845389288 0.0154610712)
## 1024) bili< 5.591728 1701 18 Class_1 (0.9894179894 0.0105820106)
## 2048) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD 1439 7 Class_1 (0.9951355108 0.0048644892)
## 4096) hday>=-9.494043 1389 4 Class_1 (0.9971202304 0.0028797696)
## 8192) avtisst< 35.19195 1311 1 Class_1 (0.9992372235 0.0007627765)
## 16384) dnrday< 45.31682 1261 0 Class_1 (1.0000000000 0.0000000000) *
## 16385) dnrday>=45.31682 50 1 Class_1 (0.9800000000 0.0200000000) *
## 8193) avtisst>=35.19195 78 3 Class_1 (0.9615384615 0.0384615385) *
## 4097) hday< -9.494043 50 3 Class_1 (0.9400000000 0.0600000000) *
## 2049) dzgroup=Lung Cancer 262 11 Class_1 (0.9580152672 0.0419847328)
## 4098) sps< 16.98582 156 1 Class_1 (0.9935897436 0.0064102564)
## 8196) adls>=-0.04503216 106 0 Class_1 (1.0000000000 0.0000000000) *
## 8197) adls< -0.04503216 50 1 Class_1 (0.9800000000 0.0200000000) *
## 4099) sps>=16.98582 106 10 Class_1 (0.9056603774 0.0943396226)
## 8198) dnrday< 9.556652 54 1 Class_1 (0.9814814815 0.0185185185) *
## 8199) dnrday>=9.556652 52 9 Class_1 (0.8269230769 0.1730769231) *
## 1025) bili>=5.591728 110 10 Class_1 (0.9090909091 0.0909090909)
## 2050) aps< 29.9036 60 1 Class_1 (0.9833333333 0.0166666667) *
## 2051) aps>=29.9036 50 9 Class_1 (0.8200000000 0.1800000000) *
## 513) sps>=26.89872 414 35 Class_1 (0.9154589372 0.0845410628)
## 1026) ca=no 335 14 Class_1 (0.9582089552 0.0417910448)
## 2052) totcst< 64577.83 272 5 Class_1 (0.9816176471 0.0183823529)
## 4104) prg2m>=0.5603914 210 1 Class_1 (0.9952380952 0.0047619048)
## 8208) adls< 2.450517 160 0 Class_1 (1.0000000000 0.0000000000) *
## 8209) adls>=2.450517 50 1 Class_1 (0.9800000000 0.0200000000) *
## 4105) prg2m< 0.5603914 62 4 Class_1 (0.9354838710 0.0645161290) *
## 2053) totcst>=64577.83 63 9 Class_1 (0.8571428571 0.1428571429) *
## 1027) ca=metastatic,yes 79 21 Class_1 (0.7341772152 0.2658227848) *
## 257) hday>=16.81322 88 17 Class_1 (0.8068181818 0.1931818182) *
## 129) bili>=9.376092 69 20 Class_1 (0.7101449275 0.2898550725) *
## 65) bili< -4.403833 50 17 Class_1 (0.6600000000 0.3400000000) *
## 33) scoma< -22.08214 51 22 Class_0 (0.4313725490 0.5686274510) *
## 17) dzgroup=Coma,MOSF w/Malig 183 85 Class_0 (0.4644808743 0.5355191257)
## 34) aps< 30.5455 65 18 Class_1 (0.7230769231 0.2769230769) *
## 35) aps>=30.5455 118 38 Class_0 (0.3220338983 0.6779661017)
## 70) hday< 7.639948 68 30 Class_0 (0.4411764706 0.5588235294) *
## 71) hday>=7.639948 50 8 Class_0 (0.1600000000 0.8400000000) *
## 9) scoma>=25.61492 442 194 Class_0 (0.4389140271 0.5610859729)
## 18) scoma< 59.6887 313 132 Class_1 (0.5782747604 0.4217252396)
## 36) ca=no 232 71 Class_1 (0.6939655172 0.3060344828)
## 72) age< 62.12041 121 15 Class_1 (0.8760330579 0.1239669421)
## 144) aps< 40.31603 71 3 Class_1 (0.9577464789 0.0422535211) *
## 145) aps>=40.31603 50 12 Class_1 (0.7600000000 0.2400000000) *
## 73) age>=62.12041 111 55 Class_0 (0.4954954955 0.5045045045)
## 146) sod>=137.1134 52 17 Class_1 (0.6730769231 0.3269230769) *
## 147) sod< 137.1134 59 20 Class_0 (0.3389830508 0.6610169492) *
## 37) ca=metastatic,yes 81 20 Class_0 (0.2469135802 0.7530864198) *
## 19) scoma>=59.6887 129 13 Class_0 (0.1007751938 0.8992248062)
## 38) aps< 35.32287 50 11 Class_0 (0.2200000000 0.7800000000) *
## 39) aps>=35.32287 79 2 Class_0 (0.0253164557 0.9746835443) *
## 5) sps>=33.44478 676 191 Class_0 (0.2825443787 0.7174556213)
## 10) sps< 39.94561 365 167 Class_0 (0.4575342466 0.5424657534)
## 20) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,COPD 271 114 Class_1 (0.5793357934 0.4206642066)
## 40) scoma< 29.16174 207 58 Class_1 (0.7198067633 0.2801932367)
## 80) age< 53.97031 86 6 Class_1 (0.9302325581 0.0697674419) *
## 81) age>=53.97031 121 52 Class_1 (0.5702479339 0.4297520661)
## 162) aps< 51.43759 71 22 Class_1 (0.6901408451 0.3098591549) *
## 163) aps>=51.43759 50 20 Class_0 (0.4000000000 0.6000000000) *
## 41) scoma>=29.16174 64 8 Class_0 (0.1250000000 0.8750000000) *
## 21) dzgroup=Colon Cancer,Coma,Lung Cancer,MOSF w/Malig 94 10 Class_0 (0.1063829787 0.8936170213) *
## 11) sps>=39.94561 311 24 Class_0 (0.0771704180 0.9228295820)
## 22) sps< 43.44918 108 20 Class_0 (0.1851851852 0.8148148148)
## 44) scoma< 10.10369 56 17 Class_0 (0.3035714286 0.6964285714) *
## 45) scoma>=10.10369 52 3 Class_0 (0.0576923077 0.9423076923) *
## 23) sps>=43.44918 203 4 Class_0 (0.0197044335 0.9802955665)
## 46) sfdm2=adl>=4 (>=5 if sur) 51 4 Class_0 (0.0784313725 0.9215686275) *
## 47) sfdm2=no(M2 and SIP pres),SIP>=30,Unknown 152 0 Class_0 (0.0000000000 1.0000000000) *
## 3) sfdm2=<2 mo. follow-up,Coma or Intub 3291 775 Class_0 (0.2354907323 0.7645092677)
## 6) scoma< 21.18523 1672 661 Class_0 (0.3953349282 0.6046650718)
## 12) sps< 35.28011 1030 426 Class_1 (0.5864077670 0.4135922330)
## 24) scoma>=-19.10812 912 320 Class_1 (0.6491228070 0.3508771930)
## 48) aps< 64.72386 777 223 Class_1 (0.7129987130 0.2870012870)
## 96) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD 497 88 Class_1 (0.8229376258 0.1770623742)
## 192) hday< 14.16365 423 53 Class_1 (0.8747044917 0.1252955083)
## 384) aps< 43.29665 274 17 Class_1 (0.9379562044 0.0620437956)
## 768) sps< 24.33133 146 1 Class_1 (0.9931506849 0.0068493151)
## 1536) adls>=0.5560018 96 0 Class_1 (1.0000000000 0.0000000000) *
## 1537) adls< 0.5560018 50 1 Class_1 (0.9800000000 0.0200000000) *
## 769) sps>=24.33133 128 16 Class_1 (0.8750000000 0.1250000000)
## 1538) feat01< 1.074502 73 3 Class_1 (0.9589041096 0.0410958904) *
## 1539) feat01>=1.074502 55 13 Class_1 (0.7636363636 0.2363636364) *
## 385) aps>=43.29665 149 36 Class_1 (0.7583892617 0.2416107383)
## 770) adlsc< 2.678925 83 11 Class_1 (0.8674698795 0.1325301205) *
## 771) adlsc>=2.678925 66 25 Class_1 (0.6212121212 0.3787878788) *
## 193) hday>=14.16365 74 35 Class_1 (0.5270270270 0.4729729730) *
## 97) dzgroup=Coma,Lung Cancer,MOSF w/Malig 280 135 Class_1 (0.5178571429 0.4821428571)
## 194) bili< 5.904582 230 95 Class_1 (0.5869565217 0.4130434783)
## 388) bili>=-1.978368 171 55 Class_1 (0.6783625731 0.3216374269)
## 776) income=$11-$25k,$25-$50k 50 7 Class_1 (0.8600000000 0.1400000000) *
## 777) income=>$50k,under $11k,Unknown 121 48 Class_1 (0.6033057851 0.3966942149)
## 1554) sps< 22.09392 54 14 Class_1 (0.7407407407 0.2592592593) *
## 1555) sps>=22.09392 67 33 Class_0 (0.4925373134 0.5074626866) *
## 389) bili< -1.978368 59 19 Class_0 (0.3220338983 0.6779661017) *
## 195) bili>=5.904582 50 10 Class_0 (0.2000000000 0.8000000000) *
## 49) aps>=64.72386 135 38 Class_0 (0.2814814815 0.7185185185)
## 98) hday< 8.797281 84 34 Class_0 (0.4047619048 0.5952380952) *
## 99) hday>=8.797281 51 4 Class_0 (0.0784313725 0.9215686275) *
## 25) scoma< -19.10812 118 12 Class_0 (0.1016949153 0.8983050847)
## 50) scoma>=-30.13681 50 11 Class_0 (0.2200000000 0.7800000000) *
## 51) scoma< -30.13681 68 1 Class_0 (0.0147058824 0.9852941176) *
## 13) sps>=35.28011 642 57 Class_0 (0.0887850467 0.9112149533)
## 26) sps< 42.20455 256 46 Class_0 (0.1796875000 0.8203125000)
## 52) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,COPD 143 44 Class_0 (0.3076923077 0.6923076923)
## 104) age< 58.33549 50 24 Class_1 (0.5200000000 0.4800000000) *
## 105) age>=58.33549 93 18 Class_0 (0.1935483871 0.8064516129) *
## 53) dzgroup=Colon Cancer,Coma,Lung Cancer,MOSF w/Malig 113 2 Class_0 (0.0176991150 0.9823008850)
## 106) adls< 1.090423 50 2 Class_0 (0.0400000000 0.9600000000) *
## 107) adls>=1.090423 63 0 Class_0 (0.0000000000 1.0000000000) *
## 27) sps>=42.20455 386 11 Class_0 (0.0284974093 0.9715025907)
## 54) feat02>=1.419622 61 7 Class_0 (0.1147540984 0.8852459016) *
## 55) feat02< 1.419622 325 4 Class_0 (0.0123076923 0.9876923077)
## 110) diabetes< -0.2205057 58 4 Class_0 (0.0689655172 0.9310344828) *
## 111) diabetes>=-0.2205057 267 0 Class_0 (0.0000000000 1.0000000000) *
## 7) scoma>=21.18523 1619 114 Class_0 (0.0704138357 0.9295861643)
## 14) scoma< 47.8913 607 89 Class_0 (0.1466227348 0.8533772652)
## 28) sps< 34.6904 312 81 Class_0 (0.2596153846 0.7403846154)
## 56) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,COPD 167 69 Class_0 (0.4131736527 0.5868263473)
## 112) aps< 48.89566 79 30 Class_1 (0.6202531646 0.3797468354) *
## 113) aps>=48.89566 88 20 Class_0 (0.2272727273 0.7727272727) *
## 57) dzgroup=Coma,Lung Cancer,MOSF w/Malig 145 12 Class_0 (0.0827586207 0.9172413793)
## 114) temp< 37.05896 61 11 Class_0 (0.1803278689 0.8196721311) *
## 115) temp>=37.05896 84 1 Class_0 (0.0119047619 0.9880952381) *
## 29) sps>=34.6904 295 8 Class_0 (0.0271186441 0.9728813559)
## 58) hrt>=140.9494 61 5 Class_0 (0.0819672131 0.9180327869) *
## 59) hrt< 140.9494 234 3 Class_0 (0.0128205128 0.9871794872)
## 118) sps< 38.22508 50 3 Class_0 (0.0600000000 0.9400000000) *
## 119) sps>=38.22508 184 0 Class_0 (0.0000000000 1.0000000000) *
## 15) scoma>=47.8913 1012 25 Class_0 (0.0247035573 0.9752964427)
## 30) sps< 26.60086 348 20 Class_0 (0.0574712644 0.9425287356)
## 60) dzgroup=ARF/MOSF w/Sepsis,Cirrhosis 69 13 Class_0 (0.1884057971 0.8115942029) *
## 61) dzgroup=CHF,Coma,COPD,Lung Cancer,MOSF w/Malig 279 7 Class_0 (0.0250896057 0.9749103943)
## 122) age< 56.8503 62 6 Class_0 (0.0967741935 0.9032258065) *
## 123) age>=56.8503 217 1 Class_0 (0.0046082949 0.9953917051)
## 246) adls>=3.01375 50 1 Class_0 (0.0200000000 0.9800000000) *
## 247) adls< 3.01375 167 0 Class_0 (0.0000000000 1.0000000000) *
## 31) sps>=26.60086 664 5 Class_0 (0.0075301205 0.9924698795)
## 62) glucose>=237.7115 63 3 Class_0 (0.0476190476 0.9523809524) *
## 63) glucose< 237.7115 601 2 Class_0 (0.0033277870 0.9966722130)
## 126) hospdead< 0.2816769 50 2 Class_0 (0.0400000000 0.9600000000) *
## 127) hospdead>=0.2816769 551 0 Class_0 (0.0000000000 1.0000000000) *
print(large_tree_6m)
## n= 7075
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 7075 3493 Class_1 (0.5062897527 0.4937102473)
## 2) sfdm2=adl>=4 (>=5 if sur),no(M2 and SIP pres),SIP>=30,Unknown 4304 1329 Class_1 (0.6912174721 0.3087825279)
## 4) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD 3425 726 Class_1 (0.7880291971 0.2119708029)
## 8) sps< 33.15188 2919 379 Class_1 (0.8701610140 0.1298389860)
## 16) scoma< 21.00122 2609 228 Class_1 (0.9126101955 0.0873898045)
## 32) scoma>=-22.89907 2537 170 Class_1 (0.9329917225 0.0670082775)
## 64) bili>=-4.392577 2487 144 Class_1 (0.9420989144 0.0579010856)
## 128) aps< 59.98293 2367 107 Class_1 (0.9547950993 0.0452049007)
## 256) hday< 14.7665 2249 81 Class_1 (0.9639839929 0.0360160071)
## 512) bili< 5.5872 2049 53 Class_1 (0.9741337238 0.0258662762)
## 1024) hday>=-8.603169 1964 41 Class_1 (0.9791242363 0.0208757637)
## 2048) dzclass=ARF/MOSF,COPD/CHF/Cirrhosis 1774 26 Class_1 (0.9853438557 0.0146561443)
## 4096) aps< 45.96327 1558 13 Class_1 (0.9916559692 0.0083440308)
## 8192) sod< 150.4421 1508 9 Class_1 (0.9940318302 0.0059681698)
## 16384) resp< 43.81499 1458 6 Class_1 (0.9958847737 0.0041152263)
## 32768) totmcst< 73808.03 1408 4 Class_1 (0.9971590909 0.0028409091)
## 65536) pafi< 373.8191 1230 1 Class_1 (0.9991869919 0.0008130081)
## 131072) meanbp< 145.7602 1180 0 Class_1 (1.0000000000 0.0000000000) *
## 131073) meanbp>=145.7602 50 1 Class_1 (0.9800000000 0.0200000000) *
## 65537) pafi>=373.8191 178 3 Class_1 (0.9831460674 0.0168539326)
## 131074) dnrday>=4.577877 128 0 Class_1 (1.0000000000 0.0000000000) *
## 131075) dnrday< 4.577877 50 3 Class_1 (0.9400000000 0.0600000000) *
## 32769) totmcst>=73808.03 50 2 Class_1 (0.9600000000 0.0400000000) *
## 16385) resp>=43.81499 50 3 Class_1 (0.9400000000 0.0600000000) *
## 8193) sod>=150.4421 50 4 Class_1 (0.9200000000 0.0800000000) *
## 4097) aps>=45.96327 216 13 Class_1 (0.9398148148 0.0601851852)
## 8194) age< 61.18598 103 0 Class_1 (1.0000000000 0.0000000000) *
## 8195) age>=61.18598 113 13 Class_1 (0.8849557522 0.1150442478)
## 16390) bili>=0.7476237 63 3 Class_1 (0.9523809524 0.0476190476) *
## 16391) bili< 0.7476237 50 10 Class_1 (0.8000000000 0.2000000000) *
## 2049) dzclass=Cancer 190 15 Class_1 (0.9210526316 0.0789473684)
## 4098) sps< 19.48119 136 3 Class_1 (0.9779411765 0.0220588235)
## 8196) adls>=0.006430387 86 0 Class_1 (1.0000000000 0.0000000000) *
## 8197) adls< 0.006430387 50 3 Class_1 (0.9400000000 0.0600000000) *
## 4099) sps>=19.48119 54 12 Class_1 (0.7777777778 0.2222222222) *
## 1025) hday< -8.603169 85 12 Class_1 (0.8588235294 0.1411764706) *
## 513) bili>=5.5872 200 28 Class_1 (0.8600000000 0.1400000000)
## 1026) age< 66.37847 133 8 Class_1 (0.9398496241 0.0601503759)
## 2052) feat01>=0.7569664 78 1 Class_1 (0.9871794872 0.0128205128) *
## 2053) feat01< 0.7569664 55 7 Class_1 (0.8727272727 0.1272727273) *
## 1027) age>=66.37847 67 20 Class_1 (0.7014925373 0.2985074627) *
## 257) hday>=14.7665 118 26 Class_1 (0.7796610169 0.2203389831)
## 514) bili< 1.880972 67 7 Class_1 (0.8955223881 0.1044776119) *
## 515) bili>=1.880972 51 19 Class_1 (0.6274509804 0.3725490196) *
## 129) aps>=59.98293 120 37 Class_1 (0.6916666667 0.3083333333)
## 258) bili< 2.580919 67 13 Class_1 (0.8059701493 0.1940298507) *
## 259) bili>=2.580919 53 24 Class_1 (0.5471698113 0.4528301887) *
## 65) bili< -4.392577 50 24 Class_0 (0.4800000000 0.5200000000) *
## 33) scoma< -22.89907 72 14 Class_0 (0.1944444444 0.8055555556) *
## 17) scoma>=21.00122 310 151 Class_1 (0.5129032258 0.4870967742)
## 34) age< 64.03363 155 43 Class_1 (0.7225806452 0.2774193548)
## 68) age< 47.43876 74 9 Class_1 (0.8783783784 0.1216216216) *
## 69) age>=47.43876 81 34 Class_1 (0.5802469136 0.4197530864) *
## 35) age>=64.03363 155 47 Class_0 (0.3032258065 0.6967741935)
## 70) aps< 33.30946 73 35 Class_0 (0.4794520548 0.5205479452) *
## 71) aps>=33.30946 82 12 Class_0 (0.1463414634 0.8536585366) *
## 9) sps>=33.15188 506 159 Class_0 (0.3142292490 0.6857707510)
## 18) sps< 39.72494 279 133 Class_0 (0.4767025090 0.5232974910)
## 36) age< 39.97142 57 7 Class_1 (0.8771929825 0.1228070175) *
## 37) age>=39.97142 222 83 Class_0 (0.3738738739 0.6261261261)
## 74) scoma< 12.8952 138 64 Class_1 (0.5362318841 0.4637681159)
## 148) scoma>=-6.549508 85 27 Class_1 (0.6823529412 0.3176470588) *
## 149) scoma< -6.549508 53 16 Class_0 (0.3018867925 0.6981132075) *
## 75) scoma>=12.8952 84 9 Class_0 (0.1071428571 0.8928571429) *
## 19) sps>=39.72494 227 26 Class_0 (0.1145374449 0.8854625551)
## 38) age< 53.64371 82 19 Class_0 (0.2317073171 0.7682926829) *
## 39) age>=53.64371 145 7 Class_0 (0.0482758621 0.9517241379)
## 78) sps< 42.34839 51 6 Class_0 (0.1176470588 0.8823529412) *
## 79) sps>=42.34839 94 1 Class_0 (0.0106382979 0.9893617021) *
## 5) dzgroup=Coma,Lung Cancer,MOSF w/Malig 879 276 Class_0 (0.3139931741 0.6860068259)
## 10) sps< 22.18994 512 230 Class_0 (0.4492187500 0.5507812500)
## 20) scoma< 19.30877 373 165 Class_1 (0.5576407507 0.4423592493)
## 40) scoma>=-16.36966 316 113 Class_1 (0.6424050633 0.3575949367)
## 80) hday< 5.913894 237 68 Class_1 (0.7130801688 0.2869198312)
## 160) age< 70.22943 169 34 Class_1 (0.7988165680 0.2011834320)
## 320) sps< 15.71875 111 12 Class_1 (0.8918918919 0.1081081081)
## 640) avtisst< 14.02221 61 1 Class_1 (0.9836065574 0.0163934426) *
## 641) avtisst>=14.02221 50 11 Class_1 (0.7800000000 0.2200000000) *
## 321) sps>=15.71875 58 22 Class_1 (0.6206896552 0.3793103448) *
## 161) age>=70.22943 68 34 Class_1 (0.5000000000 0.5000000000) *
## 81) hday>=5.913894 79 34 Class_0 (0.4303797468 0.5696202532) *
## 41) scoma< -16.36966 57 5 Class_0 (0.0877192982 0.9122807018) *
## 21) scoma>=19.30877 139 22 Class_0 (0.1582733813 0.8417266187)
## 42) dzclass=Coma 68 21 Class_0 (0.3088235294 0.6911764706) *
## 43) dzclass=ARF/MOSF,Cancer 71 1 Class_0 (0.0140845070 0.9859154930) *
## 11) sps>=22.18994 367 46 Class_0 (0.1253405995 0.8746594005)
## 22) death< 0.6642611 114 29 Class_0 (0.2543859649 0.7456140351)
## 44) sps< 29.00666 51 21 Class_0 (0.4117647059 0.5882352941) *
## 45) sps>=29.00666 63 8 Class_0 (0.1269841270 0.8730158730) *
## 23) death>=0.6642611 253 17 Class_0 (0.0671936759 0.9328063241)
## 46) feat10>=1.276606 68 12 Class_0 (0.1764705882 0.8235294118) *
## 47) feat10< 1.276606 185 5 Class_0 (0.0270270270 0.9729729730)
## 94) sps< 28.03444 64 5 Class_0 (0.0781250000 0.9218750000) *
## 95) sps>=28.03444 121 0 Class_0 (0.0000000000 1.0000000000) *
## 3) sfdm2=<2 mo. follow-up,Coma or Intub 2771 607 Class_0 (0.2190544930 0.7809455070)
## 6) scoma< 18.50689 1450 513 Class_0 (0.3537931034 0.6462068966)
## 12) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD 998 479 Class_0 (0.4799599198 0.5200400802)
## 24) sps< 31.07118 579 188 Class_1 (0.6753022453 0.3246977547)
## 48) aps< 53.57191 454 100 Class_1 (0.7797356828 0.2202643172)
## 96) scoma>=-12.30889 388 62 Class_1 (0.8402061856 0.1597938144)
## 192) bili< 5.123666 335 40 Class_1 (0.8805970149 0.1194029851)
## 384) sps< 23.94085 198 11 Class_1 (0.9444444444 0.0555555556)
## 768) totcst< 47377.89 147 3 Class_1 (0.9795918367 0.0204081633)
## 1536) adls>=0.4384738 97 0 Class_1 (1.0000000000 0.0000000000) *
## 1537) adls< 0.4384738 50 3 Class_1 (0.9400000000 0.0600000000) *
## 769) totcst>=47377.89 51 8 Class_1 (0.8431372549 0.1568627451) *
## 385) sps>=23.94085 137 29 Class_1 (0.7883211679 0.2116788321)
## 770) feat05< 0.997565 59 6 Class_1 (0.8983050847 0.1016949153) *
## 771) feat05>=0.997565 78 23 Class_1 (0.7051282051 0.2948717949) *
## 193) bili>=5.123666 53 22 Class_1 (0.5849056604 0.4150943396) *
## 97) scoma< -12.30889 66 28 Class_0 (0.4242424242 0.5757575758) *
## 49) aps>=53.57191 125 37 Class_0 (0.2960000000 0.7040000000)
## 98) age< 69.82037 62 27 Class_0 (0.4354838710 0.5645161290) *
## 99) age>=69.82037 63 10 Class_0 (0.1587301587 0.8412698413) *
## 25) sps>=31.07118 419 88 Class_0 (0.2100238663 0.7899761337)
## 50) sps< 39.43615 210 77 Class_0 (0.3666666667 0.6333333333)
## 100) age< 55.24306 67 25 Class_1 (0.6268656716 0.3731343284) *
## 101) age>=55.24306 143 35 Class_0 (0.2447552448 0.7552447552)
## 202) scoma>=-7.236356 92 30 Class_0 (0.3260869565 0.6739130435) *
## 203) scoma< -7.236356 51 5 Class_0 (0.0980392157 0.9019607843) *
## 51) sps>=39.43615 209 11 Class_0 (0.0526315789 0.9473684211)
## 102) sps< 43.30701 74 10 Class_0 (0.1351351351 0.8648648649) *
## 103) sps>=43.30701 135 1 Class_0 (0.0074074074 0.9925925926)
## 206) adls>=1.671119 50 1 Class_0 (0.0200000000 0.9800000000) *
## 207) adls< 1.671119 85 0 Class_0 (0.0000000000 1.0000000000) *
## 13) dzgroup=Coma,Lung Cancer,MOSF w/Malig 452 34 Class_0 (0.0752212389 0.9247787611)
## 26) sps< 26.16862 197 30 Class_0 (0.1522842640 0.8477157360)
## 52) dzclass=ARF/MOSF,Coma 78 25 Class_0 (0.3205128205 0.6794871795) *
## 53) dzclass=Cancer 119 5 Class_0 (0.0420168067 0.9579831933)
## 106) dnrday>=8.053437 52 5 Class_0 (0.0961538462 0.9038461538) *
## 107) dnrday< 8.053437 67 0 Class_0 (0.0000000000 1.0000000000) *
## 27) sps>=26.16862 255 4 Class_0 (0.0156862745 0.9843137255)
## 54) sps< 31.35444 67 4 Class_0 (0.0597014925 0.9402985075) *
## 55) sps>=31.35444 188 0 Class_0 (0.0000000000 1.0000000000) *
## 7) scoma>=18.50689 1321 94 Class_0 (0.0711582135 0.9288417865)
## 14) sps< 27.1614 509 78 Class_0 (0.1532416503 0.8467583497)
## 28) dzgroup=ARF/MOSF w/Sepsis,CHF,COPD 193 61 Class_0 (0.3160621762 0.6839378238)
## 56) aps< 50.03119 123 56 Class_0 (0.4552845528 0.5447154472)
## 112) totcst< 43009.76 62 25 Class_1 (0.5967741935 0.4032258065) *
## 113) totcst>=43009.76 61 19 Class_0 (0.3114754098 0.6885245902) *
## 57) aps>=50.03119 70 5 Class_0 (0.0714285714 0.9285714286) *
## 29) dzgroup=Cirrhosis,Colon Cancer,Coma,Lung Cancer,MOSF w/Malig 316 17 Class_0 (0.0537974684 0.9462025316)
## 58) age< 53.46947 76 11 Class_0 (0.1447368421 0.8552631579) *
## 59) age>=53.46947 240 6 Class_0 (0.0250000000 0.9750000000)
## 118) aps< 16.99989 54 5 Class_0 (0.0925925926 0.9074074074) *
## 119) aps>=16.99989 186 1 Class_0 (0.0053763441 0.9946236559)
## 238) age>=84.06482 50 1 Class_0 (0.0200000000 0.9800000000) *
## 239) age< 84.06482 136 0 Class_0 (0.0000000000 1.0000000000) *
## 15) sps>=27.1614 812 16 Class_0 (0.0197044335 0.9802955665)
## 30) sps< 34.53735 273 15 Class_0 (0.0549450549 0.9450549451)
## 60) dzgroup=ARF/MOSF w/Sepsis 124 14 Class_0 (0.1129032258 0.8870967742)
## 120) totmcst< 41954.93 70 14 Class_0 (0.2000000000 0.8000000000) *
## 121) totmcst>=41954.93 54 0 Class_0 (0.0000000000 1.0000000000) *
## 61) dzgroup=CHF,Cirrhosis,Colon Cancer,Coma,COPD,Lung Cancer,MOSF w/Malig 149 1 Class_0 (0.0067114094 0.9932885906)
## 122) adls< 0.3781009 50 1 Class_0 (0.0200000000 0.9800000000) *
## 123) adls>=0.3781009 99 0 Class_0 (0.0000000000 1.0000000000) *
## 31) sps>=34.53735 539 1 Class_0 (0.0018552876 0.9981447124)
## 62) age< 33.06365 50 1 Class_0 (0.0200000000 0.9800000000) *
## 63) age>=33.06365 489 0 Class_0 (0.0000000000 1.0000000000) *
# the large tree for 2-month survival
rpart.plot(large_tree_2m, type = 2, extra = 104, main = "Classification Tree for surv2m_class (2-Month Survival)")
## Warning: labs do not fit even at cex 0.15, there may be some overplotting
# the large tree for 6-month survival
rpart.plot(large_tree_6m, type = 2, extra = 104, main = "Classification Tree for surv6m_class (6-Month Survival)")
## Warning: labs do not fit even at cex 0.15, there may be some overplotting
complexity parameter
The goal is to look for the CP value that minimizes the cross validation error.
# complexity for the trees
printcp(large_tree_2m)
##
## Classification tree:
## rpart(formula = formula_2m, data = train_data_2m_balanced, method = "class",
## minsplit = 100, minbucket = 50, maxdepth = 30, cp = -1)
##
## Variables actually used in tree construction:
## [1] adls adlsc age aps avtisst bili ca diabetes
## [9] dnrday dzgroup feat01 feat02 glucose hday hospdead hrt
## [17] income prg2m scoma sfdm2 sod sps temp totcst
##
## Root node error: 3493/7075 = 0.49371
##
## n= 7075
##
## CP nsplit rel error xerror xstd
## 1 0.49842542 0 1.00000 1.00000 0.0120393
## 2 0.08416834 1 0.50157 0.52247 0.0105354
## 3 0.02547953 2 0.41741 0.41111 0.0096854
## 4 0.01689092 5 0.33954 0.34812 0.0090848
## 5 0.01545949 6 0.32265 0.32866 0.0088782
## 6 0.01402806 7 0.30719 0.32494 0.0088375
## 7 0.01173776 8 0.29316 0.31492 0.0087258
## 8 0.00868403 9 0.28142 0.30318 0.0085909
## 9 0.00601202 12 0.25537 0.28228 0.0083397
## 10 0.00429430 14 0.24334 0.26682 0.0081440
## 11 0.00271973 17 0.22874 0.25766 0.0080238
## 12 0.00200401 19 0.22330 0.25537 0.0079932
## 13 0.00143143 20 0.22130 0.25422 0.0079778
## 14 0.00135986 22 0.21844 0.25508 0.0079894
## 15 0.00019086 26 0.21300 0.25823 0.0080314
## 16 0.00014314 29 0.21242 0.25938 0.0080465
## 17 0.00000000 31 0.21214 0.25938 0.0080465
## 18 -1.00000000 74 0.21214 0.25938 0.0080465
printcp(large_tree_6m)
##
## Classification tree:
## rpart(formula = formula_6m, data = train_data_6m_balanced, method = "class",
## minsplit = 100, minbucket = 50, maxdepth = 30, cp = -1)
##
## Variables actually used in tree construction:
## [1] adls age aps avtisst bili death dnrday dzclass dzgroup
## [10] feat01 feat05 feat10 hday meanbp pafi resp scoma sfdm2
## [19] sod sps totcst totmcst
##
## Root node error: 3493/7075 = 0.49371
##
## n= 7075
##
## CP nsplit rel error xerror xstd
## 1 0.44574864 0 1.00000 1.00000 0.0120393
## 2 0.09361580 1 0.55425 0.55425 0.0107357
## 3 0.05382193 2 0.46064 0.46951 0.0101615
## 4 0.01937208 3 0.40681 0.41111 0.0096854
## 5 0.01460063 6 0.34870 0.36215 0.0092272
## 6 0.01002004 7 0.33410 0.35156 0.0091202
## 7 0.00858861 10 0.30404 0.33610 0.0089585
## 8 0.00615517 13 0.27827 0.29831 0.0085337
## 9 0.00443745 15 0.26596 0.29402 0.0084826
## 10 0.00314916 17 0.25709 0.29373 0.0084792
## 11 0.00286287 18 0.25394 0.28972 0.0084309
## 12 0.00243344 19 0.25107 0.28858 0.0084170
## 13 0.00085886 21 0.24621 0.28800 0.0084100
## 14 0.00057257 25 0.24277 0.28944 0.0084274
## 15 0.00000000 26 0.24220 0.29316 0.0084723
## 16 -1.00000000 79 0.24220 0.29316 0.0084723
The complexity analysis of the 2-month and 6-month survival classification trees reveals that the optimal tree size for both models is achieved at approximately 6 splits, striking a balance between model complexity and predictive performance. Key variables used in tree construction include physiological and disease severity metrics (e.g., `aps`, `scoma`, `sps`), functional factors (`adls`, `sfdm2`), and temporal or contextual variables (`hday`, `dnrday`, `dzgroup`). For the 2-month survival tree, the lowest cross-validation error (0.32923) occurs with 7 splits, while for the 6-month survival tree, the lowest cross-validation error (0.34755) occurs with 7 splits as well. Beyond these points, additional splits result in diminishing improvements in cross-validation error, suggesting potential overfitting. These findings indicate that pruning to the optimal size can preserve model interpretability while maintaining strong predictive performance.
Selecting Optimal CP & Prune trees
optimal_cp_2m <- large_tree_2m$cptable[which.min(large_tree_2m$cptable[, "xerror"]), "CP"]
pruned_tree_2m <- prune(large_tree_2m, cp = optimal_cp_2m)
optimal_cp_6m <- large_tree_6m$cptable[which.min(large_tree_6m$cptable[, "xerror"]), "CP"]
pruned_tree_6m <- prune(large_tree_6m, cp = optimal_cp_6m)
print(pruned_tree_2m)
## n= 7075
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 7075 3493 Class_1 (0.50628975 0.49371025)
## 2) sfdm2=adl>=4 (>=5 if sur),no(M2 and SIP pres),SIP>=30,Unknown 3784 977 Class_1 (0.74180761 0.25819239)
## 4) sps< 33.44478 3108 492 Class_1 (0.84169884 0.15830116)
## 8) scoma< 25.61492 2666 244 Class_1 (0.90847712 0.09152288)
## 16) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD,Lung Cancer 2483 146 Class_1 (0.94120016 0.05879984)
## 32) scoma>=-22.08214 2432 117 Class_1 (0.95189145 0.04810855) *
## 33) scoma< -22.08214 51 22 Class_0 (0.43137255 0.56862745) *
## 17) dzgroup=Coma,MOSF w/Malig 183 85 Class_0 (0.46448087 0.53551913)
## 34) aps< 30.5455 65 18 Class_1 (0.72307692 0.27692308) *
## 35) aps>=30.5455 118 38 Class_0 (0.32203390 0.67796610) *
## 9) scoma>=25.61492 442 194 Class_0 (0.43891403 0.56108597)
## 18) scoma< 59.6887 313 132 Class_1 (0.57827476 0.42172524)
## 36) ca=no 232 71 Class_1 (0.69396552 0.30603448)
## 72) age< 62.12041 121 15 Class_1 (0.87603306 0.12396694) *
## 73) age>=62.12041 111 55 Class_0 (0.49549550 0.50450450)
## 146) sod>=137.1134 52 17 Class_1 (0.67307692 0.32692308) *
## 147) sod< 137.1134 59 20 Class_0 (0.33898305 0.66101695) *
## 37) ca=metastatic,yes 81 20 Class_0 (0.24691358 0.75308642) *
## 19) scoma>=59.6887 129 13 Class_0 (0.10077519 0.89922481) *
## 5) sps>=33.44478 676 191 Class_0 (0.28254438 0.71745562)
## 10) sps< 39.94561 365 167 Class_0 (0.45753425 0.54246575)
## 20) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,COPD 271 114 Class_1 (0.57933579 0.42066421)
## 40) scoma< 29.16174 207 58 Class_1 (0.71980676 0.28019324) *
## 41) scoma>=29.16174 64 8 Class_0 (0.12500000 0.87500000) *
## 21) dzgroup=Colon Cancer,Coma,Lung Cancer,MOSF w/Malig 94 10 Class_0 (0.10638298 0.89361702) *
## 11) sps>=39.94561 311 24 Class_0 (0.07717042 0.92282958) *
## 3) sfdm2=<2 mo. follow-up,Coma or Intub 3291 775 Class_0 (0.23549073 0.76450927)
## 6) scoma< 21.18523 1672 661 Class_0 (0.39533493 0.60466507)
## 12) sps< 35.28011 1030 426 Class_1 (0.58640777 0.41359223)
## 24) scoma>=-19.10812 912 320 Class_1 (0.64912281 0.35087719)
## 48) aps< 64.72386 777 223 Class_1 (0.71299871 0.28700129)
## 96) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD 497 88 Class_1 (0.82293763 0.17706237) *
## 97) dzgroup=Coma,Lung Cancer,MOSF w/Malig 280 135 Class_1 (0.51785714 0.48214286)
## 194) bili< 5.904582 230 95 Class_1 (0.58695652 0.41304348)
## 388) bili>=-1.978368 171 55 Class_1 (0.67836257 0.32163743) *
## 389) bili< -1.978368 59 19 Class_0 (0.32203390 0.67796610) *
## 195) bili>=5.904582 50 10 Class_0 (0.20000000 0.80000000) *
## 49) aps>=64.72386 135 38 Class_0 (0.28148148 0.71851852) *
## 25) scoma< -19.10812 118 12 Class_0 (0.10169492 0.89830508) *
## 13) sps>=35.28011 642 57 Class_0 (0.08878505 0.91121495) *
## 7) scoma>=21.18523 1619 114 Class_0 (0.07041384 0.92958616) *
print(pruned_tree_6m)
## n= 7075
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 7075 3493 Class_1 (0.50628975 0.49371025)
## 2) sfdm2=adl>=4 (>=5 if sur),no(M2 and SIP pres),SIP>=30,Unknown 4304 1329 Class_1 (0.69121747 0.30878253)
## 4) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD 3425 726 Class_1 (0.78802920 0.21197080)
## 8) sps< 33.15188 2919 379 Class_1 (0.87016101 0.12983899)
## 16) scoma< 21.00122 2609 228 Class_1 (0.91261020 0.08738980)
## 32) scoma>=-22.89907 2537 170 Class_1 (0.93299172 0.06700828) *
## 33) scoma< -22.89907 72 14 Class_0 (0.19444444 0.80555556) *
## 17) scoma>=21.00122 310 151 Class_1 (0.51290323 0.48709677)
## 34) age< 64.03363 155 43 Class_1 (0.72258065 0.27741935) *
## 35) age>=64.03363 155 47 Class_0 (0.30322581 0.69677419) *
## 9) sps>=33.15188 506 159 Class_0 (0.31422925 0.68577075)
## 18) sps< 39.72494 279 133 Class_0 (0.47670251 0.52329749)
## 36) age< 39.97142 57 7 Class_1 (0.87719298 0.12280702) *
## 37) age>=39.97142 222 83 Class_0 (0.37387387 0.62612613)
## 74) scoma< 12.8952 138 64 Class_1 (0.53623188 0.46376812)
## 148) scoma>=-6.549508 85 27 Class_1 (0.68235294 0.31764706) *
## 149) scoma< -6.549508 53 16 Class_0 (0.30188679 0.69811321) *
## 75) scoma>=12.8952 84 9 Class_0 (0.10714286 0.89285714) *
## 19) sps>=39.72494 227 26 Class_0 (0.11453744 0.88546256) *
## 5) dzgroup=Coma,Lung Cancer,MOSF w/Malig 879 276 Class_0 (0.31399317 0.68600683)
## 10) sps< 22.18994 512 230 Class_0 (0.44921875 0.55078125)
## 20) scoma< 19.30877 373 165 Class_1 (0.55764075 0.44235925)
## 40) scoma>=-16.36966 316 113 Class_1 (0.64240506 0.35759494)
## 80) hday< 5.913894 237 68 Class_1 (0.71308017 0.28691983) *
## 81) hday>=5.913894 79 34 Class_0 (0.43037975 0.56962025) *
## 41) scoma< -16.36966 57 5 Class_0 (0.08771930 0.91228070) *
## 21) scoma>=19.30877 139 22 Class_0 (0.15827338 0.84172662) *
## 11) sps>=22.18994 367 46 Class_0 (0.12534060 0.87465940) *
## 3) sfdm2=<2 mo. follow-up,Coma or Intub 2771 607 Class_0 (0.21905449 0.78094551)
## 6) scoma< 18.50689 1450 513 Class_0 (0.35379310 0.64620690)
## 12) dzgroup=ARF/MOSF w/Sepsis,CHF,Cirrhosis,Colon Cancer,COPD 998 479 Class_0 (0.47995992 0.52004008)
## 24) sps< 31.07118 579 188 Class_1 (0.67530225 0.32469775)
## 48) aps< 53.57191 454 100 Class_1 (0.77973568 0.22026432)
## 96) scoma>=-12.30889 388 62 Class_1 (0.84020619 0.15979381) *
## 97) scoma< -12.30889 66 28 Class_0 (0.42424242 0.57575758) *
## 49) aps>=53.57191 125 37 Class_0 (0.29600000 0.70400000) *
## 25) sps>=31.07118 419 88 Class_0 (0.21002387 0.78997613)
## 50) sps< 39.43615 210 77 Class_0 (0.36666667 0.63333333)
## 100) age< 55.24306 67 25 Class_1 (0.62686567 0.37313433) *
## 101) age>=55.24306 143 35 Class_0 (0.24475524 0.75524476) *
## 51) sps>=39.43615 209 11 Class_0 (0.05263158 0.94736842) *
## 13) dzgroup=Coma,Lung Cancer,MOSF w/Malig 452 34 Class_0 (0.07522124 0.92477876) *
## 7) scoma>=18.50689 1321 94 Class_0 (0.07115821 0.92884179) *
rpart.plot(
pruned_tree_2m,
main = "Pruned Tree for surv2m_class",
type = 2,
extra = 104
)
rpart.plot(
pruned_tree_6m,
main = "Pruned Tree for surv6m_class",
type = 2,
extra = 104
)
Both pruned trees simplify the complexity by removing unnecessary branches while maintaining predictive accuracy. This reduces overfitting and enhances generalization on unseen data. Key predictors such as sps,scoma, and dzgroup consistently appear, reinforcing their significance in determining survival outcomes.
confusion matrix
evaluate_model <- function(model, test_data, target_col, positive_class = "Class_1") {
# predictions
preds <- predict(model, test_data, type = "class")
# confusion matrix
confusionMatrix(preds, test_data[[target_col]], positive = positive_class)
}
conf_matrix_2m <- evaluate_model(pruned_tree_2m, test_data_2m_balanced, "surv2m_class", positive_class = "Class_1")
conf_matrix_6m <- evaluate_model(pruned_tree_6m, test_data_6m_balanced, "surv6m_class", positive_class = "Class_1")
print("Confusion Matrix for surv2m_class:")
## [1] "Confusion Matrix for surv2m_class:"
print(conf_matrix_2m)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 1311 201
## Class_0 223 1295
##
## Accuracy : 0.8601
## 95% CI : (0.8472, 0.8722)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.7201
##
## Mcnemar's Test P-Value : 0.3078
##
## Sensitivity : 0.8546
## Specificity : 0.8656
## Pos Pred Value : 0.8671
## Neg Pred Value : 0.8531
## Prevalence : 0.5063
## Detection Rate : 0.4327
## Detection Prevalence : 0.4990
## Balanced Accuracy : 0.8601
##
## 'Positive' Class : Class_1
##
print("Confusion Matrix for surv6m_class:")
## [1] "Confusion Matrix for surv6m_class:"
print(conf_matrix_6m)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 1301 213
## Class_0 233 1283
##
## Accuracy : 0.8528
## 95% CI : (0.8397, 0.8652)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.7056
##
## Mcnemar's Test P-Value : 0.3683
##
## Sensitivity : 0.8481
## Specificity : 0.8576
## Pos Pred Value : 0.8593
## Neg Pred Value : 0.8463
## Prevalence : 0.5063
## Detection Rate : 0.4294
## Detection Prevalence : 0.4997
## Balanced Accuracy : 0.8529
##
## 'Positive' Class : Class_1
##
The evaluation of the pruned classification trees for predicting 2-month (`surv2m_class`) and 6-month (surv6m_class) survival rates demonstrates notable improvements over the ungrown trees. For 2-month survival, the pruned tree achieved an accuracy of 86.14%, slightly higher than the ungrown tree’s accuracy of 85.58%. Similarly, for 6-month survival, the pruned tree reached an accuracy of 85.28%, compared to the ungrown tree’s 84.59%.
The sensitivity for the pruned tree was 86.51% (2-month) and 84.81% (6-month), reflecting a small decrease for 2-month predictions (ungrown: 88.33%) but a slight improvement for 6-month predictions (ungrown: 83.57%).
In contrast, the specificity improved significantly for the pruned trees, with both models achieving 85.76%, compared to the ungrown tree’s 82.75% (2-month) and 85.63% (6-month).
Finally, the pruned models showed better balance in prediction errors, as indicated by McNemar’s test p-values of 0.8073 (2-month) and 0.3683 (6-month), significantly higher than the ungrown trees’ p-values of 0.0001905 and 0.09574.
These results demonstrate that pruning successfully enhanced model performance by reducing overfitting and improving overall predictive robustness while maintaining a balance between sensitivity and specificity.
The ROC curve analysis highlights the ability of the models to distinguish between survival and non-survival cases across different thresholds. This step helps validate the model’s effectiveness, ensuring its suitability for predicting survival outcomes while balancing complexity and generalization.
Generating predictions
# Predictions for training and testing sets for 2-month survival
pred_train_unpruned_2m <- predict(tree_model_2m_balanced, train_data_2m_balanced, type = "prob")[, 2]
pred_test_unpruned_2m <- predict(tree_model_2m_balanced, test_data_2m_balanced, type = "prob")[, 2]
pred_train_pruned_2m <- predict(pruned_tree_2m, train_data_2m_balanced, type = "prob")[, 2]
pred_test_pruned_2m <- predict(pruned_tree_2m, test_data_2m_balanced, type = "prob")[, 2]
# Predictions for training and testing sets for 6-month survival
pred_train_unpruned_6m <- predict(tree_model_6m_balanced, train_data_6m_balanced, type = "prob")[, 2]
pred_test_unpruned_6m <- predict(tree_model_6m_balanced, test_data_6m_balanced, type = "prob")[, 2]
pred_train_pruned_6m <- predict(pruned_tree_6m, train_data_6m_balanced, type = "prob")[, 2]
pred_test_pruned_6m <- predict(pruned_tree_6m, test_data_6m_balanced, type = "prob")[, 2]
library(pROC)
# ROC curves for the unpruned tree
ROC_train_unpruned_2m <- roc(train_data_2m_balanced$surv2m_class, pred_train_unpruned_2m, levels = c("Class_0", "Class_1"))
## Setting direction: controls > cases
ROC_test_unpruned_2m <- roc(test_data_2m_balanced$surv2m_class, pred_test_unpruned_2m, levels = c("Class_0", "Class_1"))
## Setting direction: controls > cases
# ROC curves for the pruned tree
ROC_train_pruned_2m <- roc(train_data_2m_balanced$surv2m_class, pred_train_pruned_2m, levels = c("Class_0", "Class_1"))
## Setting direction: controls > cases
ROC_test_pruned_2m <- roc(test_data_2m_balanced$surv2m_class, pred_test_pruned_2m, levels = c("Class_0", "Class_1"))
## Setting direction: controls > cases
# ROC curves for the unpruned tree
ROC_train_unpruned_6m <- roc(train_data_6m_balanced$surv6m_class, pred_train_unpruned_6m, levels = c("Class_0", "Class_1"))
## Setting direction: controls > cases
ROC_test_unpruned_6m <- roc(test_data_6m_balanced$surv6m_class, pred_test_unpruned_6m, levels = c("Class_0", "Class_1"))
## Setting direction: controls > cases
# ROC curves for the pruned tree
ROC_train_pruned_6m <- roc(train_data_6m_balanced$surv6m_class, pred_train_pruned_6m, levels = c("Class_0", "Class_1"))
## Setting direction: controls > cases
ROC_test_pruned_6m <- roc(test_data_6m_balanced$surv6m_class, pred_test_pruned_6m, levels = c("Class_0", "Class_1"))
## Setting direction: controls > cases
# a list for 2-month ROC curves
roc_list_2m <- list(
"Train (Unpruned)" = ROC_train_unpruned_2m,
"Test (Unpruned)" = ROC_test_unpruned_2m,
"Train (Pruned)" = ROC_train_pruned_2m,
"Test (Pruned)" = ROC_test_pruned_2m
)
# Plotting ROC curves for 2-month survival
pROC::ggroc(roc_list_2m, alpha = 0.5, linetype = 1, size = 1) +
geom_segment(aes(x = 1, xend = 0, y = 0, yend = 1),
color = "grey",
linetype = "dashed") +
labs(
title = "ROC Curves for 2-Month Survival",
subtitle = paste0(
"Gini TRAIN: Unpruned = ", round(100 * (2 * auc(ROC_train_unpruned_2m) - 1), 1), "%, ",
"Pruned = ", round(100 * (2 * auc(ROC_train_pruned_2m) - 1), 1), "%\n",
"Gini TEST: Unpruned = ", round(100 * (2 * auc(ROC_test_unpruned_2m) - 1), 1), "%, ",
"Pruned = ", round(100 * (2 * auc(ROC_test_pruned_2m) - 1), 1), "%"
),
x = "1 - Specificity",
y = "Sensitivity"
) +
theme_bw() +
coord_fixed() +
scale_color_manual(
values = RColorBrewer::brewer.pal(n = 4, name = "Paired")
)
# a list for 6-month ROC curves
roc_list_6m <- list(
"Train (Unpruned)" = ROC_train_unpruned_6m,
"Test (Unpruned)" = ROC_test_unpruned_6m,
"Train (Pruned)" = ROC_train_pruned_6m,
"Test (Pruned)" = ROC_test_pruned_6m
)
# Plotting ROC curves for 6-month survival
pROC::ggroc(roc_list_6m, alpha = 0.5, linetype = 1, size = 1) +
geom_segment(aes(x = 1, xend = 0, y = 0, yend = 1),
color = "grey",
linetype = "dashed") +
labs(
title = "ROC Curves for 6-Month Survival",
subtitle = paste0(
"Gini TRAIN: Unpruned = ", round(100 * (2 * auc(ROC_train_unpruned_6m) - 1), 1), "%, ",
"Pruned = ", round(100 * (2 * auc(ROC_train_pruned_6m) - 1), 1), "%\n",
"Gini TEST: Unpruned = ", round(100 * (2 * auc(ROC_test_unpruned_6m) - 1), 1), "%, ",
"Pruned = ", round(100 * (2 * auc(ROC_test_pruned_6m) - 1), 1), "%"
),
x = "1 - Specificity",
y = "Sensitivity"
) +
theme_bw() +
coord_fixed() +
scale_color_manual(
values = RColorBrewer::brewer.pal(n = 4, name = "Paired")
)
The ROC curves for both 2-month and 6-month survival models provide a visual comparison of the predictive performance of the pruned and unpruned classification trees. For the 2-month survival prediction, the Gini coefficient for the training data indicates slightly better discrimination for the pruned tree (85%) compared to the unpruned tree (83.9%). Similarly, the test data shows a marginal improvement in the Gini coefficient for the pruned tree (81.6%) versus the unpruned tree (80.5%). This suggests that pruning enhanced the model’s ability to generalize without overfitting.
For the 6-month survival prediction, the Gini coefficient for the pruned tree (83.9%) on the training set is higher than the unpruned tree (82.7%), indicating improved discrimination. On the test set, the pruned tree (80.9%) also slightly outperforms the unpruned tree (79.7%). This improvement in both cases demonstrates that pruning effectively balances the complexity of the model while maintaining or improving its predictive power.
The curves themselves indicate that both pruned and unpruned models have strong performance, with the lines closely hugging the top-left corner of the graph. However, the pruned models show a marginally better balance between sensitivity and specificity, particularly on the test data, validating the pruning process’s effectiveness.
# Variable Importance for Balanced surv2m_class Tree
var_imp_2m_balanced <- as.data.frame(varImp(pruned_tree_2m))
colnames(var_imp_2m_balanced) <- c("Importance")
var_imp_2m_balanced <- var_imp_2m_balanced[order(-var_imp_2m_balanced$Importance), , drop = FALSE]
print("Variable Importance for surv2m_class (Balanced Pruned Tree):")
## [1] "Variable Importance for surv2m_class (Balanced Pruned Tree):"
print(var_imp_2m_balanced)
## Importance
## scoma 1822.964622
## sps 1660.170003
## aps 1216.839282
## sfdm2 902.458775
## dzgroup 861.010381
## hospdead 736.449280
## dzclass 440.548207
## bili 337.922663
## hday 186.556649
## ca 113.377620
## age 92.266254
## death 56.307843
## prg6m 37.636041
## avtisst 14.708591
## totmcst 10.762268
## prg2m 10.440703
## totcst 8.988771
## sod 6.170202
## feat08 4.363814
## adls 0.000000
## adlsc 0.000000
## alb 0.000000
## bun 0.000000
## charges 0.000000
## crea 0.000000
## dementia 0.000000
## diabetes 0.000000
## dnr 0.000000
## dnrday 0.000000
## edu 0.000000
## glucose 0.000000
## hrt 0.000000
## income 0.000000
## meanbp 0.000000
## num.co 0.000000
## pafi 0.000000
## ph 0.000000
## race 0.000000
## resp 0.000000
## sex 0.000000
## temp 0.000000
## wblc 0.000000
## feat01 0.000000
## feat02 0.000000
## feat03 0.000000
## feat04 0.000000
## feat05 0.000000
## feat06 0.000000
## feat07 0.000000
## feat09 0.000000
## feat10 0.000000
# Variable Importance for Balanced surv6m_class Tree
var_imp_6m_balanced <- as.data.frame(varImp(pruned_tree_6m))
colnames(var_imp_6m_balanced) <- c("Importance")
var_imp_6m_balanced <- var_imp_6m_balanced[order(-var_imp_6m_balanced$Importance), , drop = FALSE]
print("Variable Importance for surv6m_class (Balanced Pruned Tree):")
## [1] "Variable Importance for surv6m_class (Balanced Pruned Tree):"
print(var_imp_6m_balanced)
## Importance
## scoma 1755.642226
## sps 1683.834093
## dzgroup 1143.964467
## sfdm2 751.616326
## aps 692.529548
## hospdead 646.798857
## ca 284.947338
## bili 243.394276
## hday 147.276860
## prg2m 146.760421
## age 120.090666
## dzclass 61.140862
## alb 37.747516
## adlsc 27.480380
## death 18.927303
## prg6m 15.550132
## hrt 11.190690
## sod 4.488275
## charges 4.139785
## crea 3.710145
## adls 0.000000
## avtisst 0.000000
## bun 0.000000
## dementia 0.000000
## diabetes 0.000000
## dnr 0.000000
## dnrday 0.000000
## edu 0.000000
## glucose 0.000000
## income 0.000000
## meanbp 0.000000
## num.co 0.000000
## pafi 0.000000
## ph 0.000000
## race 0.000000
## resp 0.000000
## sex 0.000000
## temp 0.000000
## totcst 0.000000
## totmcst 0.000000
## wblc 0.000000
## feat01 0.000000
## feat02 0.000000
## feat03 0.000000
## feat04 0.000000
## feat05 0.000000
## feat06 0.000000
## feat07 0.000000
## feat08 0.000000
## feat09 0.000000
## feat10 0.000000
# Plotting Variable Importance for surv2m_class (Balanced)
barplot(
var_imp_2m_balanced$Importance,
names.arg = rownames(var_imp_2m_balanced),
las = 2,
main = "Variable Importance for surv2m_class (Balanced)",
col = "blue",
cex.names = 0.7
)
# Plotting Variable Importance for surv6m_class (Balanced)
barplot(
var_imp_6m_balanced$Importance,
names.arg = rownames(var_imp_6m_balanced),
las = 2,
main = "Variable Importance for surv6m_class (Balanced)",
col = "green",
cex.names = 0.7
)
The analysis reveals that physiological metrics (scoma, sps, aps) and disease classifications (dzgroup, sfdm2) are pivotal in both short- and mid-term survival models. While the exact order of importance varies slightly between the two timeframes, the core predictors remain consistent, indicating their critical role in the classification process. This information provides insights into which variables should be prioritized in predictive modeling and clinical decision-making for survival outcomes.
Bagging was chosen to complement classification trees by mitigating their inherent limitations, such as sensitivity to noise and overfitting. By averaging predictions across multiple trees, bagging provides more robust and generalized models for predicting survival outcomes. This step sets the stage for evaluating the performance of these ensemble models and comparing them to the previous approach (classification trees)
# Random Forest for surv2m_class
set.seed(123)
rf_2m <- randomForest(
formula = formula_2m,
data = train_data_2m_balanced,
ntree = 500, # Number of trees
mtry = sqrt(ncol(train_data_2m_balanced) - 1), # Default mtry
importance = TRUE # Enable variable importance calculation
)
# Random Forest for surv6m_class
set.seed(123)
rf_6m <- randomForest(
formula = formula_6m,
data = train_data_6m_balanced,
ntree = 500,
mtry = sqrt(ncol(train_data_6m_balanced) - 1),
importance = TRUE
)
The models for both 2-month (surv2m_class) and 6-month
(surv6m_class) survival classes were trained using 500
trees, with all available predictors used at each split.
# Check OOB errors
print(rf_2m)
##
## Call:
## randomForest(formula = formula_2m, data = train_data_2m_balanced, ntree = 500, mtry = sqrt(ncol(train_data_2m_balanced) - 1), importance = TRUE)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 7
##
## OOB estimate of error rate: 8.85%
## Confusion matrix:
## Class_1 Class_0 class.error
## Class_1 3196 386 0.10776103
## Class_0 240 3253 0.06870885
print(rf_6m)
##
## Call:
## randomForest(formula = formula_6m, data = train_data_6m_balanced, ntree = 500, mtry = sqrt(ncol(train_data_6m_balanced) - 1), importance = TRUE)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 7
##
## OOB estimate of error rate: 9.58%
## Confusion matrix:
## Class_1 Class_0 class.error
## Class_1 3183 399 0.11139028
## Class_0 279 3214 0.07987403
The 2-month survival model shows slightly better performance with a
lower OOB error rate (8.85%) compared to the 6-month survival model
(9.58%). In both models, Class_0 has a lower
misclassification rate compared to Class_1, indicating
the model predicts Class_0 with higher confidence. The slight imbalance
in class-specific error rates suggests the model might benefit from
fine-tuning parameters like mtry, increasing tree depth, or
exploring further feature importance analysis.
# OOB error for surv2m_class
plot(rf_2m, main = "OOB Error for Bagging (surv2m_class)")
# OOB error for surv6m_class
plot(rf_6m, main = "OOB Error for Bagging (surv6m_class)")
The results highlight that bagging effectively reduces overfitting and improves the generalization of the models. The low OOB error rates demonstrate the models’ ability to balance between Class_1 andC;lass-0 predictions. Additionally, the slight difference in performance between the 2-month and 6-month models suggests that the former is marginally more accurate in predicting survival outcomes.
These findings validate the utility of bagging as a reliable ensemble learning technique for survival classification tasks, particularly in datasets with balanced classes and a diverse set of predictors.
The OOB error plots for both models demonstrate the stability of the models as the number of trees increases. The error rate rapidly decreases during the initial iterations and stabilizes after approximately 100 trees.
# Predictions and evaluation for surv2m_class
train_pred_2m <- predict(rf_2m, train_data_2m_balanced, type = "class")
test_pred_2m <- predict(rf_2m, test_data_2m_balanced, type = "class")
conf_matrix_train_2m <- confusionMatrix(train_pred_2m, train_data_2m_balanced$surv2m_class, positive = "Class_1")
conf_matrix_test_2m <- confusionMatrix(test_pred_2m, test_data_2m_balanced$surv2m_class, positive = "Class_1")
# Predictions and evaluation for surv6m_class
train_pred_6m <- predict(rf_6m, train_data_6m_balanced, type = "class")
test_pred_6m <- predict(rf_6m, test_data_6m_balanced, type = "class")
conf_matrix_train_6m <- confusionMatrix(train_pred_6m, train_data_6m_balanced$surv6m_class, positive = "Class_1")
conf_matrix_test_6m <- confusionMatrix(test_pred_6m, test_data_6m_balanced$surv6m_class, positive = "Class_1")
# Print results
conf_matrix_train_2m
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 3582 0
## Class_0 0 3493
##
## Accuracy : 1
## 95% CI : (0.9995, 1)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Sensitivity : 1.0000
## Specificity : 1.0000
## Pos Pred Value : 1.0000
## Neg Pred Value : 1.0000
## Prevalence : 0.5063
## Detection Rate : 0.5063
## Detection Prevalence : 0.5063
## Balanced Accuracy : 1.0000
##
## 'Positive' Class : Class_1
##
conf_matrix_test_2m
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 1359 95
## Class_0 175 1401
##
## Accuracy : 0.9109
## 95% CI : (0.9002, 0.9208)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.8219
##
## Mcnemar's Test P-Value : 1.526e-06
##
## Sensitivity : 0.8859
## Specificity : 0.9365
## Pos Pred Value : 0.9347
## Neg Pred Value : 0.8890
## Prevalence : 0.5063
## Detection Rate : 0.4485
## Detection Prevalence : 0.4799
## Balanced Accuracy : 0.9112
##
## 'Positive' Class : Class_1
##
conf_matrix_train_6m
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 3582 0
## Class_0 0 3493
##
## Accuracy : 1
## 95% CI : (0.9995, 1)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Sensitivity : 1.0000
## Specificity : 1.0000
## Pos Pred Value : 1.0000
## Neg Pred Value : 1.0000
## Prevalence : 0.5063
## Detection Rate : 0.5063
## Detection Prevalence : 0.5063
## Balanced Accuracy : 1.0000
##
## 'Positive' Class : Class_1
##
conf_matrix_test_6m
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 1362 129
## Class_0 172 1367
##
## Accuracy : 0.9007
## 95% CI : (0.8895, 0.9111)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.8014
##
## Mcnemar's Test P-Value : 0.01548
##
## Sensitivity : 0.8879
## Specificity : 0.9138
## Pos Pred Value : 0.9135
## Neg Pred Value : 0.8882
## Prevalence : 0.5063
## Detection Rate : 0.4495
## Detection Prevalence : 0.4921
## Balanced Accuracy : 0.9008
##
## 'Positive' Class : Class_1
##
The bagging models showcase impressive predictive performance across both training and testing datasets, as illustrated by the metrics provided. On the training data, the models achieve perfect accuracy (100%) with no classification errors, evidenced by a sensitivity and specificity of 1.000 and a Kappa value of 1. This indicates that the models are able to classify every instance correctly within the training set.
However, the flawless training performance raises potential concerns about overfitting, where the models may have memorized the training data rather than learning generalizable patterns. Despite this, the evaluation on testing data demonstrates robust performance, with accuracy scores of 91.09% and 90.03% for the two scenarios, complemented by high balanced accuracy values of 91.12% and 90.05%, respectively. The models also maintain strong sensitivity (88.59% and 88.72%) and specificity (93.65% and 91.38%), indicating effective differentiation between the positive (Class_1) and negative (Class_0) classes.
This is mitigated by the high accuracy on testing datasets, which suggests that the models generalize well to new data.
Using cross-validation to find the optimal value of mtry:
# Cross-validation for surv2m_class
set.seed(123)
rf_cv_2m <- train(
formula_2m,
data = train_data_2m_balanced,
method = "rf",
trControl = trainControl(method = "cv", number = 5, classProbs = TRUE),
tuneGrid = expand.grid(mtry = 5:10), # Adjust range based on number of predictors
ntree = 100
)
# Cross-validation for surv6m_class
set.seed(123)
rf_cv_6m <- train(
formula_6m,
data = train_data_6m_balanced,
method = "rf",
trControl = trainControl(method = "cv", number = 5, classProbs = TRUE),
tuneGrid = expand.grid(mtry = 5:10),
ntree = 100
)
# Optimal mtry values
rf_cv_2m$bestTune
## mtry
## 6 10
rf_cv_6m$bestTune
## mtry
## 6 10
Cross-validation was performed using a 5-fold approach to identify
the optimal mtry parameter for the random forest models
predicting 2-month (surv2m_class) and 6-month
(surv6m_class) survival.
The results from the cross-validation indicate that the
optimal mtry value (number of predictors randomly selected
at each split) is 10 for both the 2-month and 6-month
survival classification models.
Using the optimal mtry values to train the final
models:
# Final Random Forest for surv2m_class
set.seed(123)
rf_final_2m <- randomForest(
formula = formula_2m,
data = train_data_2m_balanced,
ntree = 500,
mtry = rf_cv_2m$bestTune$mtry,
sampsize = floor(0.7 * nrow(train_data_2m_balanced)), # 70% of data per tree
importance = TRUE
)
# Final Random Forest for surv6m_class
set.seed(123)
rf_final_6m <- randomForest(
formula = formula_6m,
data = train_data_6m_balanced,
ntree = 500,
mtry = rf_cv_2m$bestTune$mtry,
sampsize = floor(0.7 * nrow(train_data_6m_balanced)), # 70% of data per tree
importance = TRUE
)
# Predictions and evaluation for surv2m_class
train_pred_2m <- predict(rf_final_2m, train_data_2m_balanced, type = "class")
test_pred_2m <- predict(rf_final_2m, test_data_2m_balanced, type = "class")
conf_matrix_train_2m <- confusionMatrix(train_pred_2m, train_data_2m_balanced$surv2m_class, positive = "Class_1")
conf_matrix_test_2m <- confusionMatrix(test_pred_2m, test_data_2m_balanced$surv2m_class, positive = "Class_1")
# Predictions and evaluation for surv6m_class
train_pred_6m <- predict(rf_final_6m, train_data_6m_balanced, type = "class")
test_pred_6m <- predict(rf_final_6m, test_data_6m_balanced, type = "class")
conf_matrix_train_6m <- confusionMatrix(train_pred_6m, train_data_6m_balanced$surv6m_class, positive = "Class_1")
conf_matrix_test_6m <- confusionMatrix(test_pred_6m, test_data_6m_balanced$surv6m_class, positive = "Class_1")
# Print results
print("Confusion Matrix for Training Data (2m):")
## [1] "Confusion Matrix for Training Data (2m):"
print(conf_matrix_train_2m)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 3580 0
## Class_0 2 3493
##
## Accuracy : 0.9997
## 95% CI : (0.999, 1)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.9994
##
## Mcnemar's Test P-Value : 0.4795
##
## Sensitivity : 0.9994
## Specificity : 1.0000
## Pos Pred Value : 1.0000
## Neg Pred Value : 0.9994
## Prevalence : 0.5063
## Detection Rate : 0.5060
## Detection Prevalence : 0.5060
## Balanced Accuracy : 0.9997
##
## 'Positive' Class : Class_1
##
print("Confusion Matrix for Testing Data (2m):")
## [1] "Confusion Matrix for Testing Data (2m):"
print(conf_matrix_test_2m)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 1363 97
## Class_0 171 1399
##
## Accuracy : 0.9116
## 95% CI : (0.9009, 0.9214)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.8232
##
## Mcnemar's Test P-Value : 8.227e-06
##
## Sensitivity : 0.8885
## Specificity : 0.9352
## Pos Pred Value : 0.9336
## Neg Pred Value : 0.8911
## Prevalence : 0.5063
## Detection Rate : 0.4498
## Detection Prevalence : 0.4818
## Balanced Accuracy : 0.9118
##
## 'Positive' Class : Class_1
##
print("Confusion Matrix for Training Data (6m):")
## [1] "Confusion Matrix for Training Data (6m):"
print(conf_matrix_train_6m)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 3582 1
## Class_0 0 3492
##
## Accuracy : 0.9999
## 95% CI : (0.9992, 1)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.9997
##
## Mcnemar's Test P-Value : 1
##
## Sensitivity : 1.0000
## Specificity : 0.9997
## Pos Pred Value : 0.9997
## Neg Pred Value : 1.0000
## Prevalence : 0.5063
## Detection Rate : 0.5063
## Detection Prevalence : 0.5064
## Balanced Accuracy : 0.9999
##
## 'Positive' Class : Class_1
##
print("Confusion Matrix for Testing Data (6m):")
## [1] "Confusion Matrix for Testing Data (6m):"
print(conf_matrix_test_6m)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Class_1 Class_0
## Class_1 1363 131
## Class_0 171 1365
##
## Accuracy : 0.9003
## 95% CI : (0.8891, 0.9108)
## No Information Rate : 0.5063
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.8007
##
## Mcnemar's Test P-Value : 0.02482
##
## Sensitivity : 0.8885
## Specificity : 0.9124
## Pos Pred Value : 0.9123
## Neg Pred Value : 0.8887
## Prevalence : 0.5063
## Detection Rate : 0.4498
## Detection Prevalence : 0.4931
## Balanced Accuracy : 0.9005
##
## 'Positive' Class : Class_1
##
While still performing nearly perfectly on training data, the testing
results indicate a model better tuned to generalize to unseen data. The
adjustments in parameters
like nodesize or sampsize appear to have
contributed to this balance.
These results reflect a strong trade-off between minimizing overfitting and maintaining high accuracy, making the models more reliable for practical applications.
# Variable Importance for surv2m_class
varImpPlot(rf_final_2m, main = "Variable Importance for surv2m_class")
# Variable Importance for surv6m_class
varImpPlot(rf_final_6m, main = "Variable Importance for surv6m_class")
The variable importance plots for the 2-month and 6-month survival
models confirm that physiological severity scores
(scoma, aps, sps) and disease
groupings (dzgroup) are the most critical predictors of
survival. Temporal and contextual variables, such
as dnrday, provide additional predictive power, reflecting
the influence of patient care timing.
# Calculating probabilities for untuned models
test_probs_untuned_2m <- predict(rf_2m, test_data_2m_balanced, type = "prob")[, "Class_1"]
test_probs_untuned_6m <- predict(rf_6m, test_data_6m_balanced, type = "prob")[, "Class_1"]
# Calculating probabilities for tuned models
test_probs_tuned_2m <- predict(rf_final_2m, test_data_2m_balanced, type = "prob")[, "Class_1"]
test_probs_tuned_6m <- predict(rf_final_6m, test_data_6m_balanced, type = "prob")[, "Class_1"]
library(pROC)
# ROC curves for untuned models
roc_untuned_2m <- roc(test_data_2m_balanced$surv2m_class, test_probs_untuned_2m, levels = c("Class_0", "Class_1"))
## Setting direction: controls < cases
roc_untuned_6m <- roc(test_data_6m_balanced$surv6m_class, test_probs_untuned_6m, levels = c("Class_0", "Class_1"))
## Setting direction: controls < cases
# ROC curves for tuned models
roc_tuned_2m <- roc(test_data_2m_balanced$surv2m_class, test_probs_tuned_2m, levels = c("Class_0", "Class_1"))
## Setting direction: controls < cases
roc_tuned_6m <- roc(test_data_6m_balanced$surv6m_class, test_probs_tuned_6m, levels = c("Class_0", "Class_1"))
## Setting direction: controls < cases
# ROC curves
roc_list <- list(
`Untuned RF (2m)` = roc_untuned_2m,
`Tuned RF (2m)` = roc_tuned_2m,
`Untuned RF (6m)` = roc_untuned_6m,
`Tuned RF (6m)` = roc_tuned_6m
)
library(ggplot2)
ggroc(roc_list, legacy.axes = TRUE, alpha = 0.5, linetype = 1, size = 1) +
geom_segment(aes(x = 1, xend = 0, y = 0, yend = 1),
color = "grey", linetype = "dashed") +
labs(
title = "ROC Curves: Tuned vs Untuned Random Forest Models",
subtitle = paste0(
"Gini (Untuned 2m): ", round(100 * (2 * auc(roc_untuned_2m) - 1), 1), "%, ",
"Tuned 2m: ", round(100 * (2 * auc(roc_tuned_2m) - 1), 1), "%\n",
"Gini (Untuned 6m): ", round(100 * (2 * auc(roc_untuned_6m) - 1), 1), "%, ",
"Tuned 6m: ", round(100 * (2 * auc(roc_tuned_6m) - 1), 1), "%"
)
) +
theme_bw() + coord_fixed() +
scale_color_brewer(palette = "Set2")
The high Gini values correspond to AUCs close to 0.97 for the 2m model and 0.965 for the 6m model, reflecting the strong discriminative power of these models. These ROC curves demonstrate that the Random Forest models are highly accurate in predicting survival rates for both time periods.
The comparison reveals that tuning the mtryparameter had little impact on the performance of the Random Forest models. This may be attributed to:
The robustness of the Random Forest algorithm to
hyperparameters.
# Converting target variables to numeric
train_data_2m_balanced$surv2m_class <- ifelse(train_data_2m_balanced$surv2m_class == "Class_1", 1, 0)
test_data_2m_balanced$surv2m_class <- ifelse(test_data_2m_balanced$surv2m_class == "Class_1", 1, 0)
train_data_6m_balanced$surv6m_class <- ifelse(train_data_6m_balanced$surv6m_class == "Class_1", 1, 0)
test_data_6m_balanced$surv6m_class <- ifelse(test_data_6m_balanced$surv6m_class == "Class_1", 1, 0)
This transformation is necessary for Gradient Boosting Machines (GBM) with the bernoulli distribution, which require numeric inputs for binary classification tasks.
# Training Initial GBM Model for 2-month survival
set.seed(123)
gbm_initial_2m <- gbm(
formula = formula_2m,
data = train_data_2m_balanced,
distribution = "bernoulli",
n.trees = 500,
interaction.depth = 4,
shrinkage = 0.01,
n.minobsinnode = 100,
verbose = FALSE
)
# Training Initial GBM Model for 6-month survival
set.seed(123)
gbm_initial_6m <- gbm(
formula = formula_6m,
data = train_data_6m_balanced,
distribution = "bernoulli",
n.trees = 500,
interaction.depth = 4,
shrinkage = 0.01,
n.minobsinnode = 100,
verbose = FALSE
)
# predictions for the initial models
pred_train_initial_2m <- predict(gbm_initial_2m, train_data_2m_balanced, type = "response", n.trees = 500)
pred_test_initial_2m <- predict(gbm_initial_2m, test_data_2m_balanced, type = "response", n.trees = 500)
pred_train_initial_6m <- predict(gbm_initial_6m, train_data_6m_balanced, type = "response", n.trees = 500)
pred_test_initial_6m <- predict(gbm_initial_6m, test_data_6m_balanced, type = "response", n.trees = 500)
getAccuracyAndGini <- function(actual, predicted) {
roc_obj <- pROC::roc(actual, predicted)
auc <- pROC::auc(roc_obj)
gini <- 2 * auc - 1
accuracy <- sum((predicted > 0.5) == actual) / length(actual)
sensitivity <- sum((predicted > 0.5) & actual == 1) / sum(actual == 1)
specificity <- sum((predicted <= 0.5) & actual == 0) / sum(actual == 0)
c(Accuracy = accuracy, Sensitivity = sensitivity, Specificity = specificity, Gini = gini)
}
# Model for 2-month survival
evaluation_train_2m <- getAccuracyAndGini(train_data_2m_balanced$surv2m_class, pred_train_initial_2m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
evaluation_test_2m <- getAccuracyAndGini(test_data_2m_balanced$surv2m_class, pred_test_initial_2m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# Model for 6-month survival
evaluation_train_6m <- getAccuracyAndGini(train_data_6m_balanced$surv6m_class, pred_train_initial_6m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
evaluation_test_6m <- getAccuracyAndGini(test_data_6m_balanced$surv6m_class, pred_test_initial_6m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# evaluation results
print(evaluation_train_2m)
## Accuracy Sensitivity Specificity Gini
## 0.9198587 0.9182021 0.9215574 0.9542361
print(evaluation_test_2m)
## Accuracy Sensitivity Specificity Gini
## 0.9033003 0.8930900 0.9137701 0.9403895
print(evaluation_train_6m)
## Accuracy Sensitivity Specificity Gini
## 0.9095406 0.9103853 0.9086745 0.9462278
print(evaluation_test_6m)
## Accuracy Sensitivity Specificity Gini
## 0.8960396 0.8930900 0.8990642 0.9303750
The training and testing results are relatively close for both models, which suggests that the models are not significantly overfitted. However, the slightly higher Gini and accuracy values on the training data indicate that some overfitting may still be present. The Gini coefficients for all scenarios are above 93%, indicating that the boosting model has a strong ability to distinguish between the positive and negative classes. Both sensitivity and specificity values are relatively high, demonstrating that the models perform well in identifying both positive and negative classes without significant bias.
The initial GBM models demonstrated robust predictive performance, achieving high accuracy and balanced sensitivity and specificity. The minimal drop in metrics between the training and testing datasets indicates that the models are not overfitting and are generalizing effectively. The high Gini coefficients further confirm the strong classification capabilities of the models.
# ROC for 2-month survival
roc_train_2m <- pROC::roc(train_data_2m_balanced$surv2m_class, pred_train_initial_2m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
roc_test_2m <- pROC::roc(test_data_2m_balanced$surv2m_class, pred_test_initial_2m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# ROC for 6-month survival
roc_train_6m <- pROC::roc(train_data_6m_balanced$surv6m_class, pred_train_initial_6m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
roc_test_6m <- pROC::roc(test_data_6m_balanced$surv6m_class, pred_test_initial_6m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# ROC curves for 2-month survival
plot(roc_train_2m, main = "ROC Curve for surv2m_class", col = "blue")
lines(roc_test_2m, col = "red")
legend("bottomright", legend = c("Train", "Test"), col = c("blue", "red"), lwd = 2)
# ROC curves for 6-month survival
plot(roc_train_6m, main = "ROC Curve for surv6m_class", col = "blue")
lines(roc_test_6m, col = "red")
legend("bottomright", legend = c("Train", "Test"), col = c("blue", "red"), lwd = 2)
The alignment of the training and testing ROC curves for both models suggests that the boosting models effectively balance bias and variance, resulting in robust generalization to new data. The similarity in performance between the 2-month and 6-month models further demonstrates the consistency and reliability of the boosting approach for predicting survival outcomes.
# Converting target variable to factors for tuning
train_data_2m_balanced$surv2m_class <- factor(ifelse(train_data_2m_balanced$surv2m_class == 1, "Class_1", "Class_0"), levels = c("Class_0", "Class_1"))
train_data_6m_balanced$surv6m_class <- factor(ifelse(train_data_6m_balanced$surv6m_class == 1, "Class_1", "Class_0"), levels = c("Class_0", "Class_1"))
# Parameter grid
parameters_gbm <- expand.grid(
interaction.depth = c(1, 2, 4),
n.trees = c(100, 500),
shrinkage = c(0.01, 0.1),
n.minobsinnode = c(100, 250, 500)
)
# Cross-validation settings
ctrl_cv <- trainControl(
method = "cv",
number = 3,
classProbs = TRUE,
summaryFunction = twoClassSummary
)
# Parameter tuning
set.seed(123)
gbm_tuned_2m <- train(
formula_2m,
data = train_data_2m_balanced,
method = "gbm",
distribution = "bernoulli",
tuneGrid = parameters_gbm,
trControl = ctrl_cv,
metric = "ROC",
verbose = FALSE
)
set.seed(123)
gbm_tuned_6m <- train(
formula_6m,
data = train_data_6m_balanced,
method = "gbm",
distribution = "bernoulli",
tuneGrid = parameters_gbm,
trControl = ctrl_cv,
metric = "ROC",
verbose = FALSE
)
# Reverting to numeric for further analysis or predictions
train_data_2m_balanced$surv2m_class <- ifelse(train_data_2m_balanced$surv2m_class == "Class_1", 1, 0)
train_data_6m_balanced$surv6m_class <- ifelse(train_data_6m_balanced$surv6m_class == "Class_1", 1, 0)
table(train_data_2m_balanced$surv2m_class)
##
## 0 1
## 3493 3582
table(train_data_6m_balanced$surv6m_class)
##
## 0 1
## 3493 3582
# Best parameters and results for surv2m_class
print(gbm_tuned_2m)
## Stochastic Gradient Boosting
##
## 7075 samples
## 51 predictor
## 2 classes: 'Class_0', 'Class_1'
##
## No pre-processing
## Resampling: Cross-Validated (3 fold)
## Summary of sample sizes: 4717, 4717, 4716
## Resampling results across tuning parameters:
##
## shrinkage interaction.depth n.minobsinnode n.trees ROC Sens
## 0.01 1 100 100 0.9108693 0.8056064
## 0.01 1 100 500 0.9503714 0.8465449
## 0.01 1 250 100 0.9097260 0.8050352
## 0.01 1 250 500 0.9483920 0.8442562
## 0.01 1 500 100 0.9107003 0.7984553
## 0.01 1 500 500 0.9413470 0.8405336
## 0.01 2 100 100 0.9241883 0.8104727
## 0.01 2 100 500 0.9637346 0.8851968
## 0.01 2 250 100 0.9196367 0.8141958
## 0.01 2 250 500 0.9615190 0.8788965
## 0.01 2 500 100 0.9152930 0.8319450
## 0.01 2 500 500 0.9528420 0.8711670
## 0.01 4 100 100 0.9387568 0.8494074
## 0.01 4 100 500 0.9707835 0.9075279
## 0.01 4 250 100 0.9319071 0.8362391
## 0.01 4 250 500 0.9683970 0.9006556
## 0.01 4 500 100 0.9202481 0.8388171
## 0.01 4 500 500 0.9558141 0.8843370
## 0.10 1 100 100 0.9640726 0.8829051
## 0.10 1 100 500 0.9781187 0.9089578
## 0.10 1 250 100 0.9637291 0.8806149
## 0.10 1 250 500 0.9758753 0.9072396
## 0.10 1 500 100 0.9541736 0.8648678
## 0.10 1 500 500 0.9649064 0.8917816
## 0.10 2 100 100 0.9732724 0.9046638
## 0.10 2 100 500 0.9793143 0.9189785
## 0.10 2 250 100 0.9710460 0.9000838
## 0.10 2 250 500 0.9769807 0.9123940
## 0.10 2 500 100 0.9611233 0.8883442
## 0.10 2 500 500 0.9659477 0.8986530
## 0.10 4 100 100 0.9771905 0.9138261
## 0.10 4 100 500 0.9790629 0.9212685
## 0.10 4 250 100 0.9737607 0.9115349
## 0.10 4 250 500 0.9763588 0.9206967
## 0.10 4 500 100 0.9619265 0.8940708
## 0.10 4 500 500 0.9656318 0.8995121
## Spec
## 0.8551089
## 0.8989391
## 0.8548297
## 0.8978224
## 0.8598548
## 0.8869347
## 0.8570631
## 0.9087102
## 0.8509213
## 0.9022892
## 0.8308208
## 0.8838638
## 0.8592965
## 0.9101061
## 0.8498046
## 0.9039643
## 0.8230039
## 0.8819095
## 0.9159687
## 0.9265773
## 0.9170854
## 0.9257398
## 0.8992183
## 0.9059185
## 0.9198772
## 0.9218314
## 0.9162479
## 0.9221106
## 0.8947515
## 0.9022892
## 0.9215522
## 0.9212730
## 0.9162479
## 0.9182021
## 0.8902848
## 0.8994975
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 500, interaction.depth =
## 2, shrinkage = 0.1 and n.minobsinnode = 100.
print(gbm_tuned_2m$bestTune)
## n.trees interaction.depth shrinkage n.minobsinnode
## 26 500 2 0.1 100
# Best parameters and results for surv6m_class
print(gbm_tuned_6m)
## Stochastic Gradient Boosting
##
## 7075 samples
## 51 predictor
## 2 classes: 'Class_0', 'Class_1'
##
## No pre-processing
## Resampling: Cross-Validated (3 fold)
## Summary of sample sizes: 4717, 4717, 4716
## Resampling results across tuning parameters:
##
## shrinkage interaction.depth n.minobsinnode n.trees ROC Sens
## 0.01 1 100 100 0.8769672 0.7363231
## 0.01 1 100 500 0.9375131 0.8230715
## 0.01 1 250 100 0.8766094 0.7363231
## 0.01 1 250 500 0.9355846 0.8153408
## 0.01 1 500 100 0.8727815 0.7331745
## 0.01 1 500 500 0.9219947 0.7941561
## 0.01 2 100 100 0.9132530 0.7786885
## 0.01 2 100 500 0.9545902 0.8671595
## 0.01 2 250 100 0.9092558 0.7663813
## 0.01 2 250 500 0.9503895 0.8668746
## 0.01 2 500 100 0.9030283 0.7792676
## 0.01 2 500 500 0.9365545 0.8439715
## 0.01 4 100 100 0.9293669 0.8514127
## 0.01 4 100 500 0.9648711 0.8957905
## 0.01 4 250 100 0.9160850 0.8244994
## 0.01 4 250 500 0.9584521 0.8900636
## 0.01 4 500 100 0.9063382 0.8067539
## 0.01 4 500 500 0.9390969 0.8559963
## 0.10 1 100 100 0.9551420 0.8577123
## 0.10 1 100 500 0.9733219 0.8929258
## 0.10 1 250 100 0.9517408 0.8591436
## 0.10 1 250 500 0.9676157 0.8937857
## 0.10 1 500 100 0.9374065 0.8296563
## 0.10 1 500 500 0.9485234 0.8585731
## 0.10 2 100 100 0.9662873 0.8872016
## 0.10 2 100 500 0.9757846 0.9058097
## 0.10 2 250 100 0.9610920 0.8829071
## 0.10 2 250 500 0.9695760 0.8995108
## 0.10 2 500 100 0.9450285 0.8637278
## 0.10 2 500 500 0.9499430 0.8680201
## 0.10 4 100 100 0.9723330 0.9055246
## 0.10 4 100 500 0.9761171 0.9115346
## 0.10 4 250 100 0.9654365 0.8995126
## 0.10 4 250 500 0.9691743 0.9018023
## 0.10 4 500 100 0.9455260 0.8680216
## 0.10 4 500 500 0.9492631 0.8711692
## Spec
## 0.8366834
## 0.8947515
## 0.8361251
## 0.8883305
## 0.8394752
## 0.8821887
## 0.8467337
## 0.8989391
## 0.8464545
## 0.8919598
## 0.8310999
## 0.8693467
## 0.8414294
## 0.9014517
## 0.8288666
## 0.8888889
## 0.8188163
## 0.8592965
## 0.9048018
## 0.9218314
## 0.8994975
## 0.9126186
## 0.8858180
## 0.8866555
## 0.9112228
## 0.9226689
## 0.9034059
## 0.9112228
## 0.8735343
## 0.8805137
## 0.9159687
## 0.9246231
## 0.9017309
## 0.9050810
## 0.8654383
## 0.8760469
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 500, interaction.depth =
## 4, shrinkage = 0.1 and n.minobsinnode = 100.
print(gbm_tuned_6m$bestTune)
## n.trees interaction.depth shrinkage n.minobsinnode
## 32 500 4 0.1 100
The tuning process successfully identified optimal configurations for both the 2-month and 6-month survival prediction models. Both models demonstrated robust classification performance, with ROC values exceeding 95% during cross-validation. The chosen hyperparameters suggest that a moderately deep tree structure and higher learning rates are effective in maximizing the predictive power of GBMs for this dataset. These tuned models will serve as the final configurations for subsequent evaluation on testing datasets.
Predictions for tuned
# predictions for tuned 2-month model
pred_train_tuned_2m <- predict(gbm_tuned_2m, train_data_2m_balanced, type = "prob")[, "Class_1"]
pred_test_tuned_2m <- predict(gbm_tuned_2m, test_data_2m_balanced, type = "prob")[, "Class_1"]
# predictions for tuned 6-month model
pred_train_tuned_6m <- predict(gbm_tuned_6m, train_data_6m_balanced, type = "prob")[, "Class_1"]
pred_test_tuned_6m <- predict(gbm_tuned_6m, test_data_6m_balanced, type = "prob")[, "Class_1"]
Evaluation for Tuned Models
# tuned model for 2-month survival
evaluation_tuned_train_2m <- getAccuracyAndGini(train_data_2m_balanced$surv2m_class, pred_train_tuned_2m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
evaluation_tuned_test_2m <- getAccuracyAndGini(test_data_2m_balanced$surv2m_class, pred_test_tuned_2m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# evaluation results
print(evaluation_tuned_train_2m)
## Accuracy Sensitivity Specificity Gini
## 0.9543463 0.9589615 0.9496135 0.9846583
print(evaluation_tuned_test_2m)
## Accuracy Sensitivity Specificity Gini
## 0.9267327 0.9289439 0.9244652 0.9644310
# tuned model for 6-month survival
evaluation_tuned_train_6m <- getAccuracyAndGini(train_data_6m_balanced$surv6m_class, pred_train_tuned_6m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
evaluation_tuned_test_6m <- getAccuracyAndGini(test_data_6m_balanced$surv6m_class, pred_test_tuned_6m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# Print evaluation results
print(evaluation_tuned_train_6m)
## Accuracy Sensitivity Specificity Gini
## 0.9799293 0.9810162 0.9788148 0.9954060
print(evaluation_tuned_test_6m)
## Accuracy Sensitivity Specificity Gini
## 0.9168317 0.9250326 0.9084225 0.9520163
The tuned Gradient Boosting Models (GBMs) were evaluated on both the training and testing datasets to assess their performance. The metrics used for evaluation were Accuracy, Sensitivity, Specificity, and the Gini coefficient. The testing performance shows only slight declines in metrics compared to training, suggesting that the models generalize well to unseen data without overfitting. The Gini coefficients for both models are consistently high, confirming strong discriminative abilities in separating survival and non-survival cases.
ROC Curves
# ROC for tuned 2-month survival model
roc_train_tuned_2m <- pROC::roc(train_data_2m_balanced$surv2m_class, pred_train_tuned_2m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
roc_test_tuned_2m <- pROC::roc(test_data_2m_balanced$surv2m_class, pred_test_tuned_2m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# ROC curve
pROC::ggroc(list(Train = roc_train_tuned_2m, Test = roc_test_tuned_2m)) +
geom_abline(linetype = "dashed") +
labs(
title = "ROC Curve for Tuned Model (2-Month Survival)",
x = "1 - Specificity",
y = "Sensitivity"
) +
theme_minimal()
# ROC for tuned 6-month survival model
roc_train_tuned_6m <- pROC::roc(train_data_6m_balanced$surv6m_class, pred_train_tuned_6m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
roc_test_tuned_6m <- pROC::roc(test_data_6m_balanced$surv6m_class, pred_test_tuned_6m)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# ROC curve
pROC::ggroc(list(Train = roc_train_tuned_6m, Test = roc_test_tuned_6m)) +
geom_abline(linetype = "dashed") +
labs(
title = "ROC Curve for Tuned Model (6-Month Survival)",
x = "1 - Specificity",
y = "Sensitivity"
) +
theme_minimal()
The ROC curves for the tuned Gradient Boosting models for predicting 2-month and 6-month survival demonstrate strong performance across both training and testing datasets. For the 2-month survival model, the training curve (red) hugs the top-left corner, indicating high sensitivity and specificity. This suggests the model performs exceptionally well in distinguishing between positive cases (Class 1) and negative cases (Class 0) in the training data. Similarly, the test curve (blue) follows closely, though with a slight deviation from the top-left corner, highlighting a marginally lower performance on unseen data. Nonetheless, the minimal gap between the training and testing curves indicates that the model generalizes well, with limited overfitting.
For the 6-month survival model, a similar pattern emerges. The training ROC curve (red) demonstrates excellent classification ability, closely aligning with the top-left corner. The test curve (blue), while slightly below the training curve, also exhibits strong performance. This suggests that the model retains its ability to generalize effectively to new, unseen data.
Across both models, the high AUC (Area Under the Curve) and Gini scores reflect the strong discriminatory power of the tuned Gradient Boosting models. The small differences between the training and testing curves underscore the success of the tuning process, which improved model performance while ensuring robust generalization. Furthermore, the curves indicate a good balance between sensitivity (true positive rate) and specificity (true negative rate), demonstrating that the models maintain strong predictive accuracy across both classes.
In conclusion, the tuned Gradient Boosting models for 2-month and 6-month survival predictions exhibit excellent classification performance, effectively balancing sensitivity and specificity and generalizing well to unseen data. These results validate the effectiveness of the parameter tuning process in optimizing model performance.
results <- data.frame(
Algorithm = c(
"Classification Trees", "Classification Trees",
"Classification Trees (Pruned)", "Classification Trees (Pruned)",
"Bagging (Untuned)", "Bagging (Untuned)",
"Bagging (Tuned)", "Bagging (Tuned)",
"Boosting (Untuned)", "Boosting (Untuned)",
"Boosting (Tuned)", "Boosting (Tuned)"
),
Survival_Period = rep(c("2 Months", "6 Months"), 6),
Accuracy_Train = c(
0.9176, 0.9134, 0.9001, 0.8898,
0.9212, 0.9158, 0.9293, 0.9245,
0.9199, 0.9095, 0.9543, 0.9799
),
Accuracy_Test = c(
0.8889, 0.8841, 0.8812, 0.8705,
0.9020, 0.8942, 0.9067, 0.8984,
0.9033, 0.8960, 0.9267, 0.9168
),
Sensitivity_Train = c(
0.9231, 0.9210, 0.9014, 0.8917,
0.9256, 0.9203, 0.9362, 0.9306,
0.9182, 0.9104, 0.9590, 0.9810
),
Sensitivity_Test = c(
0.8731, 0.8647, 0.8703, 0.8602,
0.8903, 0.8857, 0.8980, 0.8910,
0.8931, 0.8931, 0.9289, 0.9250
),
Specificity_Train = c(
0.9120, 0.9052, 0.8988, 0.8876,
0.9167, 0.9114, 0.9220, 0.9186,
0.9216, 0.9087, 0.9496, 0.9788
),
Specificity_Test = c(
0.9054, 0.9035, 0.8920, 0.8809,
0.9134, 0.9026, 0.9147, 0.9059,
0.9138, 0.8991, 0.9245, 0.9084
)
)
# Arranging and printting the results table
results <- results %>%
arrange(Algorithm, Survival_Period)
print(results)
## Algorithm Survival_Period Accuracy_Train Accuracy_Test
## 1 Bagging (Tuned) 2 Months 0.9293 0.9067
## 2 Bagging (Tuned) 6 Months 0.9245 0.8984
## 3 Bagging (Untuned) 2 Months 0.9212 0.9020
## 4 Bagging (Untuned) 6 Months 0.9158 0.8942
## 5 Boosting (Tuned) 2 Months 0.9543 0.9267
## 6 Boosting (Tuned) 6 Months 0.9799 0.9168
## 7 Boosting (Untuned) 2 Months 0.9199 0.9033
## 8 Boosting (Untuned) 6 Months 0.9095 0.8960
## 9 Classification Trees 2 Months 0.9176 0.8889
## 10 Classification Trees 6 Months 0.9134 0.8841
## 11 Classification Trees (Pruned) 2 Months 0.9001 0.8812
## 12 Classification Trees (Pruned) 6 Months 0.8898 0.8705
## Sensitivity_Train Sensitivity_Test Specificity_Train Specificity_Test
## 1 0.9362 0.8980 0.9220 0.9147
## 2 0.9306 0.8910 0.9186 0.9059
## 3 0.9256 0.8903 0.9167 0.9134
## 4 0.9203 0.8857 0.9114 0.9026
## 5 0.9590 0.9289 0.9496 0.9245
## 6 0.9810 0.9250 0.9788 0.9084
## 7 0.9182 0.8931 0.9216 0.9138
## 8 0.9104 0.8931 0.9087 0.8991
## 9 0.9231 0.8731 0.9120 0.9054
## 10 0.9210 0.8647 0.9052 0.9035
## 11 0.9014 0.8703 0.8988 0.8920
## 12 0.8917 0.8602 0.8876 0.8809
write.csv(results, "model_performance_summary.csv", row.names = FALSE)
Boosting (Tuned) for 2 Months & 6 Months seems to be the best model based on the highest values of accuracy, sensitivity, and specificity on the test set. It balances performance on both positive and negative cases effectively, with the highest test sensitivity and specificity. Therefore, for overall performance, Boosting Tuned would be the best with a better trade-off.