Το dataset αφορά την πρόβλεψη ύπαρξης καρδιοπάθειας με βάση ιατρικά χαρακτηριστικά ασθενών.

Η εξαρτημένη μεταβλητή είναι:

data <- read.csv("heart.csv")

str(data)
## 'data.frame':    1025 obs. of  14 variables:
##  $ age     : int  52 53 70 61 62 58 58 55 46 54 ...
##  $ sex     : int  1 1 1 1 0 0 1 1 1 1 ...
##  $ cp      : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ trestbps: int  125 140 145 148 138 100 114 160 120 122 ...
##  $ chol    : int  212 203 174 203 294 248 318 289 249 286 ...
##  $ fbs     : int  0 1 0 0 1 0 0 0 0 0 ...
##  $ restecg : int  1 0 1 1 1 0 2 0 0 0 ...
##  $ thalach : int  168 155 125 161 106 122 140 145 144 116 ...
##  $ exang   : int  0 1 1 0 0 0 0 1 0 1 ...
##  $ oldpeak : num  1 3.1 2.6 0 1.9 1 4.4 0.8 0.8 3.2 ...
##  $ slope   : int  2 0 0 2 1 1 0 1 2 1 ...
##  $ ca      : int  2 0 0 1 3 0 3 1 0 2 ...
##  $ thal    : int  3 3 3 3 2 2 1 3 3 2 ...
##  $ target  : int  0 0 0 0 0 1 0 0 0 0 ...
summary(data)
##       age             sex               cp            trestbps    
##  Min.   :29.00   Min.   :0.0000   Min.   :0.0000   Min.   : 94.0  
##  1st Qu.:48.00   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:120.0  
##  Median :56.00   Median :1.0000   Median :1.0000   Median :130.0  
##  Mean   :54.43   Mean   :0.6956   Mean   :0.9424   Mean   :131.6  
##  3rd Qu.:61.00   3rd Qu.:1.0000   3rd Qu.:2.0000   3rd Qu.:140.0  
##  Max.   :77.00   Max.   :1.0000   Max.   :3.0000   Max.   :200.0  
##       chol          fbs            restecg          thalach     
##  Min.   :126   Min.   :0.0000   Min.   :0.0000   Min.   : 71.0  
##  1st Qu.:211   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:132.0  
##  Median :240   Median :0.0000   Median :1.0000   Median :152.0  
##  Mean   :246   Mean   :0.1493   Mean   :0.5298   Mean   :149.1  
##  3rd Qu.:275   3rd Qu.:0.0000   3rd Qu.:1.0000   3rd Qu.:166.0  
##  Max.   :564   Max.   :1.0000   Max.   :2.0000   Max.   :202.0  
##      exang           oldpeak          slope             ca        
##  Min.   :0.0000   Min.   :0.000   Min.   :0.000   Min.   :0.0000  
##  1st Qu.:0.0000   1st Qu.:0.000   1st Qu.:1.000   1st Qu.:0.0000  
##  Median :0.0000   Median :0.800   Median :1.000   Median :0.0000  
##  Mean   :0.3366   Mean   :1.072   Mean   :1.385   Mean   :0.7541  
##  3rd Qu.:1.0000   3rd Qu.:1.800   3rd Qu.:2.000   3rd Qu.:1.0000  
##  Max.   :1.0000   Max.   :6.200   Max.   :2.000   Max.   :4.0000  
##       thal           target      
##  Min.   :0.000   Min.   :0.0000  
##  1st Qu.:2.000   1st Qu.:0.0000  
##  Median :2.000   Median :1.0000  
##  Mean   :2.324   Mean   :0.5132  
##  3rd Qu.:3.000   3rd Qu.:1.0000  
##  Max.   :3.000   Max.   :1.0000

Το dataset περιλαμβάνει δημογραφικά και ιατρικά χαρακτηριστικά και είναι κατάλληλο για πρόβλημα ταξινόμησης.

library(caTools)
set.seed(123)

spl <- sample.split(data$target, SplitRatio = 0.7)
Train <- subset(data, spl==TRUE)
Test <- subset(data, spl==FALSE)

log_model <- glm(target ~ ., data = Train, family = binomial)

summary(log_model)
## 
## Call:
## glm(formula = target ~ ., family = binomial, data = Train)
## 
## Coefficients:
##              Estimate Std. Error z value Pr(>|z|)    
## (Intercept)  4.198600   1.732006   2.424  0.01535 *  
## age         -0.010376   0.015420  -0.673  0.50100    
## sex         -1.914140   0.313187  -6.112 9.85e-10 ***
## cp           0.842150   0.119272   7.061 1.66e-12 ***
## trestbps    -0.020664   0.006969  -2.965  0.00303 ** 
## chol        -0.006319   0.002473  -2.556  0.01060 *  
## fbs         -0.349104   0.348385  -1.002  0.31631    
## restecg      0.342112   0.227805   1.502  0.13315    
## thalach      0.026931   0.006820   3.949 7.86e-05 ***
## exang       -1.053479   0.270717  -3.891 9.96e-05 ***
## oldpeak     -0.577705   0.141253  -4.090 4.32e-05 ***
## slope        0.400683   0.233948   1.713  0.08677 .  
## ca          -0.727202   0.124163  -5.857 4.72e-09 ***
## thal        -0.931300   0.185996  -5.007 5.53e-07 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 993.47  on 716  degrees of freedom
## Residual deviance: 498.70  on 703  degrees of freedom
## AIC: 526.7
## 
## Number of Fisher Scoring iterations: 6
library(rpart)
library(rpart.plot)

tree_model <- rpart(target ~ ., data = Train, method = "class", minbucket = 20)

rpart.plot(tree_model)

pred_tree <- predict(tree_model, Test, type = "class")

table(Test$target, pred_tree)
##    pred_tree
##       0   1
##   0 124  26
##   1  33 125
mean(pred_tree == Test$target)
## [1] 0.8084416
prob_log <- predict(log_model, Test, type = "response")
pred_log <- ifelse(prob_log > 0.5, 1, 0)

mean(pred_log == Test$target)
## [1] 0.8538961
library(ROCR)

pred <- prediction(predict(tree_model, Test)[,2], Test$target)
perf <- performance(pred, "tpr", "fpr")

plot(perf)

as.numeric(performance(pred, "auc")@y.values)
## [1] 0.8690295
library(caret)
## Loading required package: ggplot2
## Loading required package: lattice
library(e1071)
## 
## Attaching package: 'e1071'
## The following object is masked from 'package:ggplot2':
## 
##     element
ctrl <- trainControl(method = "cv", number = 5)
grid <- expand.grid(.cp = seq(0.001, 0.05, 0.005))

Train$target <- as.factor(Train$target)

cv_model <- train(target ~ ., data = Train,
                  method = "rpart",
                  trControl = ctrl,
                  tuneGrid = grid)

tree_cv <- rpart(target ~ ., data = Train,
                 method = "class",
                 cp = cv_model$bestTune)

pred_cv <- predict(tree_cv, Test, type = "class")

mean(pred_cv == Test$target)
## [1] 0.8506494

Συμπεράσματα:

  1. Το CART είναι πιο ερμηνεύσιμο
  2. Logistic Regression πιο στατιστική
  3. Παρόμοια απόδοση
  4. Cross-validation βελτιώνει την ακρίβεια