Prediksi Kematian Pasien Di Rumah Sakit Menggunakan Model Naive Bayes, Decision Tree dan Random Forest

1. Introduction

Kematian di rumah sakit khususnya pada kasus gagal jantung di unit perawatan intensif (ICU) menjadi perhatian pada dataset ini. Hasil utama pada penelitian ini menurut sumber data didefinisikan sebagai status vital pada saat keluar dari rumah sakit sebagai orang yang selamat (alive) dan tidak selamat (death). Dataset diambil menggunakan SQL dan PostgreSQL meliputi data demografis pasien, tanda vital, penyakit penyerta serta hasil laboratorium. Tujuan pembuatan berbagai model machine learning adalah untuk mendapatkan model terbaik dalam memprediksi kematian terutama pada pasien gagal jantung yang dirawat di ICU rumah sakit.

2. Import Library

Berikut merupakan packages yang digunakan pada analisis data

library(dplyr) # for data wrangling
library(ggplot2) # to visualize data
library(gridExtra) # to display multiple graph
library(inspectdf) # for EDA
library(tidymodels) # to build tidy models
library(caret) # to pre-process data
library(caretEnsemble)
library(naivebayes)
library(tidyr)
library(e1071) # model naive bayes
library(haven)
library(nnet)
library(class)
library(rmdformats)
library(psych)
library(pracma)
library(ROCR)
library(randomForest)
library(remotes)
remotes::install_github("ahmadhusain/cmplot")
library(zoo)

3. Read Data

Dataset didapat dari link kaggle [https://www.kaggle.com/datasets/saurabhshahane/in-hospital-mortality-prediction]

mortality <- read.csv("data input in hospital mortality prediction/data01.csv")
rmarkdown::paged_table(mortality)

Terdapat beberapa kolom yang sebaiknya dihapus dan diganti tipe datanya sebagai berikut

mortality_clean <- mortality %>% 
  mutate(gendera=as.factor(ifelse(gendera == "1", "female", "male")),
         hypertensive=as.factor(ifelse(hypertensive == "1", "yes", "no")),
         atrialfibrillation=as.factor(ifelse(atrialfibrillation == "1", "yes", "no")),
         CHD.with.no.MI=as.factor(ifelse(CHD.with.no.MI == "1", "yes", "no")),
         diabetes=as.factor(ifelse(diabetes == "1", "yes", "no")),
         deficiencyanemias=as.factor(ifelse(deficiencyanemias == "1", "yes", "no")),
         depression=as.factor(ifelse(depression == "1", "yes", "no")),
         Hyperlipemia=as.factor(ifelse(Hyperlipemia == "1", "yes", "no")),
         Renal.failure=as.factor(ifelse(Renal.failure == "1", "yes", "no")),
         COPD=as.factor(ifelse(COPD == "1", "yes", "no")),
         outcome = as.factor(ifelse(outcome == "1", "death", "alive"))) %>% 
  select(-c(group,ID))
  

glimpse(mortality_clean)
#> Rows: 1,177
#> Columns: 49
#> $ outcome                  <fct> alive, alive, alive, alive, alive, alive, ali…
#> $ age                      <int> 72, 75, 83, 43, 75, 76, 72, 83, 61, 67, 70, 8…
#> $ gendera                  <fct> female, male, male, male, male, female, femal…
#> $ BMI                      <dbl> 37.58818, NA, 26.57263, 83.26463, 31.82484, 2…
#> $ hypertensive             <fct> no, no, no, no, yes, yes, yes, yes, yes, yes,…
#> $ atrialfibrillation       <fct> no, no, no, no, no, yes, no, yes, yes, no, no…
#> $ CHD.with.no.MI           <fct> no, no, no, no, no, no, no, no, no, no, no, n…
#> $ diabetes                 <fct> yes, no, no, no, no, no, no, yes, yes, yes, y…
#> $ deficiencyanemias        <fct> yes, yes, yes, no, yes, yes, no, yes, no, no,…
#> $ depression               <fct> no, no, no, no, no, no, no, no, no, no, no, n…
#> $ Hyperlipemia             <fct> yes, no, no, no, no, yes, yes, no, no, no, ye…
#> $ Renal.failure            <fct> yes, no, yes, no, yes, yes, yes, no, yes, no,…
#> $ COPD                     <fct> no, yes, no, no, yes, yes, yes, no, no, no, n…
#> $ heart.rate               <dbl> 68.83784, 101.37037, 72.31818, 94.50000, 67.9…
#> $ Systolic.blood.pressure  <dbl> 155.86667, 140.00000, 135.33333, 126.40000, 1…
#> $ Diastolic.blood.pressure <dbl> 68.33333, 65.00000, 61.37500, 73.20000, 58.12…
#> $ Respiratory.rate         <dbl> 16.62162, 20.85185, 23.64000, 21.85714, 21.36…
#> $ temperature              <dbl> 36.71429, 36.68254, 36.45370, 36.28704, 36.76…
#> $ SP.O2                    <dbl> 98.39474, 96.92308, 95.29167, 93.84615, 99.28…
#> $ Urine.output             <dbl> 2155, 1425, 2425, 8760, 4455, 1840, 2450, 303…
#> $ hematocrit               <dbl> 26.27273, 30.78000, 27.70000, 36.63750, 29.93…
#> $ RBC                      <dbl> 2.960000, 3.138000, 2.620000, 4.277500, 3.286…
#> $ MCH                      <dbl> 28.25000, 31.06000, 34.32000, 26.06250, 30.66…
#> $ MCHC                     <dbl> 31.52000, 31.66000, 31.30000, 30.41250, 33.66…
#> $ MCV                      <dbl> 89.90000, 98.20000, 109.80000, 85.62500, 91.0…
#> $ RDW                      <dbl> 16.22000, 14.26000, 23.82000, 17.03750, 16.26…
#> $ Leucocyte                <dbl> 7.6500000, 12.7400000, 5.4800000, 8.2250000, …
#> $ Platelets                <dbl> 305.10000, 246.40000, 204.20000, 216.37500, 2…
#> $ Neutrophils              <dbl> 74.65000, NA, 68.10000, 81.80000, NA, 85.4000…
#> $ Basophils                <dbl> 0.40, NA, 0.55, 0.15, NA, 0.30, 0.20, NA, 0.5…
#> $ Lymphocyte               <dbl> 13.300000, NA, 24.500000, 14.500000, NA, 9.30…
#> $ PT                       <dbl> 10.60000, NA, 11.27500, 27.06667, NA, 18.7833…
#> $ INR                      <dbl> 1.000000, NA, 0.950000, 2.666667, NA, 1.70000…
#> $ NT.proBNP                <dbl> 1956.0, 2384.0, 4081.0, 668.0, 30802.0, 34183…
#> $ Creatine.kinase          <dbl> 148.00000, 60.60000, 16.00000, 85.00000, 111.…
#> $ Creatinine               <dbl> 1.9583333, 1.1222222, 1.8714286, 0.5857143, 1…
#> $ Urea.nitrogen            <dbl> 50.00000, 20.33333, 33.85714, 15.28571, 43.00…
#> $ glucose                  <dbl> 114.63636, 147.50000, 149.00000, 128.25000, 1…
#> $ Blood.potassium          <dbl> 4.816667, 4.450000, 5.825000, 4.386667, 4.783…
#> $ Blood.sodium             <dbl> 138.7500, 138.8889, 140.7143, 138.5000, 136.6…
#> $ Blood.calcium            <dbl> 7.463636, 8.162500, 8.266667, 9.476923, 8.733…
#> $ Chloride                 <dbl> 109.16667, 98.44444, 105.85714, 92.07143, 104…
#> $ Anion.gap                <dbl> 13.166667, 11.444444, 10.000000, 12.357143, 1…
#> $ Magnesium.ion            <dbl> 2.618182, 1.887500, 2.157143, 1.942857, 1.650…
#> $ PH                       <dbl> 7.230000, 7.225000, 7.268000, 7.370000, 7.250…
#> $ Bicarbonate              <dbl> 21.16667, 33.44444, 30.57143, 38.57143, 22.00…
#> $ Lactic.acid              <dbl> 0.5000000, 0.5000000, 0.5000000, 0.6000000, 0…
#> $ PCO2                     <dbl> 40.00000, 78.00000, 71.50000, 75.00000, 50.00…
#> $ EF                       <int> 55, 55, 35, 55, 55, 35, 55, 75, 50, 55, 75, 5…

Dataset terdiri dari 49 kolom di antaranya:
* data demografis pasien : umur, jenis kelamin, index massa tubuh
* tanda vital : laju detak jantung (heart.rate), tekanan darah (Systolic.blood.pressure & Diastolic.blood.pressure), laju pernapasan (Respiratory.rate), suhu tubuh (temperature), oksigen nadi saturasi (SP.O2), urin 24 jam pertama (Urine.output)
* penyakit penyerta : hypertensive, fibrilasi atrium (atrialfibrillation), penyakit jantung iskemik tanpa infark miokard (CHD.with.no.MI), diabetes, depression anemia (deficiencyanemias), hiperlipidemia (Hyperlipemia), penyakit ginjal kronis (Renal.failure), dan penyakit paru obstruktif kronik (COPD)
* hasil laboratorium mencakup:
- sel darah: hematokrit, eritrosit (RBC), rata-rata jumlah hemoglobin dalam satu eritrosit (MCH), rata-rata konsentrasi hemoglobin dalam satu eritrosit (MCHC), volume rata-rata eritrosit (MCV), variasi ukuran eritrosit (RDW), jumlah trombosit (Platelets), sel darah putih (Leucocyte), neutrofil, basofil, limfosit
- diagnosa resiko bleeding dan blood clot: waktu protrombin (PT), laju pembekuan darah (INR)
- diagnosa gagal jantung: parameter hormon pendeteksi gagal jantung (NT-proBNP), Left ventricular ejection fraction (EF)
- diagnosa diabetes: gula darah (glucose)
- diagnosa asidosis metabolik: nitrogen urea darah (BUN), Urea.nitrogen, celah anion (Anion.gap), Bicarbonate, Lactic.acid, konsentrasi ion hidrogen (pH), glucose, Creatinine, Creatine.kinase
- deteksi defisiensi ataupun toksisitas mineral dalam darah: kalium (Blood.potassium), natrium (Blood.sodium), kalsium (Blood.calcium), klorida (Chloride), magnesium (Magnesium.ion), tekanan parsial CO2 dalam darah arteri (PCO2)

4. Data Wrangling

# cek data duplikat
anyDuplicated(mortality_clean)
#> [1] 0

Tidak terdapat duplikat pada data. Selanjutnya, akan dilakukan pengecekan pada missing value

# cek missing value
is.na(mortality_clean) %>% colSums()
#>                  outcome                      age                  gendera 
#>                        1                        0                        0 
#>                      BMI             hypertensive       atrialfibrillation 
#>                      215                        0                        0 
#>           CHD.with.no.MI                 diabetes        deficiencyanemias 
#>                        0                        0                        0 
#>               depression             Hyperlipemia            Renal.failure 
#>                        0                        0                        0 
#>                     COPD               heart.rate  Systolic.blood.pressure 
#>                        0                       13                       16 
#> Diastolic.blood.pressure         Respiratory.rate              temperature 
#>                       16                       13                       19 
#>                    SP.O2             Urine.output               hematocrit 
#>                       13                       36                        0 
#>                      RBC                      MCH                     MCHC 
#>                        0                        0                        0 
#>                      MCV                      RDW                Leucocyte 
#>                        0                        0                        0 
#>                Platelets              Neutrophils                Basophils 
#>                        0                      144                      259 
#>               Lymphocyte                       PT                      INR 
#>                      145                       20                       20 
#>                NT.proBNP          Creatine.kinase               Creatinine 
#>                        0                      165                        0 
#>            Urea.nitrogen                  glucose          Blood.potassium 
#>                        0                       18                        0 
#>             Blood.sodium            Blood.calcium                 Chloride 
#>                        0                        1                        0 
#>                Anion.gap            Magnesium.ion                       PH 
#>                        0                        0                      292 
#>              Bicarbonate              Lactic.acid                     PCO2 
#>                        0                      229                      294 
#>                       EF 
#>                        0

Ternyata, dari hasil pengecekan terdapat beberapa kolom yang memiliki banyak nilai missing value. Pertama-tama mari kita cek summary pada data numerik

mortality_clean %>% 
  select_if(is.numeric) %>% 
  summary()
#>       age             BMI           heart.rate     Systolic.blood.pressure
#>  Min.   :19.00   Min.   : 13.35   Min.   : 36.00   Min.   : 75.0          
#>  1st Qu.:65.00   1st Qu.: 24.33   1st Qu.: 72.37   1st Qu.:105.4          
#>  Median :77.00   Median : 28.31   Median : 83.61   Median :116.1          
#>  Mean   :74.06   Mean   : 30.19   Mean   : 84.58   Mean   :118.0          
#>  3rd Qu.:85.00   3rd Qu.: 33.63   3rd Qu.: 95.91   3rd Qu.:128.6          
#>  Max.   :99.00   Max.   :104.97   Max.   :135.71   Max.   :203.0          
#>                  NA's   :215      NA's   :13       NA's   :16             
#>  Diastolic.blood.pressure Respiratory.rate  temperature        SP.O2       
#>  Min.   : 24.74           Min.   :11.14    Min.   :33.25   Min.   : 75.92  
#>  1st Qu.: 52.17           1st Qu.:17.93    1st Qu.:36.29   1st Qu.: 95.00  
#>  Median : 58.46           Median :20.37    Median :36.65   Median : 96.45  
#>  Mean   : 59.53           Mean   :20.80    Mean   :36.68   Mean   : 96.27  
#>  3rd Qu.: 65.46           3rd Qu.:23.39    3rd Qu.:37.02   3rd Qu.: 97.92  
#>  Max.   :107.00           Max.   :40.90    Max.   :39.13   Max.   :100.00  
#>  NA's   :16               NA's   :13       NA's   :19      NA's   :13      
#>   Urine.output    hematocrit         RBC             MCH             MCHC      
#>  Min.   :   0   Min.   :20.31   Min.   :2.030   Min.   :18.12   Min.   :27.82  
#>  1st Qu.: 980   1st Qu.:28.16   1st Qu.:3.120   1st Qu.:28.25   1st Qu.:32.01  
#>  Median :1675   Median :30.80   Median :3.490   Median :29.75   Median :32.99  
#>  Mean   :1899   Mean   :31.91   Mean   :3.575   Mean   :29.54   Mean   :32.86  
#>  3rd Qu.:2500   3rd Qu.:35.01   3rd Qu.:3.900   3rd Qu.:31.24   3rd Qu.:33.83  
#>  Max.   :8820   Max.   :55.42   Max.   :6.575   Max.   :40.31   Max.   :37.01  
#>  NA's   :36                                                                    
#>       MCV              RDW          Leucocyte       Platelets       
#>  Min.   : 62.60   Min.   :12.09   Min.   : 0.10   Min.   :   9.571  
#>  1st Qu.: 86.25   1st Qu.:14.46   1st Qu.: 7.44   1st Qu.: 168.909  
#>  Median : 90.00   Median :15.51   Median : 9.68   Median : 222.667  
#>  Mean   : 89.90   Mean   :15.95   Mean   :10.71   Mean   : 241.504  
#>  3rd Qu.: 93.86   3rd Qu.:16.94   3rd Qu.:12.74   3rd Qu.: 304.250  
#>  Max.   :116.71   Max.   :29.05   Max.   :64.75   Max.   :1028.200  
#>                                                                     
#>   Neutrophils      Basophils        Lymphocyte            PT       
#>  Min.   : 5.00   Min.   :0.1000   Min.   : 0.9667   Min.   :10.10  
#>  1st Qu.:74.78   1st Qu.:0.2000   1st Qu.: 6.6500   1st Qu.:13.16  
#>  Median :82.47   Median :0.3000   Median :10.4750   Median :14.63  
#>  Mean   :80.11   Mean   :0.4056   Mean   :12.2330   Mean   :17.48  
#>  3rd Qu.:87.45   3rd Qu.:0.5000   3rd Qu.:15.4625   3rd Qu.:18.80  
#>  Max.   :98.00   Max.   :8.8000   Max.   :83.5000   Max.   :71.27  
#>  NA's   :144     NA's   :259      NA's   :145       NA's   :20     
#>       INR           NT.proBNP      Creatine.kinase      Creatinine     
#>  Min.   :0.8714   Min.   :    50   Min.   :    8.00   Min.   : 0.2667  
#>  1st Qu.:1.1400   1st Qu.:  2251   1st Qu.:   46.00   1st Qu.: 0.9400  
#>  Median :1.3000   Median :  5840   Median :   89.25   Median : 1.2875  
#>  Mean   :1.6255   Mean   : 11014   Mean   :  246.78   Mean   : 1.6428  
#>  3rd Qu.:1.7364   3rd Qu.: 14968   3rd Qu.:  185.19   3rd Qu.: 1.9000  
#>  Max.   :8.3429   Max.   :118928   Max.   :42987.50   Max.   :15.5273  
#>  NA's   :20                        NA's   :165                         
#>  Urea.nitrogen        glucose       Blood.potassium  Blood.sodium  
#>  Min.   :  5.357   Min.   : 66.67   Min.   :3.000   Min.   :114.7  
#>  1st Qu.: 20.833   1st Qu.:113.94   1st Qu.:3.900   1st Qu.:136.7  
#>  Median : 30.667   Median :136.40   Median :4.115   Median :139.2  
#>  Mean   : 36.298   Mean   :148.80   Mean   :4.177   Mean   :138.9  
#>  3rd Qu.: 45.250   3rd Qu.:169.50   3rd Qu.:4.400   3rd Qu.:141.6  
#>  Max.   :161.750   Max.   :414.10   Max.   :6.567   Max.   :154.7  
#>                    NA's   :18                                      
#>  Blood.calcium       Chloride        Anion.gap      Magnesium.ion  
#>  Min.   : 6.700   Min.   : 80.27   Min.   : 6.636   Min.   :1.400  
#>  1st Qu.: 8.149   1st Qu.: 99.00   1st Qu.:12.250   1st Qu.:1.956  
#>  Median : 8.500   Median :102.50   Median :13.667   Median :2.092  
#>  Mean   : 8.501   Mean   :102.28   Mean   :13.925   Mean   :2.120  
#>  3rd Qu.: 8.869   3rd Qu.:105.57   3rd Qu.:15.417   3rd Qu.:2.242  
#>  Max.   :10.950   Max.   :122.53   Max.   :25.500   Max.   :4.073  
#>  NA's   :1                                                         
#>        PH         Bicarbonate     Lactic.acid         PCO2      
#>  Min.   :7.090   Min.   :12.86   Min.   :0.500   Min.   :18.75  
#>  1st Qu.:7.335   1st Qu.:23.45   1st Qu.:1.200   1st Qu.:37.04  
#>  Median :7.380   Median :26.50   Median :1.600   Median :43.00  
#>  Mean   :7.379   Mean   :26.91   Mean   :1.853   Mean   :45.54  
#>  3rd Qu.:7.430   3rd Qu.:29.88   3rd Qu.:2.200   3rd Qu.:50.59  
#>  Max.   :7.580   Max.   :47.67   Max.   :8.333   Max.   :98.60  
#>  NA's   :292                     NA's   :229     NA's   :294    
#>        EF       
#>  Min.   :15.00  
#>  1st Qu.:40.00  
#>  Median :55.00  
#>  Mean   :48.72  
#>  3rd Qu.:55.00  
#>  Max.   :75.00  
#> 
pairs.panels(mortality_clean[c(20,35)])

Secara keseluruhan, nilai mean dan median dari setiap kolom hampir sama. Namun, ada dua kolom yang memiliki perbedaan mean dan median yang cukup tinggi yaitu Urine.output dan Creatine.kinase. Bila dilihat persebaran datanya, maka keduanya cenderung mengalami skewed positif. Berdasarkan literatur [https://medium.com/analytics-vidhya/appropriate-ways-to-treat-missing-values-f82f00edd9be], bila terdapat data yang skewed maka disarankan untuk mengisi missing value dengan nilai median atau modus. Dalam hal ini mari kita isi missing value data dengan menggunakan nilai median.

# handle missing value
mortality_clean <- mortality_clean %>% 
  mutate_if(is.numeric, na.aggregate, FUN = median)

Ternyata, setelah dilakukan pengecekan kembali, masih terdapat 1 kolom yang mengandung missing value

# re-check missing value
is.na(mortality_clean) %>% colSums()
#>                  outcome                      age                  gendera 
#>                        1                        0                        0 
#>                      BMI             hypertensive       atrialfibrillation 
#>                        0                        0                        0 
#>           CHD.with.no.MI                 diabetes        deficiencyanemias 
#>                        0                        0                        0 
#>               depression             Hyperlipemia            Renal.failure 
#>                        0                        0                        0 
#>                     COPD               heart.rate  Systolic.blood.pressure 
#>                        0                        0                        0 
#> Diastolic.blood.pressure         Respiratory.rate              temperature 
#>                        0                        0                        0 
#>                    SP.O2             Urine.output               hematocrit 
#>                        0                        0                        0 
#>                      RBC                      MCH                     MCHC 
#>                        0                        0                        0 
#>                      MCV                      RDW                Leucocyte 
#>                        0                        0                        0 
#>                Platelets              Neutrophils                Basophils 
#>                        0                        0                        0 
#>               Lymphocyte                       PT                      INR 
#>                        0                        0                        0 
#>                NT.proBNP          Creatine.kinase               Creatinine 
#>                        0                        0                        0 
#>            Urea.nitrogen                  glucose          Blood.potassium 
#>                        0                        0                        0 
#>             Blood.sodium            Blood.calcium                 Chloride 
#>                        0                        0                        0 
#>                Anion.gap            Magnesium.ion                       PH 
#>                        0                        0                        0 
#>              Bicarbonate              Lactic.acid                     PCO2 
#>                        0                        0                        0 
#>                       EF 
#>                        0
#drop na
mortality_clean <- mortality_clean %>% drop_na()
# inspect the data distribution of numerical variables
x <- inspect_num(mortality_clean) 
show_plot(x)

# Categoric Variable Proportion
x2 <- inspect_cat(mortality_clean) 
show_plot(x2)

Mari kita cek proporsi variable target outcome

# cek proporsi variabel target
prop.table(table(mortality_clean$outcome))
#> 
#>     alive     death 
#> 0.8647959 0.1352041

Terdapat ketidakseimbangan proporsi antara level target kita. Namun, alih-alih langsung menyeimbangkan proporsi data. Akan lebih baik jika data tersebut digunakan dalam pemodelan terlebih dahulu.

5. Cross Validation

Sebelum dilakukan pemodelan, data akan dibagi menjadi data train dan data test. Keduanya memiliki kegunaan yang berbeda. Data train digunakan untuk melatih model, sedangkan data test digunakan untuk menguji model dengan beberapa metrik evaluasi. Kali ini kita akan membagi data train dan data test dengan proporsi 80% dan 20%. K-fold cross validation adalah strategi membagi data untuk membangun model yang lebih general. Proses pemisahan data yang dilakukan oleh k-fold cross validation lebih efektif sehingga sering digunakan oleh data scientist. Hal ini disebabkan karena metode ini akan menghasilkan hasil pengujian keseluruhan model yang lebih konsisten karena data train dan data test dipilih dan kemudian dibagi secara berulang.

RNGkind(sample.kind = "Rounding")
set.seed(123) # mengunci seed agar hasil split sama di tiap komputer

split <- sample(nrow(mortality_clean), nrow(mortality_clean)*0.80)
mortality_train <- mortality_clean[split, ] 
mortality_test <- mortality_clean[-split, ] 

6. Model Naive Bayes

Karakteristik model Naive Bayes adalah asumsikan bahwa seluruh prediktor saling independent dan memiliki bobot yang sama. Kelebihan model ini adalah waktu komputasi yang relatif lebih cepat dibanding model klasifikasi lain, karena hanya mengkomputasi proporsi tabel frekuensi. Oleh karena itu, model ini sering digunakan sebagai baseline model yaitu model acuan terhadap model yang lebih kompleks. Namun, kekurangannya sangat terpengaruh oleh data yang memiliki skewness pada frekuensi data. Skewness Due To Scarcity adalah kondisi ketika terdapat suatu prediktor yang frekuensi nilainya 0 untuk salah satu kelas. Model menjadi bias dalam melakukan prediksi. Untuk mengatasinya, dapat dilakukan Laplace Smoothing, yaitu dengan menambahkan frekuensi dari setiap prediktor sebanyak angka tertentu, sehingga tidak ada lagi prediktor yang memiliki nilai 0.

Beberapa metrik yang akan digunakan untuk mengevaluasi model klasifikasi adalah confusion matrix dimana hasilnya berupa:
* Accuracy : memprediksi dengan benar baik kelas positif maupun negatif TP+TN/TOTAL
* Precision/Pos Pred Value : memprediksi dengan benar kelas positif dari total prediksi kelas positif TP/(TP+FP)
* Recall/Sensitivity : memprediksi dengan benar kelas positif dari total aktual kelas positif TP/(TP+FN)
* Specificity : memprediksi dengan benar kelas negatif dari total aktual kelas negatif TN/(TN+FP)
Selain itu, ada juga metrik lain yaitu Receiver Operating Characteristics (ROC) curve and Area Under Curve (AUC). ROC merupakan kurva yang menggambarkan hubungan antara True Positive Rate (Sensitivity atau Recall) dengan False Positive Rate (1-Specificity) pada setiap thresholdnya. Model yang baik idealnya memiliki True Positive Rate yang tinggi dan False Positive Rate yang rendah. kurve ROC adalah kumpulan titik-titik yang tidak dapat diukur nilainya. Oleh karena itu, hadirlah metrik lain yang bernama AUC yang menunjukkan luas area yang berada di bawah kurva ROC. Semakin tinggi nilai AUC, semakin bagus performa modelnya, dengan kata lain model memiliki TPR/Recall yang tinggi dan FPR yang rendah. Nilai AUC yang paling bagus adalah 1. Baik ROC maupun AUC digunakan untuk mengukur sebaik apakah model dalam membedakan kelas positif maupun negatif.

Model Fitting

set.seed(123)
ctrl <- trainControl(method = "repeatedcv", number = 5, repeats = 3)
# parameter method dapat disesuaikan dengan metode klasifikasi yang digunakan dalam hal ini naive bayes
model_kfold_naive_train <- train(outcome ~ .,
               data = mortality_train,
               method = "nb",
               trControl = ctrl)

Evaluation Data Train

Prediction Data Train

# prediction pada data train
naive_train_pred2 <- predict(model_kfold_naive_train, mortality_train, type = 'raw') # for the probability
naive_train_prob2 <- predict(model_kfold_naive_train, mortality_train, type = 'prob') # for the class prediction

# hasil prediction pada data train
naive_train_table2 <- select(mortality_train, outcome) %>%
  bind_cols(outcome_pred2 = naive_train_pred2) %>% 
  bind_cols(outcome_eprob2 = round(naive_train_prob2[,1],4)) %>% 
  bind_cols(outcome_pprob2 = round(naive_train_prob2[,2],4))

Confusion Matrix Data Train

# performance evaluation - confusion matrix pada data train
naive_train_table2 %>% 
  conf_mat(outcome, outcome_pred2) %>% 
  ggplot2::autoplot(type = "heatmap")

naive_train_table2 %>%
  summarise(
    accuracy = accuracy_vec(outcome, outcome_pred2),
    sensitivity = sens_vec(outcome, outcome_pred2),
    specificity = spec_vec(outcome, outcome_pred2),
    precision = precision_vec(outcome, outcome_pred2)
  )
#>    accuracy sensitivity specificity precision
#> 1 0.9159574   0.9889163    0.453125 0.9198167

Dari hasil confusion matrix, diperoleh hasil:
- Accuracy : 0.915
- Sensitivity : 0.988
- Precision : 0.919
- Specificity : 0.453 Nilai accuracy, sensitivity dan precision sangat tinggi, sedangkan specificity nya cukup rendah, ini berarti bahwa nilai True Negatif nya rendah. Jika kita lihat pembagian confusion matrix di atas, kita dapat mengambil kesimpulan bahwa yang menjadi kelas negatif adalah death. Sehingga, yg ingin kita turunkan adalah False Positive yang berarti bahwa diprediksi hidup alive ternyata meninggal death. Dalam hal ini berarti kita ingin meningkatkan adalah Precision.

ROC & AUC Data Train

# ROC
naive_train_roc2 <- data.frame(prediction=round(naive_train_prob2[,1],4),
                      trueclass=as.numeric(naive_train_table2$outcome=="alive"))

library(pROC)
# Create ROC curve
roc_naive_train <- roc(naive_train_roc2$trueclass, naive_train_roc2$prediction) # membentuk kurva ROC, fungsi ini akan memperhitungkan nilai true positive rate (tpr) dan false positive rate (fpr) sebagai nilai range treshold yang akan dikembalikan sebagai kelas objek roc

# Plot ROC curve
plot(roc_naive_train, col="blue", main="ROC Curve", print.thres=TRUE) #membentuk garis plot ROC, main untuk memberikan judul plot, print.thres untuk mencetak nilai treshold plot yaitu 0.597

# Add diagonal line
lines(x=c(0,1), y=c(0,1), lty=2, col="gray") # menambahkan garis diagonal dimana x dan y adalah titik start dan end garis, fungsi lty untuk menentukan tipe garis yaitu dashed, fungsi col untuk mewarnai garis diagonal

# Add AUC to plot
text(0.8, 0.2, paste("AUC =", round(auc(roc_naive_train), 2)), cex=1.2) #menambahkan nilai AUC ke plot pakai fungsi text dimana 0.8 dan 0.2 adalah posisi dri teks, fungsi cex untuk menentukan ukuran teks

Nilai AUC data train yang diperoleh sangat tinggi yaitu 0.91. Mari kita bandingkan hasilnya dengan data test.

Evaluation Data Test

Prediction Data Test

# prediction pada data test
naive_test_pred2 <- predict(model_kfold_naive_train, mortality_test, type = 'raw') # for the probability
naive_test_prob2 <- predict(model_kfold_naive_train, mortality_test, type = 'prob') # for the class prediction

# hasil prediction pada data test
naive_test_table2 <- select(mortality_test, outcome) %>%
  bind_cols(outcome_pred2 = naive_test_pred2) %>% 
  bind_cols(outcome_eprob2 = round(naive_test_prob2[,1],4)) %>% 
  bind_cols(outcome_pprob2 = round(naive_test_prob2[,2],4))

Confusion Matrix Data Test

# performance evaluation - confusion matrix data test
naive_test_table2 %>% 
  conf_mat(outcome, outcome_pred2) %>% 
  ggplot2::autoplot(type = "heatmap")

naive_test_table2 %>%
  summarise(
    accuracy = accuracy_vec(outcome, outcome_pred2),
    sensitivity = sens_vec(outcome, outcome_pred2),
    specificity = spec_vec(outcome, outcome_pred2),
    precision = precision_vec(outcome, outcome_pred2)
  )
#>    accuracy sensitivity specificity precision
#> 1 0.8771186   0.9756098   0.2258065 0.8928571

Nilai precision yang diperoleh pada data test adalah 0.892. Selisih precision antara data train dengan data test pada model naive bayes adalah 2.7% (tidak lebih dari 10% menurut standar industri), hal ini berarti model kita tidak mengalami overfitting.

ROC & AUC Data Test

# ROC
naive_test_roc2 <- data.frame(prediction=round(naive_test_prob2[,1],4),
                      trueclass=as.numeric(naive_test_table2$outcome=="alive"))

library(pROC)
# Create ROC curve
roc_naive_test <- roc(naive_test_roc2$trueclass, naive_test_roc2$prediction) # membentuk kurva ROC, fungsi ini akan memperhitungkan nilai true positive rate (tpr) dan false positive rate (fpr) sebagai nilai range treshold yang akan dikembalikan sebagai kelas objek roc

# Plot ROC curve
plot(roc_naive_test, col="blue", main="ROC Curve", print.thres=TRUE) #membentuk garis plot ROC, main untuk memberikan judul plot, print.thres untuk mencetak nilai treshold plot yaitu 0.597

# Add diagonal line
lines(x=c(0,1), y=c(0,1), lty=2, col="gray") # menambahkan garis diagonal dimana x dan y adalah titik start dan end garis, fungsi lty untuk menentukan tipe garis yaitu dashed, fungsi col untuk mewarnai garis diagonal

# Add AUC to plot
text(0.8, 0.2, paste("AUC =", round(auc(roc_naive_test), 2)), cex=1.2) #menambahkan nilai AUC ke plot pakai fungsi text dimana 0.8 dan 0.2 adalah posisi dri teks, fungsi cex untuk menentukan ukuran teks

Nilai AUC data test yang diperoleh yaitu 0.86. Selisih AUC antara data train dan data test pada model naive bayes adalah 5% (tidak lebih dari 10%), hal ini berarti model kita tidak mengalami overfitting.

7. Model Decision Tree

Decision Tree merupakan model berdasar pada pohon keputusan yang cukup sederhana dengan performa yang robust/powerful untuk melakukan prediksi. Decision Tree menghasilkan visualisasi berupa pohon keputusan yang dapat diinterpretasi dengan mudah. Decision Tree tidak hanya terbatas pada kasus Classification, namun dapat digunakan pada kasus Regression. Terdapat tiga komponen utama pada pohon keputusan yang terbentuk, yaitu: - Root Node: Percabangan pertama dalam menentukan nilai target, biasa disebut sebagai predictor utama
- Interior/Internal Node: Percabangan selanjutnya yang menggunakan predictor lain apabila root node tidak cukup dalam menentukan target
- Terminal/Leaf Node: Keputusan akhir berupa nilai target yang diprediksi

Model Fitting

# parameter method dapat disesuaikan dengan metode klasifikasi yang digunakan dalam hal ini decision tree
library(rpart)
library(rattle)
library(rpart.plot)
model_dtree_train <- rpart(outcome ~ .,
data = mortality_train)

Visualisasi Model

fancyRpartPlot(model_dtree_train, sub=NULL)

Evaluation Data Train

Prediction Data Train

# prediction pada data train
dtree_train_pred4 <- predict(model_dtree_train, mortality_train, type = 'class') # for the probability
dtree_train_prob4 <- predict(model_dtree_train, mortality_train, type = 'prob') # for the class prediction
# hasil prediksi pada data train
dtree_train_table4 <- select(mortality_train, outcome) %>%
bind_cols(outcome_pred = dtree_train_pred4) %>%
bind_cols(outcome_eprob = round(dtree_train_prob4[,1],4)) %>%
bind_cols(outcome_pprob = round(dtree_train_prob4[,2],4))

Confusion Matrix Data Train

# performance evaluation - confusion matrix data train
dtree_train_table4 %>% 
  conf_mat(outcome, outcome_pred) %>% 
  ggplot2::autoplot(type = "heatmap")

dtree_train_table4 %>%
  summarise(
    accuracy = accuracy_vec(outcome, outcome_pred),
    sensitivity = sens_vec(outcome, outcome_pred),
    specificity = spec_vec(outcome, outcome_pred),
    precision = precision_vec(outcome, outcome_pred)
  )
#>    accuracy sensitivity specificity precision
#> 1 0.9138298   0.9815271    0.484375 0.9235226

Dari hasil confusion matrix, diperoleh hasil:
- Accuracy : 0.913
- Sensitivity : 0.981
- Precision : 0.923
- Specificity : 0.484

ROC & AUC Data Train

# ROC
dtree_train_roc4 <- data.frame(prediction=round(dtree_train_prob4[,1],4),
                      trueclass=as.numeric(dtree_train_table4$outcome=="alive"))

library(pROC)
# Create ROC curve
roc_dtrain_test <- roc(dtree_train_roc4$trueclass, dtree_train_roc4$prediction) # membentuk kurva ROC, fungsi ini akan memperhitungkan nilai true positive rate (tpr) dan false positive rate (fpr) sebagai nilai range treshold yang akan dikembalikan sebagai kelas objek roc

# Plot ROC curve
plot(roc_dtrain_test, col="blue", main="ROC Curve", print.thres=TRUE) #membentuk garis plot ROC, main untuk memberikan judul plot, print.thres untuk mencetak nilai treshold plot yaitu 0.597

# Add diagonal line
lines(x=c(0,1), y=c(0,1), lty=2, col="gray") # menambahkan garis diagonal dimana x dan y adalah titik start dan end garis, fungsi lty untuk menentukan tipe garis yaitu dashed, fungsi col untuk mewarnai garis diagonal

# Add AUC to plot
text(0.8, 0.2, paste("AUC =", round(auc(roc_dtrain_test), 2)), cex=1.2) #menambahkan nilai AUC ke plot pakai fungsi text dimana 0.8 dan 0.2 adalah posisi dri teks, fungsi cex untuk menentukan ukuran teks

Nilai AUC data train yang diperoleh sangat tinggi yaitu 0.77. Mari kita bandingkan hasilnya dengan data test.

Evaluation Data Test

Prediction Data Test

# prediction pada data test
dtree_test_pred4 <- predict(model_dtree_train, mortality_test, type = 'class') # for the probability
dtree_test_prob4 <- predict(model_dtree_train, mortality_test, type = 'prob') # for the class prediction

# hasil prediction data test
dtree_test_table4 <- select(mortality_test, outcome) %>%
  bind_cols(outcome_pred = dtree_test_pred4) %>% 
  bind_cols(outcome_eprob = round(dtree_test_prob4[,1],4)) %>% 
  bind_cols(outcome_pprob = round(dtree_test_prob4[,2],4))

Confusion Matrix Data Test

# performance evaluation - confusion matrix data test
dtree_test_table4 %>% 
  conf_mat(outcome, outcome_pred) %>% 
  ggplot2::autoplot(type = "heatmap")

dtree_test_table4 %>%
  summarise(
    accuracy = accuracy_vec(outcome, outcome_pred),
    sensitivity = sens_vec(outcome, outcome_pred),
    specificity = spec_vec(outcome, outcome_pred),
    precision = precision_vec(outcome, outcome_pred)
  )
#>    accuracy sensitivity specificity precision
#> 1 0.8474576   0.9463415   0.1935484 0.8858447

Nilai precision yang diperoleh pada data test adalah 0.885. Selisih precision antara data train dengan data test pada model decision tree adalah 3.8% (tidak lebih dari 10% menurut standar industri), hal ini berarti model kita tidak mengalami overfitting.

ROC & AUC Data Test

# ROC
dtree_test_roc4 <- data.frame(prediction=round(dtree_test_prob4[,1],4),
                      trueclass=as.numeric(dtree_test_table4$outcome=="alive"))

library(pROC)
# Create ROC curve
roc_dtree_test <- roc(dtree_test_roc4$trueclass, dtree_test_roc4$prediction) # membentuk kurva ROC, fungsi ini akan memperhitungkan nilai true positive rate (tpr) dan false positive rate (fpr) sebagai nilai range treshold yang akan dikembalikan sebagai kelas objek roc

# Plot ROC curve
plot(roc_dtree_test, col="blue", main="ROC Curve", print.thres=TRUE) #membentuk garis plot ROC, main untuk memberikan judul plot, print.thres untuk mencetak nilai treshold plot yaitu 0.597

# Add diagonal line
lines(x=c(0,1), y=c(0,1), lty=2, col="gray") # menambahkan garis diagonal dimana x dan y adalah titik start dan end garis, fungsi lty untuk menentukan tipe garis yaitu dashed, fungsi col untuk mewarnai garis diagonal

# Add AUC to plot
text(0.8, 0.2, paste("AUC =", round(auc(roc_dtree_test), 2)), cex=1.2) #menambahkan nilai AUC ke plot pakai fungsi text dimana 0.8 dan 0.2 adalah posisi dri teks, fungsi cex untuk menentukan ukuran teks

Nilai AUC data test yang diperoleh yaitu 0.59. Selisih AUC antara data train dan data test pada model decision tree adalah 18% (lebih dari 10%), hal ini berarti model kita mengalami overfitting. Karena performa model belum cukup baik, perlu dilakukan model tuning.

Model Tuning

Pruning and Tree Size

Kekurangan dari Decision Tree adalah rentan overfitting, karena mampu melakukan splitting data hingga amat detail, bahkan sampai kondisi dimana 1 leaf node hanya terdapat 1 observasi. Hal ini membuat model Decision Tree hanya menghafal pola data train dengan membuat rules yang kompleks, yang seharusnya mempelajari pola data tersebut. Hasilnya model kurang bisa mengeneralisir pola data test, sehingga performanya jauh lebih buruk. Untuk mengatasi hal itu, Decision Tree perlu berhenti membuat percabangan sehingga pohon lebih sederhana, disebut sebagai pruning.

Parameter pruning pada function rpart():

  • minsplit: Jumlah minimal observasi di tiap cabang (internal node) setelah splitting. Bila tidak terpenuhi, tidak dilakukan percabangan
  • minbucket: Jumlah minimal observasi di terminal/leaf node. Bila tidak terpenuhi, tidak dilakukan percabangan

Model Fitting

model_dtree_train_tuned <- rpart(outcome ~ .,
data = mortality_train,
minsplit = 15, minbucket = 13)

Visualisasi Model Tuning

fancyRpartPlot(model_dtree_train_tuned, sub=NULL)

Evaluation Data Train

Prediction Data Train

# prediction pada data train
dtree_train_pred4_tuned <- predict(model_dtree_train_tuned, mortality_train, type = 'class') # for the probability
dtree_train_prob4_tuned <- predict(model_dtree_train_tuned, mortality_train, type = 'prob') # for the class prediction
# hasil prediksi pada data train
dtree_train_table4_tuned <- select(mortality_train, outcome) %>%
bind_cols(outcome_pred = dtree_train_pred4_tuned) %>%
bind_cols(outcome_eprob = round(dtree_train_prob4_tuned[,1],4)) %>%
bind_cols(outcome_pprob = round(dtree_train_prob4_tuned[,2],4))

Confusion Matrix Data Train

# performance evaluation - confusion matrix
dtree_train_table4_tuned %>%
conf_mat(outcome, outcome_pred) %>%
ggplot2::autoplot(type = "heatmap")

dtree_train_table4_tuned %>%
summarise(
accuracy = accuracy_vec(outcome, outcome_pred),
sensitivity = sens_vec(outcome, outcome_pred),
specificity = spec_vec(outcome, outcome_pred),
precision = precision_vec(outcome, outcome_pred)
)
#>    accuracy sensitivity specificity precision
#> 1 0.8989362   0.9716749      0.4375 0.9163763

Dari hasil confusion matrix, diperoleh hasil:
- Accuracy : 0.898
- Sensitivity : 0.971
- Precision : 0.9163
- Specificity : 0.437

ROC & AUC Data Train

# ROC
dtree_train_roc4_tuned <- data.frame(prediction=round(dtree_train_prob4_tuned[,1],4),
trueclass=as.numeric(dtree_train_table4_tuned$outcome=="alive"))

library(pROC)
# Create ROC curve
roc_dtreetuned_train <- roc(dtree_train_roc4_tuned$trueclass, dtree_train_roc4_tuned$prediction) # membentuk kurva ROC, fungsi ini akan memperhitungkan nilai true positive rate (tpr) dan false positive rate (fpr) sebagai nilai range treshold yang akan dikembalikan sebagai kelas objek roc

# Plot ROC curve
plot(roc_dtreetuned_train, col="blue", main="ROC Curve", print.thres=TRUE) #membentuk garis plot ROC, main untuk memberikan judul plot, print.thres untuk mencetak nilai treshold plot yaitu 0.597

# Add diagonal line
lines(x=c(0,1), y=c(0,1), lty=2, col="gray") # menambahkan garis diagonal dimana x dan y adalah titik start dan end garis, fungsi lty untuk menentukan tipe garis yaitu dashed, fungsi col untuk mewarnai garis diagonal

# Add AUC to plot
text(0.8, 0.2, paste("AUC =", round(auc(roc_dtreetuned_train), 2)), cex=1.2) #menambahkan nilai AUC ke plot pakai fungsi text dimana 0.8 dan 0.2 adalah posisi dri teks, fungsi cex untuk menentukan ukuran teks

Nilai AUC data train yang diperoleh cukup tinggi yaitu 0.75. Mari kita bandingkan hasilnya dengan data test.

Evaluation Data Test

Prediction Data Test

# prediction pada data test
dtree_test_pred4_tuned <- predict(model_dtree_train_tuned, mortality_test, type = 'class') # for the probability
dtree_test_prob4_tuned <- predict(model_dtree_train_tuned, mortality_test, type = 'prob') # for the class prediction

# hasil prediction data test
dtree_test_table4_tuned <- select(mortality_test, outcome) %>%
  bind_cols(outcome_pred = dtree_test_pred4_tuned) %>% 
  bind_cols(outcome_eprob = round(dtree_test_prob4_tuned[,1],4)) %>% 
  bind_cols(outcome_pprob = round(dtree_test_prob4_tuned[,2],4))

Confusion Matrix Data Test

# performance evaluation - confusion matrix data test
dtree_test_table4_tuned %>% 
  conf_mat(outcome, outcome_pred) %>% 
  ggplot2::autoplot(type = "heatmap")

dtree_test_table4_tuned %>%
  summarise(
    accuracy = accuracy_vec(outcome, outcome_pred),
    sensitivity = sens_vec(outcome, outcome_pred),
    specificity = spec_vec(outcome, outcome_pred),
    precision = precision_vec(outcome, outcome_pred)
  )
#>    accuracy sensitivity specificity precision
#> 1 0.8644068   0.9609756   0.2258065 0.8914027

Nilai precision yang diperoleh pada data test adalah 0.891. Selisih precision antara data train dengan data test pada model decision tree tuning adalah 2.53% (tidak lebih dari 10% menurut standar industri), hal ini berarti model kita tidak mengalami overfitting.

ROC & AUC Data Test

# ROC
dtree_test_roc4_tuned <- data.frame(prediction=round(dtree_test_prob4_tuned[,1],4),
                      trueclass=as.numeric(dtree_test_table4_tuned$outcome=="alive"))

library(pROC)
# Create ROC curve
roc_dtreetuned_test <- roc(dtree_test_roc4_tuned$trueclass, dtree_test_roc4_tuned$prediction) # membentuk kurva ROC, fungsi ini akan memperhitungkan nilai true positive rate (tpr) dan false positive rate (fpr) sebagai nilai range treshold yang akan dikembalikan sebagai kelas objek roc

# Plot ROC curve
plot(roc_dtreetuned_test, col="blue", main="ROC Curve", print.thres=TRUE) #membentuk garis plot ROC, main untuk memberikan judul plot, print.thres untuk mencetak nilai treshold plot yaitu 0.597

# Add diagonal line
lines(x=c(0,1), y=c(0,1), lty=2, col="gray") # menambahkan garis diagonal dimana x dan y adalah titik start dan end garis, fungsi lty untuk menentukan tipe garis yaitu dashed, fungsi col untuk mewarnai garis diagonal

# Add AUC to plot
text(0.8, 0.2, paste("AUC =", round(auc(roc_dtreetuned_test), 2)), cex=1.2) #menambahkan nilai AUC ke plot pakai fungsi text dimana 0.8 dan 0.2 adalah posisi dri teks, fungsi cex untuk menentukan ukuran teks

Nilai AUC data test yang diperoleh yaitu 0.72. Selisih AUC antara data train dan data test pada model decision tree tuning adalah 3% (tidak lebih dari 10%), hal ini berarti model tidak mengalami overfitting.

8. Model Random Forest

Random Forest adalah metode algoritma Ensemble yang terdiri dari banyak Decision Tree. Masing-masing Decision Tree memiliki karakteristik yang berbeda dan tidak saling berhubungan. Random Forest memanfaatkan konsep Bagging (Bootstrap and Aggregation) dalam pembuatannya. Prosesnya adalah:
1. Bootstrap sampling: Membuat data dengan random sampling (with replacement) dari data keseluruhan dan memungkinkan adanya baris yang terduplikat
2. Dibentuk 1 decision tree untuk masing-masing data hasil bootstrap. Digunakan parameter mtry untuk memilih banyaknya calon prediktor secara random (Automatic Feature Selection) 3. Melakukan prediksi terhadap observasi yang baru untuk setiap Decision Tree
4. Aggregation: Menghasilkan satu prediksi tunggal untuk memprediksi Kelebihan Random Forest:
- Menekan bias dan variance dari Decision Tree, sehingga performa lebih baik dalam memprediksi
- Automatic feature selection: Prediktor dipilih secara random pada pembuatan Decision Tree
- Terdapat out-of-bag error sebagai pengganti evaluasi model

Kekurangan dari Random Forest adalah membutuhkan waktu komputasi yang cukup lama. Hal ini dapat diatasi dengan membuang predictor yang variansinya mendekati nol (dianggap kurang informatif). Untuk mengetahuinya dapat menggunakan function nearZeroVar() dari caret:

nearZeroVar(mortality_clean)
#> integer(0)

Ternyata tidak ada kolom yang variansinya mendekati nol. Selain cara di atas, cara lain agar menghemat waktu komputasi adalah dengan menyimpan model random forest yang telah dibuat dengan function saveRDS, kemudian ketika ingin dijalankan cukup dipanggil menggunakan function readRDS.

Model Fitting

# set.seed(123)
# ctrl <- trainControl(method = "repeatedcv", number = 5, repeats = 3)
# # parameter method dapat disesuaikan dengan metode klasifikasi yang digunakan dalam hal ini random forest
# model_kfold_rforest_train <- train(outcome ~ .,
#                data = mortality_train,
#                method = "rf",
#                trControl = ctrl)
# 
# saveRDS(model_kfold_rforest_train, "model_kfold_rforest_train.RDS") # simpan model

Read Model

model_kfold_rforest_train <- readRDS("model_kfold_rforest_train.RDS")
model_kfold_rforest_train
#> Random Forest 
#> 
#> 940 samples
#>  48 predictor
#>   2 classes: 'alive', 'death' 
#> 
#> No pre-processing
#> Resampling: Cross-Validated (5 fold, repeated 3 times) 
#> Summary of sample sizes: 752, 753, 751, 751, 753, 753, ... 
#> Resampling results across tuning parameters:
#> 
#>   mtry  Accuracy   Kappa    
#>    2    0.8712850  0.1023046
#>   25    0.8797807  0.2317939
#>   48    0.8790715  0.2351254
#> 
#> Accuracy was used to select the optimal model using the largest value.
#> The final value used for the model was mtry = 25.

Pada summary model di atas, dilakukan beberapa kali percobaan mtry (jumlah predictor random yang digunakan saat splitting node). Secara default akan dicoba sebanyak 3 nilai mtry. Model yang dipilih adalah mtry = 25 dengan nilai Accuracy tertinggi ketika diujikan ke data hasil bootstrap sampling (atau data in-sample, bisa dianggap sebagai data train seperti pada pembuatan model).

Out-of-Bag (OOB) Error

Pada tahap Bootstrap sampling, terdapat data yang tidak digunakan dalam pembuatan model, ini yang disebut sebagai data Out-of-Bag (OOB). Model Random Forest akan menggunakan data OOB sebagai data test untuk melakukan evaluasi dengan cara menghitung error. Error inilah yang disebut sebagai OOB Error. Dalam kasus klasifikasi, OOB error merupakan persentase data OOB yang misklasifikasi. Perlu dicatat bahwa error di unseen data, bukan di data train.

model_kfold_rforest_train$finalModel
#> 
#> Call:
#>  randomForest(x = x, y = y, mtry = param$mtry) 
#>                Type of random forest: classification
#>                      Number of trees: 500
#> No. of variables tried at each split: 25
#> 
#>         OOB estimate of  error rate: 12.02%
#> Confusion matrix:
#>       alive death class.error
#> alive   806     6 0.007389163
#> death   107    21 0.835937500

Nilai OOB Error pada model sebesar 12.02%. Dengan kata lain, akurasi model pada data OOB adalah 87.98%!

Interpretation

Pada machine learning model, terdapat trade-off antara sisi interpretability dan performance. Performance Random Forest dapat diunggulkan dibandingkan model yang lain, namun tidak terlalu dapat diinterpretasi karena banyak faktor random yang terlibat. Namun setidaknya kita dapat melihat predictor apa saja yang paling penting dalam pembuatan Random Forest melalui variable importancenya:

varImp(model_kfold_rforest_train)
#> rf variable importance
#> 
#>   only 20 most important variables shown (out of 48)
#> 
#>                  Overall
#> Lymphocyte        100.00
#> Leucocyte          97.85
#> Lactic.acid        97.75
#> Anion.gap          96.61
#> Bicarbonate        96.14
#> Blood.calcium      65.54
#> SP.O2              60.27
#> PH                 55.54
#> Platelets          55.43
#> Urea.nitrogen      54.21
#> Blood.sodium       51.45
#> Urine.output       49.42
#> Creatinine         47.10
#> RBC                46.94
#> temperature        45.92
#> Respiratory.rate   44.58
#> NT.proBNP          43.58
#> heart.rate         42.39
#> Blood.potassium    40.50
#> PCO2               40.26
plot(varImp(model_kfold_rforest_train))

> ini adalah variabel-variabel yg paling sering dipakai di random forest. Akan dihitung seberapa sering variabel tersebut digunakan oleh model.

Evaluation Data Train

Prediction Data Train

# prediction pada data train
rforest_train_pred5 <- predict(model_kfold_rforest_train, mortality_train, type = 'raw') # for the probability
rforest_train_prob5 <- predict(model_kfold_rforest_train, mortality_train, type = 'prob') # for the class prediction

# hasil prediction pada data train
rforest_train_table5 <- select(mortality_train, outcome) %>%
  bind_cols(outcome_pred = rforest_train_pred5) %>% 
  bind_cols(outcome_eprob = round(rforest_train_prob5[,1],4)) %>% 
  bind_cols(outcome_pprob = round(rforest_train_prob5[,2],4))

Confusion Matrix Data Train

# performance evaluation - confusion matrix data train
rforest_train_table5 %>% 
  conf_mat(outcome, outcome_pred) %>% 
  ggplot2::autoplot(type = "heatmap")

rforest_train_table5 %>%
  summarise(
    accuracy = accuracy_vec(outcome, outcome_pred),
    sensitivity = sens_vec(outcome, outcome_pred),
    specificity = spec_vec(outcome, outcome_pred),
    precision = precision_vec(outcome, outcome_pred)
  )
#>   accuracy sensitivity specificity precision
#> 1        1           1           1         1

Dari hasil confusion matrix, diperoleh hasil:
- Accuracy : 1
- Sensitivity : 1
- Precision : 1
- Specificity : 1

ROC & AUC Data Train

# ROC
rforest_train_roc5 <- data.frame(prediction=round(rforest_train_prob5[,1],4),
                      trueclass=as.numeric(rforest_train_table5$outcome=="alive"))

library(pROC)
# Create ROC curve
roc_rforest_train <- roc(rforest_train_roc5$trueclass, rforest_train_roc5$prediction) # membentuk kurva ROC, fungsi ini akan memperhitungkan nilai true positive rate (tpr) dan false positive rate (fpr) sebagai nilai range treshold yang akan dikembalikan sebagai kelas objek roc

# Plot ROC curve
plot(roc_rforest_train, col="blue", main="ROC Curve", print.thres=TRUE) #membentuk garis plot ROC, main untuk memberikan judul plot, print.thres untuk mencetak nilai treshold plot yaitu 0.597

# Add diagonal line
lines(x=c(0,1), y=c(0,1), lty=2, col="gray") # menambahkan garis diagonal dimana x dan y adalah titik start dan end garis, fungsi lty untuk menentukan tipe garis yaitu dashed, fungsi col untuk mewarnai garis diagonal

# Add AUC to plot
text(0.8, 0.2, paste("AUC =", round(auc(roc_dtreetuned_test), 2)), cex=1.2) #menambahkan nilai AUC ke plot pakai fungsi text dimana 0.8 dan 0.2 adalah posisi dri teks, fungsi cex untuk menentukan ukuran teks

Nilai AUC data train yang diperoleh sangat baik yaitu 0.72. Mari kita bandingkan hasilnya dengan data test.

Evaluation Data Test

Prediction pada Data Test

# prediction pada data test
rforest_test_pred5 <- predict(model_kfold_rforest_train, mortality_test, type = 'raw') # for the probability
rforest_test_prob5 <- predict(model_kfold_rforest_train, mortality_test, type = 'prob') # for the class prediction

# hasil prediction pada data test
rforest_test_table5 <- select(mortality_test, outcome) %>%
  bind_cols(outcome_pred = rforest_test_pred5) %>% 
  bind_cols(outcome_eprob = round(rforest_test_prob5[,1],4)) %>% 
  bind_cols(outcome_pprob = round(rforest_test_prob5[,2],4))

Confusion Matrix Data Test

# performance evaluation - confusion matrix data test
rforest_test_table5 %>% 
  conf_mat(outcome, outcome_pred) %>% 
  ggplot2::autoplot(type = "heatmap")

rforest_test_table5 %>%
  summarise(
    accuracy = accuracy_vec(outcome, outcome_pred),
    sensitivity = sens_vec(outcome, outcome_pred),
    specificity = spec_vec(outcome, outcome_pred),
    precision = precision_vec(outcome, outcome_pred)
  )
#>    accuracy sensitivity specificity precision
#> 1 0.8855932    0.995122   0.1612903 0.8869565

Nilai precision yang diperoleh pada data test adalah 0.985, nilai ini sangat mirip dengan precision data test (selisih 1.5%) yang berarti model ini sangat baik dalam melakukan prediksi.

ROC & AUC Data Test

# ROC
rforest_test_roc5 <- data.frame(prediction=round(rforest_test_prob5[,1],4),
                      trueclass=as.numeric(rforest_test_table5$outcome=="alive"))

library(pROC)
# Create ROC curve
roc_rforest_test <- roc(rforest_test_roc5$trueclass, rforest_test_roc5$prediction) # membentuk kurva ROC, fungsi ini akan memperhitungkan nilai true positive rate (tpr) dan false positive rate (fpr) sebagai nilai range treshold yang akan dikembalikan sebagai kelas objek roc

# Plot ROC curve
plot(roc_rforest_test, col="blue", main="ROC Curve", print.thres=TRUE) #membentuk garis plot ROC, main untuk memberikan judul plot, print.thres untuk mencetak nilai treshold plot yaitu 0.597

# Add diagonal line
lines(x=c(0,1), y=c(0,1), lty=2, col="gray") # menambahkan garis diagonal dimana x dan y adalah titik start dan end garis, fungsi lty untuk menentukan tipe garis yaitu dashed, fungsi col untuk mewarnai garis diagonal

# Add AUC to plot
text(0.8, 0.2, paste("AUC =", round(auc(roc_rforest_test), 2)), cex=1.2) #menambahkan nilai AUC ke plot pakai fungsi text dimana 0.8 dan 0.2 adalah posisi dri teks, fungsi cex untuk menentukan ukuran teks

Nilai AUC data test yang diperoleh yaitu 0.81, mirip dengan nilai AUC data train. Artinya model cukup baik dalam membedakan kelas positif dan negatif.

Ringkasan hasil evaluasi

Berikut adalah ringkasan hasil evaluasi ketiga model pada data test

final_n <- naive_test_table2 %>%
summarise(
accuracy = accuracy_vec(outcome, outcome_pred2),
sensitivity = sens_vec(outcome, outcome_pred2),
specificity = spec_vec(outcome, outcome_pred2),
precision = precision_vec(outcome, outcome_pred2)
)%>%
cbind(AUC = roc_naive_test$auc)

final_dtuned <- dtree_test_table4_tuned %>%
summarise(
accuracy = accuracy_vec(outcome, outcome_pred),
sensitivity = sens_vec(outcome, outcome_pred),
specificity = spec_vec(outcome, outcome_pred),
precision = precision_vec(outcome, outcome_pred)
)%>%
cbind(AUC = roc_dtreetuned_test$auc)

final_f <- rforest_test_table5 %>%
summarise(
accuracy = accuracy_vec(outcome, outcome_pred),
sensitivity = sens_vec(outcome, outcome_pred),
specificity = spec_vec(outcome, outcome_pred),
precision = precision_vec(outcome, outcome_pred)
)%>%
cbind(AUC = roc_rforest_test$auc)

9. Kesimpulan

rbind("Naive Bayes" = final_n, "Decision Tree Tuning" = final_dtuned, "Random Forest" = final_f)
#>                       accuracy sensitivity specificity precision       AUC
#> Naive Bayes          0.8771186   0.9756098   0.2258065 0.8928571 0.8570417
#> Decision Tree Tuning 0.8644068   0.9609756   0.2258065 0.8914027 0.7206137
#> Random Forest        0.8855932   0.9951220   0.1612903 0.8869565 0.8148702

Berdasarkan hasil evaluasi ketiga model di atas, diperoleh hasil yang tinggi secara keseluruhan. Namun, yang memiliki nilai AUC dan precision yang paling tinggi adalah model Naive Bayes. Dalam hal ini, model Naive Bayes cukup baik dalam memprediksi kematian pasien gagal jantung di ICU rumah sakit.