# Load necessary library
library(ggplot2)
# Generating pm1 values from 0 to 1
pm1 <- seq(0, 1, by = 0.01)
# Calculating Gini index, Classification error, and Entropy
gini_index <- 2 * pm1 * (1 - pm1)
classification_error <- 1 - pmax(pm1, 1 - pm1)
entropy <- -pm1 * log2(pm1) - (1 - pm1) * log2(1 - pm1)
# Replacing NaN with 0 for log(0) case in entropy
entropy[is.na(entropy)] <- 0
# Preparing data for ggplot
df <- data.frame(pm1, gini_index, classification_error, entropy)
# Melting the data for easy plotting with ggplot
df_melted <- reshape2::melt(df, id.vars = 'pm1')
# Plotting
ggplot(df_melted, aes(x = pm1, y = value, color = variable)) +
geom_line() +
labs(x = expression(hat(pm)[1]), y = "Value", title = "Gini Index, Classification Error, and Entropy") +
scale_color_manual(values = c("gini_index" = "blue", "classification_error" = "red", "entropy" = "green"), labels = c("Gini Index", "Classification Error", "Entropy")) +
theme_minimal()
The process of fitting a regression tree is a methodical approach that starts with the entire dataset represented at the root node. The core of the algorithm involves iteratively selecting the best feature and split point to partition the data, aiming to maximize variance reduction at each step. This selection process considers each possible feature and its potential split points, evaluating them based on their ability to segregate the dataset such that subsets contain observations with similar response values, thus minimizing the variance within these subsets.
Once the optimal split is determined, the dataset is divided into two subsets, and this procedure of recursive binary splitting continues. The algorithm further splits each resulting subset, adhering to the same criteria for selecting the best split, and this recursive process persists until certain predefined stopping conditions are met. These conditions include reaching a minimum number of observations in a node, a situation where further splits do not significantly reduce the variance, or achieving a maximum depth of the tree.
To ensure the model does not overfit the data, a pruning phase may follow the tree’s growth. Pruning involves trimming down the fully grown tree by removing splits that contribute minimally to the model’s predictive power. This is regulated by a complexity parameter, which is optimized through cross-validation to find a balance between the tree’s fit to the training data and its simplicity.
The final model, represented by the pruned tree, uses piecewise constant approximations for predictions. An observation is directed down the tree to a leaf node based on its feature values and the splits defined within the tree. The mean of the target variable for the training observations in that leaf node then becomes the prediction for the observation. Through this process, the regression tree algorithm captures complex, non-linear relationships in the data with a series of simple, interpretable rules, yielding a model that is not only adaptable but also straightforward for users to understand. This combination of adaptability and interpretability makes regression trees a valuable tool in the repertoire of machine learning methods.
install.packages("rpart")
WARNING: Rtools is required to build R packages but is not currently installed. Please download and install the appropriate version of Rtools before proceeding:
https://cran.rstudio.com/bin/windows/Rtools/
Installing package into ‘C:/Users/ngaku/AppData/Local/R/win-library/4.3’
(as ‘lib’ is unspecified)
trying URL 'https://cran.rstudio.com/bin/windows/contrib/4.3/rpart_4.1.23.zip'
Content type 'application/zip' length 710883 bytes (694 KB)
downloaded 694 KB
package ‘rpart’ successfully unpacked and MD5 sums checked
The downloaded binary packages are in
C:\Users\ngaku\AppData\Local\Temp\RtmpULmSbf\downloaded_packages
library(rpart)
install.packages("ISLR")
WARNING: Rtools is required to build R packages but is not currently installed. Please download and install the appropriate version of Rtools before proceeding:
https://cran.rstudio.com/bin/windows/Rtools/
Installing package into ‘C:/Users/ngaku/AppData/Local/R/win-library/4.3’
(as ‘lib’ is unspecified)
trying URL 'https://cran.rstudio.com/bin/windows/contrib/4.3/ISLR_1.4.zip'
Content type 'application/zip' length 2924120 bytes (2.8 MB)
downloaded 2.8 MB
package ‘ISLR’ successfully unpacked and MD5 sums checked
The downloaded binary packages are in
C:\Users\ngaku\AppData\Local\Temp\RtmpULmSbf\downloaded_packages
library(ISLR)
data("OJ")
str(OJ)
'data.frame': 1070 obs. of 18 variables:
$ Purchase : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
$ WeekofPurchase: num 237 239 245 227 228 230 232 234 235 238 ...
$ StoreID : num 1 1 1 1 7 7 7 7 7 7 ...
$ PriceCH : num 1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
$ PriceMM : num 1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
$ DiscCH : num 0 0 0.17 0 0 0 0 0 0 0 ...
$ DiscMM : num 0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
$ SpecialCH : num 0 0 0 0 0 0 1 1 0 0 ...
$ SpecialMM : num 0 1 0 0 0 1 1 0 0 0 ...
$ LoyalCH : num 0.5 0.6 0.68 0.4 0.957 ...
$ SalePriceMM : num 1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
$ SalePriceCH : num 1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
$ PriceDiff : num 0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
$ Store7 : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
$ PctDiscMM : num 0 0.151 0 0 0 ...
$ PctDiscCH : num 0 0 0.0914 0 0 ...
$ ListPriceDiff : num 0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
$ STORE : num 1 1 1 1 0 0 0 0 0 0 ...
set.seed(123) # Ensure reproducibility
# Create a random sample of 800 observations for the training set
train_indices <- sample(1:nrow(OJ), 800)
# Split the data into training and test sets
train_data <- OJ[train_indices, ]
test_data <- OJ[-train_indices, ]
# Fit the decision tree model
tree_model <- rpart(Purchase ~ ., data = train_data, method="class")
# Check the summary of the model
summary(tree_model)
Call:
rpart(formula = Purchase ~ ., data = train_data, method = "class")
n= 800
CP nsplit rel error xerror xstd
1 0.49201278 0 1.0000000 1.0000000 0.04410089
2 0.03514377 1 0.5079872 0.5335463 0.03672576
3 0.02555911 2 0.4728435 0.5335463 0.03672576
4 0.01277955 4 0.4217252 0.4504792 0.03443205
5 0.01000000 7 0.3833866 0.4728435 0.03508854
Variable importance
LoyalCH StoreID PriceDiff SalePriceMM WeekofPurchase PriceMM DiscMM
45 9 9 6 6 5 5
PctDiscMM PriceCH ListPriceDiff SalePriceCH STORE SpecialCH
4 4 3 2 1 1
Node number 1: 800 observations, complexity param=0.4920128
predicted class=CH expected loss=0.39125 P(node) =1
class counts: 487 313
probabilities: 0.609 0.391
left son=2 (450 obs) right son=3 (350 obs)
Primary splits:
LoyalCH < 0.5036 to the right, improve=134.49530, (0 missing)
StoreID < 3.5 to the right, improve= 40.88655, (0 missing)
STORE < 0.5 to the left, improve= 20.84871, (0 missing)
Store7 splits as RL, improve= 20.84871, (0 missing)
PriceDiff < 0.015 to the right, improve= 19.14298, (0 missing)
Surrogate splits:
StoreID < 3.5 to the right, agree=0.660, adj=0.223, (0 split)
WeekofPurchase < 246.5 to the right, agree=0.625, adj=0.143, (0 split)
PriceCH < 1.825 to the right, agree=0.600, adj=0.086, (0 split)
PriceMM < 1.89 to the right, agree=0.596, adj=0.077, (0 split)
ListPriceDiff < 0.035 to the right, agree=0.581, adj=0.043, (0 split)
Node number 2: 450 observations, complexity param=0.03514377
predicted class=CH expected loss=0.1355556 P(node) =0.5625
class counts: 389 61
probabilities: 0.864 0.136
left son=4 (423 obs) right son=5 (27 obs)
Primary splits:
PriceDiff < -0.39 to the right, improve=18.543390, (0 missing)
DiscMM < 0.72 to the left, improve= 9.309254, (0 missing)
SalePriceMM < 1.435 to the right, improve= 9.309254, (0 missing)
PctDiscMM < 0.3342595 to the left, improve= 9.309254, (0 missing)
LoyalCH < 0.7645725 to the right, improve= 8.822549, (0 missing)
Surrogate splits:
DiscMM < 0.72 to the left, agree=0.967, adj=0.444, (0 split)
SalePriceMM < 1.435 to the right, agree=0.967, adj=0.444, (0 split)
PctDiscMM < 0.3342595 to the left, agree=0.967, adj=0.444, (0 split)
SalePriceCH < 2.075 to the left, agree=0.949, adj=0.148, (0 split)
Node number 3: 350 observations, complexity param=0.02555911
predicted class=MM expected loss=0.28 P(node) =0.4375
class counts: 98 252
probabilities: 0.280 0.720
left son=6 (180 obs) right son=7 (170 obs)
Primary splits:
LoyalCH < 0.2761415 to the right, improve=14.991900, (0 missing)
StoreID < 3.5 to the right, improve= 6.562913, (0 missing)
Store7 splits as RL, improve= 4.617311, (0 missing)
STORE < 0.5 to the left, improve= 4.617311, (0 missing)
SpecialCH < 0.5 to the right, improve= 4.512108, (0 missing)
Surrogate splits:
STORE < 1.5 to the left, agree=0.629, adj=0.235, (0 split)
StoreID < 1.5 to the left, agree=0.589, adj=0.153, (0 split)
PriceCH < 1.875 to the left, agree=0.589, adj=0.153, (0 split)
SalePriceCH < 1.875 to the left, agree=0.586, adj=0.147, (0 split)
SalePriceMM < 1.84 to the left, agree=0.571, adj=0.118, (0 split)
Node number 4: 423 observations
predicted class=CH expected loss=0.09929078 P(node) =0.52875
class counts: 381 42
probabilities: 0.901 0.099
Node number 5: 27 observations
predicted class=MM expected loss=0.2962963 P(node) =0.03375
class counts: 8 19
probabilities: 0.296 0.704
Node number 6: 180 observations, complexity param=0.02555911
predicted class=MM expected loss=0.4222222 P(node) =0.225
class counts: 76 104
probabilities: 0.422 0.578
left son=12 (106 obs) right son=13 (74 obs)
Primary splits:
PriceDiff < 0.05 to the right, improve=12.110850, (0 missing)
SalePriceMM < 2.04 to the right, improve=11.572070, (0 missing)
DiscMM < 0.25 to the left, improve= 5.760121, (0 missing)
PctDiscMM < 0.1345485 to the left, improve= 5.760121, (0 missing)
ListPriceDiff < 0.18 to the right, improve= 5.597236, (0 missing)
Surrogate splits:
SalePriceMM < 1.94 to the right, agree=0.933, adj=0.838, (0 split)
DiscMM < 0.08 to the left, agree=0.822, adj=0.568, (0 split)
PctDiscMM < 0.038887 to the left, agree=0.822, adj=0.568, (0 split)
ListPriceDiff < 0.135 to the right, agree=0.800, adj=0.514, (0 split)
PriceMM < 2.04 to the right, agree=0.783, adj=0.473, (0 split)
Node number 7: 170 observations
predicted class=MM expected loss=0.1294118 P(node) =0.2125
class counts: 22 148
probabilities: 0.129 0.871
Node number 12: 106 observations, complexity param=0.01277955
predicted class=CH expected loss=0.4245283 P(node) =0.1325
class counts: 61 45
probabilities: 0.575 0.425
left son=24 (8 obs) right son=25 (98 obs)
Primary splits:
LoyalCH < 0.3084325 to the left, improve=3.118983, (0 missing)
WeekofPurchase < 247.5 to the right, improve=2.489639, (0 missing)
SpecialMM < 0.5 to the left, improve=2.454538, (0 missing)
PriceCH < 1.755 to the right, improve=2.048863, (0 missing)
PriceMM < 2.04 to the right, improve=1.514675, (0 missing)
Node number 13: 74 observations
predicted class=MM expected loss=0.2027027 P(node) =0.0925
class counts: 15 59
probabilities: 0.203 0.797
Node number 24: 8 observations
predicted class=CH expected loss=0 P(node) =0.01
class counts: 8 0
probabilities: 1.000 0.000
Node number 25: 98 observations, complexity param=0.01277955
predicted class=CH expected loss=0.4591837 P(node) =0.1225
class counts: 53 45
probabilities: 0.541 0.459
left son=50 (46 obs) right son=51 (52 obs)
Primary splits:
LoyalCH < 0.442144 to the right, improve=3.071463, (0 missing)
WeekofPurchase < 248.5 to the right, improve=2.208454, (0 missing)
SpecialMM < 0.5 to the left, improve=2.011796, (0 missing)
STORE < 0.5 to the left, improve=1.624324, (0 missing)
StoreID < 5.5 to the right, improve=1.624324, (0 missing)
Surrogate splits:
WeekofPurchase < 255 to the left, agree=0.622, adj=0.196, (0 split)
SalePriceCH < 1.755 to the right, agree=0.571, adj=0.087, (0 split)
STORE < 2.5 to the right, agree=0.571, adj=0.087, (0 split)
PriceMM < 2.205 to the right, agree=0.561, adj=0.065, (0 split)
DiscCH < 0.115 to the left, agree=0.561, adj=0.065, (0 split)
Node number 50: 46 observations
predicted class=CH expected loss=0.326087 P(node) =0.0575
class counts: 31 15
probabilities: 0.674 0.326
Node number 51: 52 observations, complexity param=0.01277955
predicted class=MM expected loss=0.4230769 P(node) =0.065
class counts: 22 30
probabilities: 0.423 0.577
left son=102 (8 obs) right son=103 (44 obs)
Primary splits:
SpecialCH < 0.5 to the right, improve=2.020979, (0 missing)
STORE < 1.5 to the left, improve=1.724009, (0 missing)
SpecialMM < 0.5 to the left, improve=1.680070, (0 missing)
WeekofPurchase < 245 to the right, improve=1.384615, (0 missing)
StoreID < 5.5 to the right, improve=1.319751, (0 missing)
Surrogate splits:
DiscCH < 0.27 to the right, agree=0.942, adj=0.625, (0 split)
SalePriceCH < 1.54 to the left, agree=0.942, adj=0.625, (0 split)
PctDiscCH < 0.149059 to the right, agree=0.942, adj=0.625, (0 split)
SalePriceMM < 1.64 to the left, agree=0.923, adj=0.500, (0 split)
DiscMM < 0.42 to the right, agree=0.904, adj=0.375, (0 split)
Node number 102: 8 observations
predicted class=CH expected loss=0.25 P(node) =0.01
class counts: 6 2
probabilities: 0.750 0.250
Node number 103: 44 observations
predicted class=MM expected loss=0.3636364 P(node) =0.055
class counts: 16 28
probabilities: 0.364 0.636
# Assuming you have already loaded the OJ dataset from the ISLR package
# Install and load the ISLR package if you haven't already
# install.packages("ISLR")
library(ISLR)
# Load the OJ dataset
data(OJ)
# Setting a seed for reproducibility
set.seed(123)
# Create indices for randomly sampling 800 observations for the training set
train_indices <- sample(nrow(OJ), 800)
# Create the training and test sets
train_data <- OJ[train_indices, ]
test_data <- OJ[-train_indices, ]
# Install and load the rpart package for fitting a tree model
# install.packages("rpart")
library(rpart)
# Fit a tree model to the training data
model <- rpart(Purchase ~ ., data = train_data, method = "class")
# Predict class labels for the training data
predictions <- predict(model, train_data, type = "class")
# Calculate the training error rate
train_error_rate <- mean(predictions != train_data$Purchase)
# Print the training error rate
print(train_error_rate)
[1] 0.15
.15 is the error rate
# Plot the tree
plot(model, main="Decision Tree")
text(model, use.n=TRUE)
There are five terminal nodes Terminal nodes usually contain the predicted class for that node and often the distribution or count of the class labels from the training data that fell into that node. For example, in the leftmost node, ‘381/42’ suggests that 381 observations from the training set were predicted as ‘CH’ and 42 as ‘MM’. This count can help assess the purity of the node - the higher the count of the majority class relative to the other, the purer the node.
install.packages("rattle")
WARNING: Rtools is required to build R packages but is not currently installed. Please download and install the appropriate version of Rtools before proceeding:
https://cran.rstudio.com/bin/windows/Rtools/
Installing package into ‘C:/Users/ngaku/AppData/Local/R/win-library/4.3’
(as ‘lib’ is unspecified)
also installing the dependencies ‘XML’, ‘rpart.plot’
trying URL 'https://cran.rstudio.com/bin/windows/contrib/4.3/XML_3.99-0.16.1.zip'
Content type 'application/zip' length 3092318 bytes (2.9 MB)
downloaded 2.9 MB
trying URL 'https://cran.rstudio.com/bin/windows/contrib/4.3/rpart.plot_3.1.2.zip'
Content type 'application/zip' length 1035325 bytes (1011 KB)
downloaded 1011 KB
trying URL 'https://cran.rstudio.com/bin/windows/contrib/4.3/rattle_5.5.1.zip'
Content type 'application/zip' length 6369857 bytes (6.1 MB)
downloaded 6.1 MB
package ‘XML’ successfully unpacked and MD5 sums checked
package ‘rpart.plot’ successfully unpacked and MD5 sums checked
package ‘rattle’ successfully unpacked and MD5 sums checked
The downloaded binary packages are in
C:\Users\ngaku\AppData\Local\Temp\RtmpULmSbf\downloaded_packages
# install.packages("rpart.plot") # Uncomment if you haven't installed the package
library(rpart.plot)
# Print a detailed text summary of the tree structure
rpart.plot(model, type = 4, extra = 101)
# Print a summary of the tree
summary(model)
Call:
rpart(formula = Purchase ~ ., data = train_data, method = "class")
n= 800
CP nsplit rel error xerror xstd
1 0.49201278 0 1.0000000 1.0000000 0.04410089
2 0.03514377 1 0.5079872 0.5335463 0.03672576
3 0.02555911 2 0.4728435 0.5335463 0.03672576
4 0.01277955 4 0.4217252 0.4504792 0.03443205
5 0.01000000 7 0.3833866 0.4728435 0.03508854
Variable importance
LoyalCH StoreID PriceDiff SalePriceMM WeekofPurchase PriceMM DiscMM
45 9 9 6 6 5 5
PctDiscMM PriceCH ListPriceDiff SalePriceCH STORE SpecialCH
4 4 3 2 1 1
Node number 1: 800 observations, complexity param=0.4920128
predicted class=CH expected loss=0.39125 P(node) =1
class counts: 487 313
probabilities: 0.609 0.391
left son=2 (450 obs) right son=3 (350 obs)
Primary splits:
LoyalCH < 0.5036 to the right, improve=134.49530, (0 missing)
StoreID < 3.5 to the right, improve= 40.88655, (0 missing)
STORE < 0.5 to the left, improve= 20.84871, (0 missing)
Store7 splits as RL, improve= 20.84871, (0 missing)
PriceDiff < 0.015 to the right, improve= 19.14298, (0 missing)
Surrogate splits:
StoreID < 3.5 to the right, agree=0.660, adj=0.223, (0 split)
WeekofPurchase < 246.5 to the right, agree=0.625, adj=0.143, (0 split)
PriceCH < 1.825 to the right, agree=0.600, adj=0.086, (0 split)
PriceMM < 1.89 to the right, agree=0.596, adj=0.077, (0 split)
ListPriceDiff < 0.035 to the right, agree=0.581, adj=0.043, (0 split)
Node number 2: 450 observations, complexity param=0.03514377
predicted class=CH expected loss=0.1355556 P(node) =0.5625
class counts: 389 61
probabilities: 0.864 0.136
left son=4 (423 obs) right son=5 (27 obs)
Primary splits:
PriceDiff < -0.39 to the right, improve=18.543390, (0 missing)
DiscMM < 0.72 to the left, improve= 9.309254, (0 missing)
SalePriceMM < 1.435 to the right, improve= 9.309254, (0 missing)
PctDiscMM < 0.3342595 to the left, improve= 9.309254, (0 missing)
LoyalCH < 0.7645725 to the right, improve= 8.822549, (0 missing)
Surrogate splits:
DiscMM < 0.72 to the left, agree=0.967, adj=0.444, (0 split)
SalePriceMM < 1.435 to the right, agree=0.967, adj=0.444, (0 split)
PctDiscMM < 0.3342595 to the left, agree=0.967, adj=0.444, (0 split)
SalePriceCH < 2.075 to the left, agree=0.949, adj=0.148, (0 split)
Node number 3: 350 observations, complexity param=0.02555911
predicted class=MM expected loss=0.28 P(node) =0.4375
class counts: 98 252
probabilities: 0.280 0.720
left son=6 (180 obs) right son=7 (170 obs)
Primary splits:
LoyalCH < 0.2761415 to the right, improve=14.991900, (0 missing)
StoreID < 3.5 to the right, improve= 6.562913, (0 missing)
Store7 splits as RL, improve= 4.617311, (0 missing)
STORE < 0.5 to the left, improve= 4.617311, (0 missing)
SpecialCH < 0.5 to the right, improve= 4.512108, (0 missing)
Surrogate splits:
STORE < 1.5 to the left, agree=0.629, adj=0.235, (0 split)
StoreID < 1.5 to the left, agree=0.589, adj=0.153, (0 split)
PriceCH < 1.875 to the left, agree=0.589, adj=0.153, (0 split)
SalePriceCH < 1.875 to the left, agree=0.586, adj=0.147, (0 split)
SalePriceMM < 1.84 to the left, agree=0.571, adj=0.118, (0 split)
Node number 4: 423 observations
predicted class=CH expected loss=0.09929078 P(node) =0.52875
class counts: 381 42
probabilities: 0.901 0.099
Node number 5: 27 observations
predicted class=MM expected loss=0.2962963 P(node) =0.03375
class counts: 8 19
probabilities: 0.296 0.704
Node number 6: 180 observations, complexity param=0.02555911
predicted class=MM expected loss=0.4222222 P(node) =0.225
class counts: 76 104
probabilities: 0.422 0.578
left son=12 (106 obs) right son=13 (74 obs)
Primary splits:
PriceDiff < 0.05 to the right, improve=12.110850, (0 missing)
SalePriceMM < 2.04 to the right, improve=11.572070, (0 missing)
DiscMM < 0.25 to the left, improve= 5.760121, (0 missing)
PctDiscMM < 0.1345485 to the left, improve= 5.760121, (0 missing)
ListPriceDiff < 0.18 to the right, improve= 5.597236, (0 missing)
Surrogate splits:
SalePriceMM < 1.94 to the right, agree=0.933, adj=0.838, (0 split)
DiscMM < 0.08 to the left, agree=0.822, adj=0.568, (0 split)
PctDiscMM < 0.038887 to the left, agree=0.822, adj=0.568, (0 split)
ListPriceDiff < 0.135 to the right, agree=0.800, adj=0.514, (0 split)
PriceMM < 2.04 to the right, agree=0.783, adj=0.473, (0 split)
Node number 7: 170 observations
predicted class=MM expected loss=0.1294118 P(node) =0.2125
class counts: 22 148
probabilities: 0.129 0.871
Node number 12: 106 observations, complexity param=0.01277955
predicted class=CH expected loss=0.4245283 P(node) =0.1325
class counts: 61 45
probabilities: 0.575 0.425
left son=24 (8 obs) right son=25 (98 obs)
Primary splits:
LoyalCH < 0.3084325 to the left, improve=3.118983, (0 missing)
WeekofPurchase < 247.5 to the right, improve=2.489639, (0 missing)
SpecialMM < 0.5 to the left, improve=2.454538, (0 missing)
PriceCH < 1.755 to the right, improve=2.048863, (0 missing)
PriceMM < 2.04 to the right, improve=1.514675, (0 missing)
Node number 13: 74 observations
predicted class=MM expected loss=0.2027027 P(node) =0.0925
class counts: 15 59
probabilities: 0.203 0.797
Node number 24: 8 observations
predicted class=CH expected loss=0 P(node) =0.01
class counts: 8 0
probabilities: 1.000 0.000
Node number 25: 98 observations, complexity param=0.01277955
predicted class=CH expected loss=0.4591837 P(node) =0.1225
class counts: 53 45
probabilities: 0.541 0.459
left son=50 (46 obs) right son=51 (52 obs)
Primary splits:
LoyalCH < 0.442144 to the right, improve=3.071463, (0 missing)
WeekofPurchase < 248.5 to the right, improve=2.208454, (0 missing)
SpecialMM < 0.5 to the left, improve=2.011796, (0 missing)
STORE < 0.5 to the left, improve=1.624324, (0 missing)
StoreID < 5.5 to the right, improve=1.624324, (0 missing)
Surrogate splits:
WeekofPurchase < 255 to the left, agree=0.622, adj=0.196, (0 split)
SalePriceCH < 1.755 to the right, agree=0.571, adj=0.087, (0 split)
STORE < 2.5 to the right, agree=0.571, adj=0.087, (0 split)
PriceMM < 2.205 to the right, agree=0.561, adj=0.065, (0 split)
DiscCH < 0.115 to the left, agree=0.561, adj=0.065, (0 split)
Node number 50: 46 observations
predicted class=CH expected loss=0.326087 P(node) =0.0575
class counts: 31 15
probabilities: 0.674 0.326
Node number 51: 52 observations, complexity param=0.01277955
predicted class=MM expected loss=0.4230769 P(node) =0.065
class counts: 22 30
probabilities: 0.423 0.577
left son=102 (8 obs) right son=103 (44 obs)
Primary splits:
SpecialCH < 0.5 to the right, improve=2.020979, (0 missing)
STORE < 1.5 to the left, improve=1.724009, (0 missing)
SpecialMM < 0.5 to the left, improve=1.680070, (0 missing)
WeekofPurchase < 245 to the right, improve=1.384615, (0 missing)
StoreID < 5.5 to the right, improve=1.319751, (0 missing)
Surrogate splits:
DiscCH < 0.27 to the right, agree=0.942, adj=0.625, (0 split)
SalePriceCH < 1.54 to the left, agree=0.942, adj=0.625, (0 split)
PctDiscCH < 0.149059 to the right, agree=0.942, adj=0.625, (0 split)
SalePriceMM < 1.64 to the left, agree=0.923, adj=0.500, (0 split)
DiscMM < 0.42 to the right, agree=0.904, adj=0.375, (0 split)
Node number 102: 8 observations
predicted class=CH expected loss=0.25 P(node) =0.01
class counts: 6 2
probabilities: 0.750 0.250
Node number 103: 44 observations
predicted class=MM expected loss=0.3636364 P(node) =0.055
class counts: 16 28
probabilities: 0.364 0.636
I picked node number 4
Predicted Class: CH. The model predicts ‘CH’ for observations falling into this node. Expected Loss: 0.09929078. This is the error rate for misclassification at this node – about 9.93%. It’s relatively low, indicating a good level of confidence in the predictions at this node. P(node): 0.52875. This value represents the proportion of observations from the training data that fall into this node, which is roughly 52.88%. Class Counts: 381 for CH and 42 for MM. There are 381 observations labeled ‘CH’ and 42 observations labeled ‘MM’ in this node. Probabilities: 0.901 for CH and 0.099 for MM. Given an observation falls into this node, there is a 90.1% chance it’s classified as ‘CH’ and a 9.9% chance it’s classified as ‘MM’.
This node does not split further and is a terminal node, meaning it does not branch out into more decisions. Observations that fall into this node are classified based on the highest probability which, in this case, is for class ‘CH’. The classification is made with a high confidence level (90.1%), and the error rate for this group of observations is quite low, suggesting that the features leading to this node are strong predictors for the ‘CH’ class.
# Assuming 'model' is your trained decision tree and 'test_data' is your test dataset
# Predict class labels for the test data
test_predictions <- predict(model, test_data, type = "class")
# Create a confusion matrix
confusion_matrix <- table(Predicted = test_predictions, Actual = test_data$Purchase)
# Print the confusion matrix
print(confusion_matrix)
Actual
Predicted CH MM
CH 141 24
MM 25 80
# Calculate and print the test error rate
test_error_rate <- sum(test_predictions != test_data$Purchase) / length(test_predictions)
print(paste("Test error rate:", test_error_rate))
[1] "Test error rate: 0.181481481481481"
Test error rate is 0.181481481481481
# Load necessary libraries
library(caret)
Loading required package: lattice
Registered S3 method overwritten by 'data.table':
method from
print.data.table
library(rpart)
# Set up cross-validation control
train_control <- trainControl(method="cv", number=10)
# Train the model with cross-validation to select the optimal complexity parameter (cp)
model_cv <- train(Purchase ~ .,
data=train_data,
method="rpart",
trControl=train_control,
tuneLength=10)
# Print the results
print(model_cv)
CART
800 samples
17 predictor
2 classes: 'CH', 'MM'
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 719, 720, 720, 720, 720, 720, ...
Resampling results across tuning parameters:
cp Accuracy Kappa
0.00000000 0.8062750 0.5898040
0.05466809 0.7862588 0.5531378
0.10933617 0.7862588 0.5531378
0.16400426 0.7862588 0.5531378
0.21867235 0.7862588 0.5531378
0.27334043 0.7862588 0.5531378
0.32800852 0.7862588 0.5531378
0.38267661 0.7862588 0.5531378
0.43734469 0.7862588 0.5531378
0.49201278 0.7225088 0.3644616
Accuracy was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.
# Assuming 'model_cv' is the model trained with cross-validation as shown in the previous example
# Extract the results
results <- model_cv$results
# Plot cross-validated classification error rate vs. tree size (1/Complexity Parameter)
plot(1/results$cp, results$Accuracy, type='b', col='blue',
xlab='Tree Size (1/Complexity Parameter)', ylab='Cross-Validated Accuracy',
main='Tree Size vs. Cross-Validated Accuracy')
# Assuming model_cv is the model trained with cross-validation
# Find the row with the highest accuracy
best_model_row <- which.max(model_cv$results$Accuracy)
# Extract the best cp value
best_cp <- model_cv$results$cp[best_model_row]
# Calculate the corresponding tree size (1/cp)
best_tree_size <- 1 / best_cp
# Print the best cp value and the corresponding tree size
print(paste("Best cp:", best_cp))
[1] "Best cp: 0"
print(paste("Corresponding tree size (1/cp):", best_tree_size))
[1] "Corresponding tree size (1/cp): Inf"
[1] “Best cp: 0” [1] “Corresponding tree size (1/cp): Inf”
# Assuming 'model' is your original tree model and 'best_cp' is the optimal cp value you've identified
# Prune the tree using the best_cp value
pruned_model <- prune(model, cp = best_cp)
# Plot the pruned tree
plot(pruned_model, main="Pruned Decision Tree")
text(pruned_model, use.n=TRUE)
# Optionally, if you want to see the summary of the pruned tree
print(summary(pruned_model))
Call:
rpart(formula = Purchase ~ ., data = train_data, method = "class")
n= 800
CP nsplit rel error xerror xstd
1 0.49201278 0 1.0000000 1.0000000 0.04410089
2 0.03514377 1 0.5079872 0.5335463 0.03672576
3 0.02555911 2 0.4728435 0.5335463 0.03672576
4 0.01277955 4 0.4217252 0.4504792 0.03443205
5 0.01000000 7 0.3833866 0.4728435 0.03508854
Variable importance
LoyalCH StoreID PriceDiff SalePriceMM WeekofPurchase PriceMM DiscMM
45 9 9 6 6 5 5
PctDiscMM PriceCH ListPriceDiff SalePriceCH STORE SpecialCH
4 4 3 2 1 1
Node number 1: 800 observations, complexity param=0.4920128
predicted class=CH expected loss=0.39125 P(node) =1
class counts: 487 313
probabilities: 0.609 0.391
left son=2 (450 obs) right son=3 (350 obs)
Primary splits:
LoyalCH < 0.5036 to the right, improve=134.49530, (0 missing)
StoreID < 3.5 to the right, improve= 40.88655, (0 missing)
STORE < 0.5 to the left, improve= 20.84871, (0 missing)
Store7 splits as RL, improve= 20.84871, (0 missing)
PriceDiff < 0.015 to the right, improve= 19.14298, (0 missing)
Surrogate splits:
StoreID < 3.5 to the right, agree=0.660, adj=0.223, (0 split)
WeekofPurchase < 246.5 to the right, agree=0.625, adj=0.143, (0 split)
PriceCH < 1.825 to the right, agree=0.600, adj=0.086, (0 split)
PriceMM < 1.89 to the right, agree=0.596, adj=0.077, (0 split)
ListPriceDiff < 0.035 to the right, agree=0.581, adj=0.043, (0 split)
Node number 2: 450 observations, complexity param=0.03514377
predicted class=CH expected loss=0.1355556 P(node) =0.5625
class counts: 389 61
probabilities: 0.864 0.136
left son=4 (423 obs) right son=5 (27 obs)
Primary splits:
PriceDiff < -0.39 to the right, improve=18.543390, (0 missing)
DiscMM < 0.72 to the left, improve= 9.309254, (0 missing)
SalePriceMM < 1.435 to the right, improve= 9.309254, (0 missing)
PctDiscMM < 0.3342595 to the left, improve= 9.309254, (0 missing)
LoyalCH < 0.7645725 to the right, improve= 8.822549, (0 missing)
Surrogate splits:
DiscMM < 0.72 to the left, agree=0.967, adj=0.444, (0 split)
SalePriceMM < 1.435 to the right, agree=0.967, adj=0.444, (0 split)
PctDiscMM < 0.3342595 to the left, agree=0.967, adj=0.444, (0 split)
SalePriceCH < 2.075 to the left, agree=0.949, adj=0.148, (0 split)
Node number 3: 350 observations, complexity param=0.02555911
predicted class=MM expected loss=0.28 P(node) =0.4375
class counts: 98 252
probabilities: 0.280 0.720
left son=6 (180 obs) right son=7 (170 obs)
Primary splits:
LoyalCH < 0.2761415 to the right, improve=14.991900, (0 missing)
StoreID < 3.5 to the right, improve= 6.562913, (0 missing)
Store7 splits as RL, improve= 4.617311, (0 missing)
STORE < 0.5 to the left, improve= 4.617311, (0 missing)
SpecialCH < 0.5 to the right, improve= 4.512108, (0 missing)
Surrogate splits:
STORE < 1.5 to the left, agree=0.629, adj=0.235, (0 split)
StoreID < 1.5 to the left, agree=0.589, adj=0.153, (0 split)
PriceCH < 1.875 to the left, agree=0.589, adj=0.153, (0 split)
SalePriceCH < 1.875 to the left, agree=0.586, adj=0.147, (0 split)
SalePriceMM < 1.84 to the left, agree=0.571, adj=0.118, (0 split)
Node number 4: 423 observations
predicted class=CH expected loss=0.09929078 P(node) =0.52875
class counts: 381 42
probabilities: 0.901 0.099
Node number 5: 27 observations
predicted class=MM expected loss=0.2962963 P(node) =0.03375
class counts: 8 19
probabilities: 0.296 0.704
Node number 6: 180 observations, complexity param=0.02555911
predicted class=MM expected loss=0.4222222 P(node) =0.225
class counts: 76 104
probabilities: 0.422 0.578
left son=12 (106 obs) right son=13 (74 obs)
Primary splits:
PriceDiff < 0.05 to the right, improve=12.110850, (0 missing)
SalePriceMM < 2.04 to the right, improve=11.572070, (0 missing)
DiscMM < 0.25 to the left, improve= 5.760121, (0 missing)
PctDiscMM < 0.1345485 to the left, improve= 5.760121, (0 missing)
ListPriceDiff < 0.18 to the right, improve= 5.597236, (0 missing)
Surrogate splits:
SalePriceMM < 1.94 to the right, agree=0.933, adj=0.838, (0 split)
DiscMM < 0.08 to the left, agree=0.822, adj=0.568, (0 split)
PctDiscMM < 0.038887 to the left, agree=0.822, adj=0.568, (0 split)
ListPriceDiff < 0.135 to the right, agree=0.800, adj=0.514, (0 split)
PriceMM < 2.04 to the right, agree=0.783, adj=0.473, (0 split)
Node number 7: 170 observations
predicted class=MM expected loss=0.1294118 P(node) =0.2125
class counts: 22 148
probabilities: 0.129 0.871
Node number 12: 106 observations, complexity param=0.01277955
predicted class=CH expected loss=0.4245283 P(node) =0.1325
class counts: 61 45
probabilities: 0.575 0.425
left son=24 (8 obs) right son=25 (98 obs)
Primary splits:
LoyalCH < 0.3084325 to the left, improve=3.118983, (0 missing)
WeekofPurchase < 247.5 to the right, improve=2.489639, (0 missing)
SpecialMM < 0.5 to the left, improve=2.454538, (0 missing)
PriceCH < 1.755 to the right, improve=2.048863, (0 missing)
PriceMM < 2.04 to the right, improve=1.514675, (0 missing)
Node number 13: 74 observations
predicted class=MM expected loss=0.2027027 P(node) =0.0925
class counts: 15 59
probabilities: 0.203 0.797
Node number 24: 8 observations
predicted class=CH expected loss=0 P(node) =0.01
class counts: 8 0
probabilities: 1.000 0.000
Node number 25: 98 observations, complexity param=0.01277955
predicted class=CH expected loss=0.4591837 P(node) =0.1225
class counts: 53 45
probabilities: 0.541 0.459
left son=50 (46 obs) right son=51 (52 obs)
Primary splits:
LoyalCH < 0.442144 to the right, improve=3.071463, (0 missing)
WeekofPurchase < 248.5 to the right, improve=2.208454, (0 missing)
SpecialMM < 0.5 to the left, improve=2.011796, (0 missing)
STORE < 0.5 to the left, improve=1.624324, (0 missing)
StoreID < 5.5 to the right, improve=1.624324, (0 missing)
Surrogate splits:
WeekofPurchase < 255 to the left, agree=0.622, adj=0.196, (0 split)
SalePriceCH < 1.755 to the right, agree=0.571, adj=0.087, (0 split)
STORE < 2.5 to the right, agree=0.571, adj=0.087, (0 split)
PriceMM < 2.205 to the right, agree=0.561, adj=0.065, (0 split)
DiscCH < 0.115 to the left, agree=0.561, adj=0.065, (0 split)
Node number 50: 46 observations
predicted class=CH expected loss=0.326087 P(node) =0.0575
class counts: 31 15
probabilities: 0.674 0.326
Node number 51: 52 observations, complexity param=0.01277955
predicted class=MM expected loss=0.4230769 P(node) =0.065
class counts: 22 30
probabilities: 0.423 0.577
left son=102 (8 obs) right son=103 (44 obs)
Primary splits:
SpecialCH < 0.5 to the right, improve=2.020979, (0 missing)
STORE < 1.5 to the left, improve=1.724009, (0 missing)
SpecialMM < 0.5 to the left, improve=1.680070, (0 missing)
WeekofPurchase < 245 to the right, improve=1.384615, (0 missing)
StoreID < 5.5 to the right, improve=1.319751, (0 missing)
Surrogate splits:
DiscCH < 0.27 to the right, agree=0.942, adj=0.625, (0 split)
SalePriceCH < 1.54 to the left, agree=0.942, adj=0.625, (0 split)
PctDiscCH < 0.149059 to the right, agree=0.942, adj=0.625, (0 split)
SalePriceMM < 1.64 to the left, agree=0.923, adj=0.500, (0 split)
DiscMM < 0.42 to the right, agree=0.904, adj=0.375, (0 split)
Node number 102: 8 observations
predicted class=CH expected loss=0.25 P(node) =0.01
class counts: 6 2
probabilities: 0.750 0.250
Node number 103: 44 observations
predicted class=MM expected loss=0.3636364 P(node) =0.055
class counts: 16 28
probabilities: 0.364 0.636
n= 800
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 800 313 CH (0.60875000 0.39125000)
2) LoyalCH>=0.5036 450 61 CH (0.86444444 0.13555556)
4) PriceDiff>=-0.39 423 42 CH (0.90070922 0.09929078) *
5) PriceDiff< -0.39 27 8 MM (0.29629630 0.70370370) *
3) LoyalCH< 0.5036 350 98 MM (0.28000000 0.72000000)
6) LoyalCH>=0.2761415 180 76 MM (0.42222222 0.57777778)
12) PriceDiff>=0.05 106 45 CH (0.57547170 0.42452830)
24) LoyalCH< 0.3084325 8 0 CH (1.00000000 0.00000000) *
25) LoyalCH>=0.3084325 98 45 CH (0.54081633 0.45918367)
50) LoyalCH>=0.442144 46 15 CH (0.67391304 0.32608696) *
51) LoyalCH< 0.442144 52 22 MM (0.42307692 0.57692308)
102) SpecialCH>=0.5 8 2 CH (0.75000000 0.25000000) *
103) SpecialCH< 0.5 44 16 MM (0.36363636 0.63636364) *
13) PriceDiff< 0.05 74 15 MM (0.20270270 0.79729730) *
7) LoyalCH< 0.2761415 170 22 MM (0.12941176 0.87058824) *
# Predict class labels for the training data using the pruned tree
pruned_predictions <- predict(pruned_model, train_data, type = "class")
# Calculate the training error rate for the pruned tree
pruned_train_error_rate <- mean(pruned_predictions != train_data$Purchase)
# Print the training error rate for the pruned tree
print(pruned_train_error_rate)
[1] 0.15
they have the same error rate
# For the unpruned tree
unpruned_predictions <- predict(model, newdata = test_data, type = "class")
unpruned_test_error_rate <- mean(unpruned_predictions != test_data$Purchase)
print(paste("Unpruned Test Error Rate:", unpruned_test_error_rate))
[1] "Unpruned Test Error Rate: 0.181481481481481"
# For the pruned tree
pruned_predictions <- predict(pruned_model, newdata = test_data, type = "class")
pruned_test_error_rate <- mean(pruned_predictions != test_data$Purchase)
print(paste("Pruned Test Error Rate:", pruned_test_error_rate))
[1] "Pruned Test Error Rate: 0.181481481481481"
same test error rate