Friedman (1991) introduced several benchmark data sets create by simulation. One of these simulations used the following nonlinear equation to create data:
y = 10 sin(pi x1x2) + 20(x3 - 0.5)2 + 10x4 + 5x5 + N (0, sigma 2)
where the x values are random variables uniformly distributed between [0, 1] (there are also 5 other non-informative variables also created in the simulation). The package mlbench contains a function called mlbench.friedman1 that simulates these data:
library(mlbench)
set.seed(200)
simulated <- mlbench.friedman1(200, sd = 1)
simulated <- cbind(simulated$x, simulated$y)
simulated <- as.data.frame(simulated)
colnames(simulated)[ncol(simulated)] <- "y"
library(randomForest)
## Warning: package 'randomForest' was built under R version 4.4.2
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
library(caret)
## Loading required package: lattice
model1 <- randomForest(y ~ ., data = simulated, importance = TRUE, ntree = 5000)
rfImp1 <- varImp(model1, scale = FALSE)
print(rfImp1)
## Overall
## V1 8.79515634
## V2 6.56654060
## V3 0.73300417
## V4 7.64635222
## V5 2.15996803
## V6 0.13519686
## V7 0.06891977
## V8 -0.14606085
## V9 -0.08615019
## V10 -0.02833177
Did the random forest model significantly use the uninformative predictors (V6– V10)?
We see the highest values are regarding V1, V2, V4, and V5 (with V3 potentially being slightly important, but less important than the others). For variables V6 through V10, the importance values are very low, thus meaning that these variables are of little importance to the model.
simulated$duplicate1 <- simulated$V1 + rnorm(200) * .1
cor(simulated$duplicate1, simulated$V1)
## [1] 0.953051
Fit another random forest model to these data. Did the importance score for V1 change? What happens when you add another predictor that is also highly correlated with V1?
model_corr <- randomForest(y ~ ., data = simulated, importance = TRUE, ntree = 5000)
rfImp_corr <- varImp(model_corr, scale = FALSE)
print(rfImp_corr)
## Overall
## V1 6.17739898
## V2 5.97111496
## V3 0.59650620
## V4 6.93143877
## V5 2.07678953
## V6 0.18118973
## V7 0.05271119
## V8 -0.09553865
## V9 -0.05854792
## V10 0.02412532
## duplicate1 3.81659740
After adding this correlated variable, we see the importance of V1 through V5 slightly decreases for each. They remain important (less so for V3), just less important. We see that the correlated variable has a high importance, but still is not the most important variable.
library(partykit)
## Warning: package 'partykit' was built under R version 4.4.3
## Loading required package: grid
## Loading required package: libcoin
## Warning: package 'libcoin' was built under R version 4.4.3
## Loading required package: mvtnorm
cf <- cforest(y~., data = simulated)
cf_imp <- varimp(cf)
print(cf_imp)
## V1 V2 V3 V4 V5 V6
## 6.64135674 5.69655881 0.18338281 6.57547998 1.92508374 -0.06731552
## V7 V8 V9 V10 duplicate1
## 0.15733663 0.04363038 -0.11867178 -0.02181568 4.59634023
We see the same pattern as the rf model in part a) and b), with the main exception that V3 now has importance on the level of V6 to V10. V1, V2, V4, V5, and duplicate 1 are relatively important while all the others are not important to the model’s prediction.
library(xgboost)
## Warning: package 'xgboost' was built under R version 4.4.2
# prepare data
X <- as.matrix(simulated[, -which(names(simulated) == "y")])
y <- simulated$y
# train model
xgb <- xgboost(data = X, label = y, nrounds = 100, verbose = 0)
# Variable importance
xgb_imp <- xgb.importance(model = xgb)
print(xgb_imp)
## Feature Gain Cover Frequency
## <char> <num> <num> <num>
## 1: V2 0.257443654 0.09690187 0.12219321
## 2: V1 0.255829290 0.09024620 0.18328982
## 3: V4 0.235146059 0.12314760 0.10182768
## 4: duplicate1 0.102528395 0.07278831 0.05378590
## 5: V5 0.082429515 0.09899870 0.08668407
## 6: V3 0.026268440 0.10382848 0.09033943
## 7: V7 0.017039093 0.07496761 0.07362924
## 8: V6 0.010008273 0.07063258 0.07624021
## 9: V9 0.006440871 0.07231712 0.06736292
## 10: V10 0.004831916 0.10693839 0.07571802
## 11: V8 0.002034494 0.08923313 0.06892950
For XGBoost, we see that the highest Gain values are on V2, V1, V4, duplicate1, and V5. This is a fractional contribution of each feature, meaning that a higher gain indicates a more important predictive feature. This correlates with what we have previously seen in that these 5 are considered to be the most important features.
library(rpart)
# CART
cart <- rpart(y ~ ., data = simulated)
print(cart$variable.importance)
## duplicate1 V1 V4 V2 V5 V6 V7
## 1660.81880 1587.46111 1313.16709 1083.63087 394.50197 333.28745 321.67664
## V10 V8 V9 V3
## 258.80701 195.88849 115.31916 42.80127
Using CART we can see that it finds duplicate1 to be extremely important, with V1, V4, and V2 following. Interestingly enough, this model does not seem to use V5 very significantly, which is unique in comparison to the previous models.
CART focuses on duplicate1 and v1 due to their high correlation, which overestimates their importance. Random forest only chooses random subsets of variables, meaning that oftentimes only one of these predictors is available in each tree, meaning that the trees can focus on other predictors. This is why random forest is considered to decorrelate variables, which we see is the case in this example. XGBoost uses a different approach to counteract this correlation which is based on boosting. It uses decision trees trained on the residuals in order to add a penalty, which provides regularization for the model. This is in addition to only selecting a certain amount of predictors like random forest, which provides further resilience against correlated variables.
The “churn” data set in the MLC++ software package was developed to predict telecom customer churn based on information about their account. The data files state that the data are “artificial based on claims similar to real world.” The data consist of 19 predictors related to the customer account, such as the number of customer service calls, the area code, and the number of minutes. The outcome is whether the customer churned.
The data are contained in the modeldata package and can be loaded using
library(modeldata)
## Warning: package 'modeldata' was built under R version 4.4.2
data(mlc_churn)
str(mlc_churn)
## tibble [5,000 × 20] (S3: tbl_df/tbl/data.frame)
## $ state : Factor w/ 51 levels "AK","AL","AR",..: 17 36 32 36 37 2 20 25 19 50 ...
## $ account_length : int [1:5000] 128 107 137 84 75 118 121 147 117 141 ...
## $ area_code : Factor w/ 3 levels "area_code_408",..: 2 2 2 1 2 3 3 2 1 2 ...
## $ international_plan : Factor w/ 2 levels "no","yes": 1 1 1 2 2 2 1 2 1 2 ...
## $ voice_mail_plan : Factor w/ 2 levels "no","yes": 2 2 1 1 1 1 2 1 1 2 ...
## $ number_vmail_messages : int [1:5000] 25 26 0 0 0 0 24 0 0 37 ...
## $ total_day_minutes : num [1:5000] 265 162 243 299 167 ...
## $ total_day_calls : int [1:5000] 110 123 114 71 113 98 88 79 97 84 ...
## $ total_day_charge : num [1:5000] 45.1 27.5 41.4 50.9 28.3 ...
## $ total_eve_minutes : num [1:5000] 197.4 195.5 121.2 61.9 148.3 ...
## $ total_eve_calls : int [1:5000] 99 103 110 88 122 101 108 94 80 111 ...
## $ total_eve_charge : num [1:5000] 16.78 16.62 10.3 5.26 12.61 ...
## $ total_night_minutes : num [1:5000] 245 254 163 197 187 ...
## $ total_night_calls : int [1:5000] 91 103 104 89 121 118 118 96 90 97 ...
## $ total_night_charge : num [1:5000] 11.01 11.45 7.32 8.86 8.41 ...
## $ total_intl_minutes : num [1:5000] 10 13.7 12.2 6.6 10.1 6.3 7.5 7.1 8.7 11.2 ...
## $ total_intl_calls : int [1:5000] 3 3 5 7 3 6 7 6 4 5 ...
## $ total_intl_charge : num [1:5000] 2.7 3.7 3.29 1.78 2.73 1.7 2.03 1.92 2.35 3.02 ...
## $ number_customer_service_calls: int [1:5000] 1 1 0 2 3 0 3 0 1 0 ...
## $ churn : Factor w/ 2 levels "yes","no": 2 2 2 2 2 2 2 2 2 2 ...
table(mlc_churn$churn)
##
## yes no
## 707 4293
We can look for degenerate distributions first:
# Identify near-zero variance variables
degen_cols <- nearZeroVar(mlc_churn)
var_name <- names(mlc_churn)[degen_cols[1]] # pick just one to plot
# Plot the distribution of that one variable
ggplot(mlc_churn, aes_string(x = var_name)) +
geom_bar(fill = "steelblue") +
labs(title = paste("Bar Plot of", var_name),
x = var_name,
y = "Count") +
theme_minimal()
## Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
## ℹ Please use tidy evaluation idioms with `aes()`.
## ℹ See also `vignette("ggplot2-in-packages")` for more information.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
# print amount of non zero values:
mean(mlc_churn$number_vmail_messages == 0)
## [1] 0.7356
We see that about 70% of number_vmail_messages is 0. While this is a low variance, it may still be important and thus we will keep it in the dataset.
We will next look at correlations between predictors:
library(corrplot)
## corrplot 0.94 loaded
mlc_churn_numeric <- mlc_churn[sapply(mlc_churn, is.numeric)]
corrplot(cor(mlc_churn_numeric))
We see perfect correlation between variables regarding minutes and charges. This does make sense, as the cost is usually per minute. We will continue by removing minutes and keeping the charge of each.
mlc_churn <- mlc_churn[!grepl("minutes", names(mlc_churn), ignore.case = TRUE)]
mlc_churn_numeric <- mlc_churn_numeric[!grepl("minutes", names(mlc_churn_numeric), ignore.case = TRUE)]
corrplot(cor(mlc_churn_numeric))
There appears to be no further issues with correlation. Interestingly enough, even the correlation between total calls and total charge for each appears to be almost zero. We can look at a pairwise plot of each of these as well as churn below:
pairs(mlc_churn[,c(7:14,16)])
We can then try to create two new columns, which will be the sum of all calls and charges. We will then plot this:
mlc_churn$total_calls <- with(mlc_churn_numeric,
total_day_calls + total_eve_calls + total_night_calls + total_intl_calls
)
mlc_churn$total_charge <- with(mlc_churn_numeric,
total_day_charge + total_eve_charge + total_night_charge + total_intl_charge
)
churn_copy <- mlc_churn[,c(1:6,15:18)]
pairs(churn_copy[,-1])
First we will split the data into an 80-20 train test split:
set.seed(123)
# get train test split
train_idx <- createDataPartition(churn_copy$churn, p=0.80, list=F)
train <- churn_copy[train_idx,]
test <- churn_copy[-train_idx,]
Let’s first do boosting:
library(gbm)
## Warning: package 'gbm' was built under R version 4.4.3
## Loaded gbm 2.2.2
## This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
# fix data for boosting
train_boost <- train
train_boost$churn <- ifelse(train_boost$churn == 'yes', 1, 0)
test_boost <- test
test_boost$churn <- ifelse(test_boost$churn == 'yes', 1, 0)
boost <- gbm(churn~., data=train_boost, distribution = 'bernoulli', n.trees = 5000, interaction.depth = 4, shrinkage = 0.2, verbose = F)
boost_preds <- predict(boost, newdata=test_boost)
## Using 5000 trees...
boost_preds <- as.factor(ifelse(boost_preds >= 0.5, 'yes','no'))
Now to continue with bagging:
library(randomForest)
# bagging is rf with m=p, 9 in this case
bag <- randomForest(churn~., data=train, mtry=9)
bag_preds <- predict(bag, newdata=test)
Next with random forests:
rf <- randomForest(churn~., data = train, mtry = 3) # use sqrt(p) which is 3 for classification
rf_preds <- predict(rf, newdata=test)
BART:
library(BART)
## Warning: package 'BART' was built under R version 4.4.3
## Loading required package: nlme
## Loading required package: survival
##
## Attaching package: 'survival'
## The following object is masked from 'package:caret':
##
## cluster
ytrain <- ifelse(train$churn == "yes", 1, 0)
ytest <- ifelse(test$churn == "yes", 1, 0)
xtrain <- data.matrix(train[, setdiff(names(train), "churn")])
xtest <- data.matrix(test[, setdiff(names(test), "churn")])
bart_model <- pbart(x.train = xtrain, y.train = ytrain, x.test = xtest)
## *****Into main of pbart
## *****Data:
## data:n,p,np: 4001, 9, 999
## y1,yn: 0, 0
## x1,x[n*p]: 17.000000, 54.180000
## xp1,xp[np*p]: 2.000000, 59.090000
## *****Number of Trees: 50
## *****Number of Cut Points: 50 ... 100
## *****burn and ndpost: 100, 1000
## *****Prior:mybeta,alpha,tau: 2.000000,0.950000,0.212132
## *****binaryOffset: -1.073762
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,9,0
## *****nkeeptrain,nkeeptest,nkeeptreedraws: 1000,1000,1000
## *****printevery: 100
## *****skiptr,skipte,skiptreedraws: 1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 100 (out of 1100)
## done 200 (out of 1100)
## done 300 (out of 1100)
## done 400 (out of 1100)
## done 500 (out of 1100)
## done 600 (out of 1100)
## done 700 (out of 1100)
## done 800 (out of 1100)
## done 900 (out of 1100)
## done 1000 (out of 1100)
## time: 6s
## check counts
## trcnt,tecnt: 1000,1000
bart_probs <- bart_model$prob.test.mean
bart_preds <- as.factor(ifelse(bart_probs >= 0.5, "yes", "no"))
We will finally finish with logistic regression for a baseline comparison:
library(glmnet)
## Loading required package: Matrix
## Loaded glmnet 4.1-8
lr <- glm(churn~., data=train, family='binomial')
lr_preds <- predict(lr, newdata = test, type = "response")
lr_preds <- factor(ifelse(lr_preds >= 0.5, "yes", "no"), levels = levels(test$churn))
We can now compare all of the models using classification metrics such as accuracy, sensitivity, specificity, and misclassification rate.
# function to get metrics
get_metrics <- function(preds, ytrue) {
cm <- confusionMatrix(preds, ytrue, positive = "yes")
accuracy <- as.numeric(cm$overall["Accuracy"])
sensitivity <- as.numeric(cm$byClass["Sensitivity"])
specificity <- as.numeric(cm$byClass["Specificity"])
misclass_rate <- 1 - accuracy
return(c(Accuracy = accuracy, Sensitivity = sensitivity,
Specificity = specificity, Misclassification = misclass_rate))
}
# getting metrics
results <- rbind(
Logistic = get_metrics(lr_preds, test$churn),
BART = get_metrics(bart_preds, test$churn),
Bagging = get_metrics(bag_preds, test$churn),
Boosting = get_metrics(boost_preds, test$churn),
RF = get_metrics(rf_preds, test$churn)
)
results_df <- as.data.frame(round(results, 3))
results_df
## Accuracy Sensitivity Specificity Misclassification
## Logistic 0.145 0.809 0.036 0.855
## BART 0.945 0.674 0.990 0.055
## Bagging 0.948 0.702 0.988 0.052
## Boosting 0.942 0.695 0.983 0.058
## RF 0.949 0.667 0.995 0.051
It would appear that logistic regression performed quite poorly. The high sensitivity combined with other metrics being bad indicates it is likely guessing yes to most customers, which is unhelpful. BART, Bagging, Boosting, and RF all seem to perform quite well and at a similar level in comparison. Random forest does appear to perform the best with the highest accuracy and lowest misclassification rate, but the difference appears to be small enough that any of the models (besides logistic regression) would achieve a good performance.
Draw an example (of your own invention) of a partition of two-dimensional feature space that could result from recursive binary splitting. Your example should contain at least six regions. Draw a decision tree corresponding to this partition. Be sure to label all aspects of your figures, including the regions R1, R2, . . ., the cutpoints t1, t2, . . ., and so forth.
Hint: Your result should look something like Figures 8.1 and 8.2.
For this I will be making up a dataset and model in which we are looking at heart disease. We will be using Systolic Blood Pressure and A1c, which are both measures doctors look at to measure heart health, and predict whether a patient is at a 10% or more risk of developing heart disease over the next 10 years. This is based on a calculator that doctors use when deciding whether to prescribe statins, among other things.
Here we see the graph divided into six different regions. Green means that you are good and do not have a high risk of heart disease, while red is bad and means you have a high risk of heart disease. We can look at the decision tree which creates this graph below:
[SBP ≥ 140?]
/ \
[A1c ≥ 6.5?] [A1c ≥ 8.0?]
/ \ / \
No Yes No [A1c ≥ 9.0?]
/ \
No Yes
In this decision tree, each ‘Yes’ indicates over 10% risk of heart disease within the next ten years, while ‘No’ indicates less than 10% risk of heart disease within the next 10 years.
This problem involves the OJ data set which is part of the ISLR2 package.
library(ISLR2)
str(OJ)
## 'data.frame': 1070 obs. of 18 variables:
## $ Purchase : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
## $ WeekofPurchase: num 237 239 245 227 228 230 232 234 235 238 ...
## $ StoreID : num 1 1 1 1 7 7 7 7 7 7 ...
## $ PriceCH : num 1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceMM : num 1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
## $ DiscCH : num 0 0 0.17 0 0 0 0 0 0 0 ...
## $ DiscMM : num 0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
## $ SpecialCH : num 0 0 0 0 0 0 1 1 0 0 ...
## $ SpecialMM : num 0 1 0 0 0 1 1 0 0 0 ...
## $ LoyalCH : num 0.5 0.6 0.68 0.4 0.957 ...
## $ SalePriceMM : num 1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
## $ SalePriceCH : num 1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceDiff : num 0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
## $ Store7 : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
## $ PctDiscMM : num 0 0.151 0 0 0 ...
## $ PctDiscCH : num 0 0 0.0914 0 0 ...
## $ ListPriceDiff : num 0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
## $ STORE : num 1 1 1 1 0 0 0 0 0 0 ...
library(tree)
## Warning: package 'tree' was built under R version 4.4.3
set.seed(123)
train_idx <- sample(1:length(OJ$Purchase), size = 800) # train indices
train <- OJ[train_idx,]
test <- OJ[-train_idx,]
tree <- tree(Purchase ~., data=train)
summary(tree)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 8
## Residual mean deviance: 0.7625 = 603.9 / 792
## Misclassification error rate: 0.165 = 132 / 800
The training error is misclassification error in this case, which is 16.5% for this tree. There are 8 terminal nodes, and there is a residual mean deviance of 0.7625.
tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1071.00 CH ( 0.60875 0.39125 )
## 2) LoyalCH < 0.5036 350 415.10 MM ( 0.28000 0.72000 )
## 4) LoyalCH < 0.276142 170 131.00 MM ( 0.12941 0.87059 )
## 8) LoyalCH < 0.0356415 56 10.03 MM ( 0.01786 0.98214 ) *
## 9) LoyalCH > 0.0356415 114 108.90 MM ( 0.18421 0.81579 ) *
## 5) LoyalCH > 0.276142 180 245.20 MM ( 0.42222 0.57778 )
## 10) PriceDiff < 0.05 74 74.61 MM ( 0.20270 0.79730 ) *
## 11) PriceDiff > 0.05 106 144.50 CH ( 0.57547 0.42453 ) *
## 3) LoyalCH > 0.5036 450 357.10 CH ( 0.86444 0.13556 )
## 6) PriceDiff < -0.39 27 32.82 MM ( 0.29630 0.70370 ) *
## 7) PriceDiff > -0.39 423 273.70 CH ( 0.90071 0.09929 )
## 14) LoyalCH < 0.705326 130 135.50 CH ( 0.78462 0.21538 )
## 28) PriceDiff < 0.145 43 58.47 CH ( 0.58140 0.41860 ) *
## 29) PriceDiff > 0.145 87 62.07 CH ( 0.88506 0.11494 ) *
## 15) LoyalCH > 0.705326 293 112.50 CH ( 0.95222 0.04778 ) *
Let us examine node 8 which is a terminal node. To get there, the first split is at LoyalCH which is below 0.504, of which leaves 350 people. The next split is with LoyalCH being lower than 0.276, which decreases to 170 people. The next split is again with LoyalCH being lower than 0.036. This categorizes the remaining 56 individuals as purchasing MM, as 98.2% of individuals in the training dataset who met both of these criteria purchased MM (meaning only about one person in this group purchased CH). This outcome makes sense: people with a very low loyalty for CH will be very unlikely to buy CH.
plot(tree)
text(tree, pretty = 0)
In this tree diagram, branches on the left mean the condition is met while branches on the right mean the condition is not met. We can see visually the same concept from the previous question: in individuals with LoyalCH < 0.5036, LoyalCH < 0.276, and LoyalCH < 0.036, they would be classified as MM.
Seeing the entire tree shows an overall trend: if LoyalCH < 0.5 then you are unlikely to purchase CH, but if it is higher you are more likely to purchase CH. We see that the exceptions to this are involving PriceDiff; when there isn’t much of a price difference or MM is higher then they are likely to buy CH on the left hand side, while if MM is much lower the right hand side is more likely to buy MM.
Interestingly enough, we see that some splits have the same classification (such as nodes 8 and 9 as seen in the last section). This is because these splits increase the node purity, as the likelihood of being correct is much higher in node 8 than node 9.
tree_preds <- predict(tree, newdata = test)
tree_preds <- as.factor(ifelse(tree_preds[,1] >= 0.5, "CH","MM")) # classify as CH or MM
confusionMatrix(as.factor(test$Purchase), tree_preds)
## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 150 16
## MM 34 70
##
## Accuracy : 0.8148
## 95% CI : (0.7633, 0.8593)
## No Information Rate : 0.6815
## P-Value [Acc > NIR] : 5.99e-07
##
## Kappa : 0.596
##
## Mcnemar's Test P-Value : 0.01621
##
## Sensitivity : 0.8152
## Specificity : 0.8140
## Pos Pred Value : 0.9036
## Neg Pred Value : 0.6731
## Prevalence : 0.6815
## Detection Rate : 0.5556
## Detection Prevalence : 0.6148
## Balanced Accuracy : 0.8146
##
## 'Positive' Class : CH
##
cat('misclassification error: ', mean(tree_preds != as.factor(test$Purchase)))
## misclassification error: 0.1851852
We see misclassification error is 18.5%, slightly higher than the training error.
cv_tree <- cv.tree(tree) # uses prune.tree with cross validation
print(cv_tree)
## $size
## [1] 8 7 6 5 4 3 2 1
##
## $dev
## [1] 690.2346 692.8691 683.0423 717.4446 717.4446 743.3604 793.1266
## [8] 1073.0916
##
## $k
## [1] -Inf 12.03823 14.92474 25.76707 26.02613 38.91686 50.61655
## [8] 298.68751
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(cv_tree)
We see the lowest error at size 6 at 683, meaining that the optimal length to cut the tree is at size 6.
tree_pruned <- prune.tree(tree, best = 6)
plot(tree_pruned)
text(tree_pruned, pretty = 0)
summary(tree)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 8
## Residual mean deviance: 0.7625 = 603.9 / 792
## Misclassification error rate: 0.165 = 132 / 800
summary(tree_pruned)
##
## Classification tree:
## snip.tree(tree = tree, nodes = c(4L, 14L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 6
## Residual mean deviance: 0.7945 = 630.9 / 794
## Misclassification error rate: 0.165 = 132 / 800
We see the training error is at 16.5% for both trees. This is because the two nodes that were pruned were at a split in which both options were the same classification. This means that we see no decrease in performance after pruning the tree.
prune_preds <- predict(tree_pruned, newdata = test)
prune_preds <- as.factor(ifelse(prune_preds[,1] >= 0.5, "CH","MM")) # classify as CH or MM
cat('misclassification error for full tree: ', mean(tree_preds != as.factor(test$Purchase)), '\n')
## misclassification error for full tree: 0.1851852
cat('misclassification error for pruned tree: ', mean(prune_preds != as.factor(test$Purchase)))
## misclassification error for pruned tree: 0.1851852
We see the same test error for the pruning, for the same reason; the nodes pruned were only regarding improving purity rather than performance.