# Demo Dataset
library(C50)
data(churn)
churn <- rbind(churnTrain, churnTest)
str(churnTrain)
## 'data.frame': 3333 obs. of 20 variables:
## $ state : Factor w/ 51 levels "AK","AL","AR",..: 17 36 32 36 37 2 20 25 19 50 ...
## $ account_length : int 128 107 137 84 75 118 121 147 117 141 ...
## $ area_code : Factor w/ 3 levels "area_code_408",..: 2 2 2 1 2 3 3 2 1 2 ...
## $ international_plan : Factor w/ 2 levels "no","yes": 1 1 1 2 2 2 1 2 1 2 ...
## $ voice_mail_plan : Factor w/ 2 levels "no","yes": 2 2 1 1 1 1 2 1 1 2 ...
## $ number_vmail_messages : int 25 26 0 0 0 0 24 0 0 37 ...
## $ total_day_minutes : num 265 162 243 299 167 ...
## $ total_day_calls : int 110 123 114 71 113 98 88 79 97 84 ...
## $ total_day_charge : num 45.1 27.5 41.4 50.9 28.3 ...
## $ total_eve_minutes : num 197.4 195.5 121.2 61.9 148.3 ...
## $ total_eve_calls : int 99 103 110 88 122 101 108 94 80 111 ...
## $ total_eve_charge : num 16.78 16.62 10.3 5.26 12.61 ...
## $ total_night_minutes : num 245 254 163 197 187 ...
## $ total_night_calls : int 91 103 104 89 121 118 118 96 90 97 ...
## $ total_night_charge : num 11.01 11.45 7.32 8.86 8.41 ...
## $ total_intl_minutes : num 10 13.7 12.2 6.6 10.1 6.3 7.5 7.1 8.7 11.2 ...
## $ total_intl_calls : int 3 3 5 7 3 6 7 6 4 5 ...
## $ total_intl_charge : num 2.7 3.7 3.29 1.78 2.73 1.7 2.03 1.92 2.35 3.02 ...
## $ number_customer_service_calls: int 1 1 0 2 3 0 3 0 1 0 ...
## $ churn : Factor w/ 2 levels "yes","no": 2 2 2 2 2 2 2 2 2 2 ...
# Model Training
library(e1071)
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(rpart)
dt_model <- train(churn ~ ., data = churnTrain, metric = "Accuracy", method = "rpart")
typeof(dt_model)
## [1] "list"
names(dt_model)
## [1] "method" "modelInfo" "modelType" "results"
## [5] "pred" "bestTune" "call" "dots"
## [9] "metric" "control" "finalModel" "preProcess"
## [13] "trainingData" "resample" "resampledCM" "perfNames"
## [17] "maximize" "yLimits" "times" "levels"
## [21] "terms" "coefnames" "contrasts" "xlevels"
# Check Decision Tree Classifiers
print(dt_model)
## CART
##
## 3333 samples
## 19 predictor
## 2 classes: 'yes', 'no'
##
## No pre-processing
## Resampling: Bootstrapped (25 reps)
## Summary of sample sizes: 3333, 3333, 3333, 3333, 3333, 3333, ...
## Resampling results across tuning parameters:
##
## cp Accuracy Kappa
## 0.07867495 0.8799575 0.3800602
## 0.08488613 0.8699528 0.2870143
## 0.08902692 0.8669191 0.2570134
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.07867495.
# Check Decision Tree Classifier Details
print(dt_model$finalModel)
## n= 3333
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 3333 483 no (0.1449145 0.8550855)
## 2) total_day_minutes>=264.45 211 84 yes (0.6018957 0.3981043)
## 4) voice_mail_planyes< 0.5 158 37 yes (0.7658228 0.2341772) *
## 5) voice_mail_planyes>=0.5 53 6 no (0.1132075 0.8867925) *
## 3) total_day_minutes< 264.45 3122 356 no (0.1140295 0.8859705) *
# Model Prediction (1)
dt_predict <- predict(dt_model, newdata = churnTest, na.action = na.omit, type = "prob")
head(dt_predict, 5)
## yes no
## 1 0.1140295 0.8859705
## 2 0.1140295 0.8859705
## 3 0.1132075 0.8867925
## 4 0.1140295 0.8859705
## 5 0.1140295 0.8859705
# Model Prediction (2)
dt_predict2 <- predict(dt_model, newdata = churnTest, type = "raw")
head(dt_predict2)
## [1] no no no no no no
## Levels: yes no
# MOdel Tuning (1)
dt_model_tune <- train(churn ~ ., data = churnTrain, method = "rpart", metric = "Accuracy", tuneLength = 8)
print(dt_model_tune$finalModel)
## n= 3333
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 3333 483 no (0.14491449 0.85508551)
## 2) total_day_minutes>=264.45 211 84 yes (0.60189573 0.39810427)
## 4) voice_mail_planyes< 0.5 158 37 yes (0.76582278 0.23417722)
## 8) total_eve_minutes>=187.75 101 5 yes (0.95049505 0.04950495) *
## 9) total_eve_minutes< 187.75 57 25 no (0.43859649 0.56140351)
## 18) total_day_minutes>=277.7 32 11 yes (0.65625000 0.34375000)
## 36) total_eve_minutes>=144.35 24 4 yes (0.83333333 0.16666667) *
## 37) total_eve_minutes< 144.35 8 1 no (0.12500000 0.87500000) *
## 19) total_day_minutes< 277.7 25 4 no (0.16000000 0.84000000) *
## 5) voice_mail_planyes>=0.5 53 6 no (0.11320755 0.88679245) *
## 3) total_day_minutes< 264.45 3122 356 no (0.11402947 0.88597053)
## 6) number_customer_service_calls>=3.5 251 124 yes (0.50597610 0.49402390)
## 12) total_day_minutes< 160.2 102 13 yes (0.87254902 0.12745098) *
## 13) total_day_minutes>=160.2 149 38 no (0.25503356 0.74496644)
## 26) total_eve_minutes< 141.75 19 5 yes (0.73684211 0.26315789) *
## 27) total_eve_minutes>=141.75 130 24 no (0.18461538 0.81538462)
## 54) total_day_minutes< 175.75 34 14 no (0.41176471 0.58823529)
## 108) total_eve_minutes< 212.15 16 2 yes (0.87500000 0.12500000) *
## 109) total_eve_minutes>=212.15 18 0 no (0.00000000 1.00000000) *
## 55) total_day_minutes>=175.75 96 10 no (0.10416667 0.89583333) *
## 7) number_customer_service_calls< 3.5 2871 229 no (0.07976315 0.92023685)
## 14) international_planyes>=0.5 267 101 no (0.37827715 0.62172285)
## 28) total_intl_calls< 2.5 51 0 yes (1.00000000 0.00000000) *
## 29) total_intl_calls>=2.5 216 50 no (0.23148148 0.76851852)
## 58) total_intl_minutes>=13.1 43 0 yes (1.00000000 0.00000000) *
## 59) total_intl_minutes< 13.1 173 7 no (0.04046243 0.95953757) *
## 15) international_planyes< 0.5 2604 128 no (0.04915515 0.95084485)
## 30) total_day_minutes>=223.25 383 68 no (0.17754569 0.82245431)
## 60) total_eve_minutes>=259.8 51 17 yes (0.66666667 0.33333333)
## 120) voice_mail_planyes< 0.5 40 6 yes (0.85000000 0.15000000) *
## 121) voice_mail_planyes>=0.5 11 0 no (0.00000000 1.00000000) *
## 61) total_eve_minutes< 259.8 332 34 no (0.10240964 0.89759036) *
## 31) total_day_minutes< 223.25 2221 60 no (0.02701486 0.97298514) *
# Model Tuning (2)
dt_model_tune2 <- train(churn ~ ., data = churnTrain, method = "rpart", tuneGrid = expand.grid(cp = seq(0, 0.1, 0.01)))
print(dt_model_tune2$finalModel)
## n= 3333
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 3333 483 no (0.14491449 0.85508551)
## 2) total_day_minutes>=264.45 211 84 yes (0.60189573 0.39810427)
## 4) voice_mail_planyes< 0.5 158 37 yes (0.76582278 0.23417722)
## 8) total_eve_minutes>=187.75 101 5 yes (0.95049505 0.04950495) *
## 9) total_eve_minutes< 187.75 57 25 no (0.43859649 0.56140351)
## 18) total_day_minutes>=277.7 32 11 yes (0.65625000 0.34375000)
## 36) total_eve_minutes>=144.35 24 4 yes (0.83333333 0.16666667) *
## 37) total_eve_minutes< 144.35 8 1 no (0.12500000 0.87500000) *
## 19) total_day_minutes< 277.7 25 4 no (0.16000000 0.84000000) *
## 5) voice_mail_planyes>=0.5 53 6 no (0.11320755 0.88679245) *
## 3) total_day_minutes< 264.45 3122 356 no (0.11402947 0.88597053)
## 6) number_customer_service_calls>=3.5 251 124 yes (0.50597610 0.49402390)
## 12) total_day_minutes< 160.2 102 13 yes (0.87254902 0.12745098) *
## 13) total_day_minutes>=160.2 149 38 no (0.25503356 0.74496644)
## 26) total_eve_minutes< 141.75 19 5 yes (0.73684211 0.26315789) *
## 27) total_eve_minutes>=141.75 130 24 no (0.18461538 0.81538462)
## 54) total_day_minutes< 175.75 34 14 no (0.41176471 0.58823529)
## 108) total_eve_minutes< 212.15 16 2 yes (0.87500000 0.12500000) *
## 109) total_eve_minutes>=212.15 18 0 no (0.00000000 1.00000000) *
## 55) total_day_minutes>=175.75 96 10 no (0.10416667 0.89583333) *
## 7) number_customer_service_calls< 3.5 2871 229 no (0.07976315 0.92023685)
## 14) international_planyes>=0.5 267 101 no (0.37827715 0.62172285)
## 28) total_intl_calls< 2.5 51 0 yes (1.00000000 0.00000000) *
## 29) total_intl_calls>=2.5 216 50 no (0.23148148 0.76851852)
## 58) total_intl_minutes>=13.1 43 0 yes (1.00000000 0.00000000) *
## 59) total_intl_minutes< 13.1 173 7 no (0.04046243 0.95953757) *
## 15) international_planyes< 0.5 2604 128 no (0.04915515 0.95084485)
## 30) total_day_minutes>=223.25 383 68 no (0.17754569 0.82245431)
## 60) total_eve_minutes>=259.8 51 17 yes (0.66666667 0.33333333)
## 120) voice_mail_planyes< 0.5 40 6 yes (0.85000000 0.15000000) *
## 121) voice_mail_planyes>=0.5 11 0 no (0.00000000 1.00000000) *
## 61) total_eve_minutes< 259.8 332 34 no (0.10240964 0.89759036) *
## 31) total_day_minutes< 223.25 2221 60 no (0.02701486 0.97298514) *
# Model Pre-Pruning
dt_model_preprune <- train(churn ~ ., data = churnTrain, method = "rpart", metric = "Accuracy", tuneLength = 8,
control = rpart.control(minsplit = 50, minbucket = 20, maxdepth = 5))
print(dt_model_preprune$finalModel)
## n= 3333
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 3333 483 no (0.14491449 0.85508551)
## 2) total_day_minutes>=264.45 211 84 yes (0.60189573 0.39810427)
## 4) voice_mail_planyes< 0.5 158 37 yes (0.76582278 0.23417722)
## 8) total_eve_minutes>=187.75 101 5 yes (0.95049505 0.04950495) *
## 9) total_eve_minutes< 187.75 57 25 no (0.43859649 0.56140351)
## 18) total_day_minutes>=277.7 32 11 yes (0.65625000 0.34375000) *
## 19) total_day_minutes< 277.7 25 4 no (0.16000000 0.84000000) *
## 5) voice_mail_planyes>=0.5 53 6 no (0.11320755 0.88679245) *
## 3) total_day_minutes< 264.45 3122 356 no (0.11402947 0.88597053)
## 6) number_customer_service_calls>=3.5 251 124 yes (0.50597610 0.49402390)
## 12) total_day_minutes< 160.2 102 13 yes (0.87254902 0.12745098) *
## 13) total_day_minutes>=160.2 149 38 no (0.25503356 0.74496644)
## 26) total_eve_minutes< 155.5 29 11 yes (0.62068966 0.37931034) *
## 27) total_eve_minutes>=155.5 120 20 no (0.16666667 0.83333333) *
## 7) number_customer_service_calls< 3.5 2871 229 no (0.07976315 0.92023685)
## 14) international_planyes>=0.5 267 101 no (0.37827715 0.62172285)
## 28) total_intl_calls< 2.5 51 0 yes (1.00000000 0.00000000) *
## 29) total_intl_calls>=2.5 216 50 no (0.23148148 0.76851852)
## 58) total_intl_minutes>=13.1 43 0 yes (1.00000000 0.00000000) *
## 59) total_intl_minutes< 13.1 173 7 no (0.04046243 0.95953757) *
## 15) international_planyes< 0.5 2604 128 no (0.04915515 0.95084485)
## 30) total_day_minutes>=223.25 383 68 no (0.17754569 0.82245431)
## 60) total_eve_minutes>=259.8 51 17 yes (0.66666667 0.33333333) *
## 61) total_eve_minutes< 259.8 332 34 no (0.10240964 0.89759036) *
## 31) total_day_minutes< 223.25 2221 60 no (0.02701486 0.97298514) *
# Model Post-Pruning
dt_model_postprune <- prune(dt_model$finalModel, cp = 0.2)
print(dt_model_postprune)
## n= 3333
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 3333 483 no (0.1449145 0.8550855) *
# Check Decision Tree Classifier (1)
plot(dt_model$finalModel)
text(dt_model$finalModel)

# Check Decision Tree Classifier (2)
library(rattle)
## Rattle: A free graphical interface for data science with R.
## Version 5.2.0 Copyright (c) 2006-2018 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
fancyRpartPlot(dt_model$finalModel)

# Check Decision Tree Classifier (3)
library(rpart.plot)
prp(dt_model$finalModel)
