Loading Libraries
library(MASS)
library(rpart)
library(rpart.plot)
library(ggplot2)
Data Exploration
bwt <- birthwt
str(bwt)
## 'data.frame': 189 obs. of 10 variables:
## $ low : int 0 0 0 0 0 0 0 0 0 0 ...
## $ age : int 19 33 20 21 18 21 22 17 29 26 ...
## $ lwt : int 182 155 105 108 107 124 118 103 123 113 ...
## $ race : int 2 3 1 1 1 3 1 3 1 1 ...
## $ smoke: int 0 0 1 1 1 0 0 0 1 1 ...
## $ ptl : int 0 0 0 0 0 0 0 0 0 0 ...
## $ ht : int 0 0 0 0 0 0 0 0 0 0 ...
## $ ui : int 1 0 0 1 1 0 0 0 0 0 ...
## $ ftv : int 0 3 1 2 0 0 1 1 1 0 ...
## $ bwt : int 2523 2551 2557 2594 2600 2622 2637 2637 2663 2665 ...
#Converting numerical to factor variable
fac <- c("low", "race", "smoke", "ht", "ui")
bwt[fac] <- lapply(bwt[fac], as.factor)
summary(bwt)
## low age lwt race smoke ptl
## 0:130 Min. :14.00 Min. : 80.0 1:96 0:115 Min. :0.0000
## 1: 59 1st Qu.:19.00 1st Qu.:110.0 2:26 1: 74 1st Qu.:0.0000
## Median :23.00 Median :121.0 3:67 Median :0.0000
## Mean :23.24 Mean :129.8 Mean :0.1958
## 3rd Qu.:26.00 3rd Qu.:140.0 3rd Qu.:0.0000
## Max. :45.00 Max. :250.0 Max. :3.0000
## ht ui ftv bwt
## 0:177 0:161 Min. :0.0000 Min. : 709
## 1: 12 1: 28 1st Qu.:0.0000 1st Qu.:2414
## Median :0.0000 Median :2977
## Mean :0.7937 Mean :2945
## 3rd Qu.:1.0000 3rd Qu.:3487
## Max. :6.0000 Max. :4990
Data Exploration
ggplot(data = bwt, aes(x = bwt, fill = ui)) +
geom_histogram()

ggplot(data = bwt, aes(x = age, y = bwt)) +
geom_point()

ggplot(data = bwt, aes(x = bwt, fill = as.factor(ptl))) +
geom_histogram()

ggplot(data = bwt, aes(x = lwt, y = bwt)) +
geom_point() +
geom_smooth()

table(bwt$low)
##
## 0 1
## 130 59
Tree Model
set.seed(1)
index <- sample(1:nrow(bwt), 0.75*nrow(bwt))
train <- bwt[index,]
test <- bwt[-index,]
birthwtTree <- rpart(low ~ . - bwt, data = train, method = 'class')
birthwtTree
## n= 141
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 141 44 0 (0.6879433 0.3120567)
## 2) ptl< 0.5 117 30 0 (0.7435897 0.2564103)
## 4) lwt>=106 93 19 0 (0.7956989 0.2043011)
## 8) ht=0 86 15 0 (0.8255814 0.1744186) *
## 9) ht=1 7 3 1 (0.4285714 0.5714286) *
## 5) lwt< 106 24 11 0 (0.5416667 0.4583333)
## 10) age< 22.5 15 4 0 (0.7333333 0.2666667) *
## 11) age>=22.5 9 2 1 (0.2222222 0.7777778) *
## 3) ptl>=0.5 24 10 1 (0.4166667 0.5833333)
## 6) lwt>=131.5 7 2 0 (0.7142857 0.2857143) *
## 7) lwt< 131.5 17 5 1 (0.2941176 0.7058824) *
rpart.plot(birthwtTree)

Validation on Test Dataset
birthwtPred <- predict(birthwtTree, newdata = test, type = "class")
table(birthwtPred, test$low)
##
## birthwtPred 0 1
## 0 31 10
## 1 2 5
mean(birthwtPred == test$low)
## [1] 0.75