Loading Libraries

library(MASS)
library(rpart)
library(rpart.plot)
library(ggplot2)

Data Exploration

bwt <- birthwt
str(bwt)
## 'data.frame':    189 obs. of  10 variables:
##  $ low  : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ age  : int  19 33 20 21 18 21 22 17 29 26 ...
##  $ lwt  : int  182 155 105 108 107 124 118 103 123 113 ...
##  $ race : int  2 3 1 1 1 3 1 3 1 1 ...
##  $ smoke: int  0 0 1 1 1 0 0 0 1 1 ...
##  $ ptl  : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ ht   : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ ui   : int  1 0 0 1 1 0 0 0 0 0 ...
##  $ ftv  : int  0 3 1 2 0 0 1 1 1 0 ...
##  $ bwt  : int  2523 2551 2557 2594 2600 2622 2637 2637 2663 2665 ...
#Converting numerical to factor variable
fac <- c("low", "race", "smoke", "ht", "ui")
bwt[fac] <- lapply(bwt[fac], as.factor)
summary(bwt)
##  low          age             lwt        race   smoke        ptl        
##  0:130   Min.   :14.00   Min.   : 80.0   1:96   0:115   Min.   :0.0000  
##  1: 59   1st Qu.:19.00   1st Qu.:110.0   2:26   1: 74   1st Qu.:0.0000  
##          Median :23.00   Median :121.0   3:67           Median :0.0000  
##          Mean   :23.24   Mean   :129.8                  Mean   :0.1958  
##          3rd Qu.:26.00   3rd Qu.:140.0                  3rd Qu.:0.0000  
##          Max.   :45.00   Max.   :250.0                  Max.   :3.0000  
##  ht      ui           ftv              bwt      
##  0:177   0:161   Min.   :0.0000   Min.   : 709  
##  1: 12   1: 28   1st Qu.:0.0000   1st Qu.:2414  
##                  Median :0.0000   Median :2977  
##                  Mean   :0.7937   Mean   :2945  
##                  3rd Qu.:1.0000   3rd Qu.:3487  
##                  Max.   :6.0000   Max.   :4990

Data Exploration

ggplot(data = bwt, aes(x = bwt, fill = ui)) +
  geom_histogram()

ggplot(data = bwt, aes(x = age, y = bwt)) +
  geom_point()

ggplot(data = bwt, aes(x = bwt, fill = as.factor(ptl))) +
  geom_histogram()

ggplot(data = bwt, aes(x = lwt, y = bwt)) +
  geom_point() +
  geom_smooth()

table(bwt$low)
## 
##   0   1 
## 130  59

Tree Model

set.seed(1)
index <- sample(1:nrow(bwt), 0.75*nrow(bwt))
train <- bwt[index,]
test <- bwt[-index,]

birthwtTree <- rpart(low ~ . - bwt, data = train, method = 'class')

birthwtTree
## n= 141 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 141 44 0 (0.6879433 0.3120567)  
##    2) ptl< 0.5 117 30 0 (0.7435897 0.2564103)  
##      4) lwt>=106 93 19 0 (0.7956989 0.2043011)  
##        8) ht=0 86 15 0 (0.8255814 0.1744186) *
##        9) ht=1 7  3 1 (0.4285714 0.5714286) *
##      5) lwt< 106 24 11 0 (0.5416667 0.4583333)  
##       10) age< 22.5 15  4 0 (0.7333333 0.2666667) *
##       11) age>=22.5 9  2 1 (0.2222222 0.7777778) *
##    3) ptl>=0.5 24 10 1 (0.4166667 0.5833333)  
##      6) lwt>=131.5 7  2 0 (0.7142857 0.2857143) *
##      7) lwt< 131.5 17  5 1 (0.2941176 0.7058824) *
rpart.plot(birthwtTree)

Validation on Test Dataset

birthwtPred <- predict(birthwtTree, newdata = test, type = "class")
table(birthwtPred, test$low)
##            
## birthwtPred  0  1
##           0 31 10
##           1  2  5
mean(birthwtPred == test$low)
## [1] 0.75