Classification Trees are a type of decision tree algorithm specifically tailored for solving classification problems. Their primary attraction lies in their interpretability, ease of implementation, and ability to handle a variety of data types.
The general goal of Classification Trees is to create a model that predicts the value of a target binary variable by learning simple decision rules inferred from the data features.
The decision rules are easy to understand and interpret, making Classification Trees ideal for applications where interpretability is crucial.
Medical Diagnosis: For diagnosing diseases based on medical test results.
Credit Default Prediction: For predicting the likelihood of a borrower defaulting on a loan.
Email Filtering: To categorize emails as spam or not spam.
Customer Segmentation: To segment customers based on their purchasing behavior.
Root Node: The topmost node that includes the entire dataset. It represents the starting point of the decision-making process.
Decision Node: An internal node that performs a test based on feature values and makes a decision to split the data. Each decision node results in two child nodes.
Leaf Node: A terminal node where no further splitting occurs. It contains the class label to be assigned to new data points.
Branch: A section of the tree that connects nodes, representing the decision path leading to a leaf node.
Depth: The length of the longest path from a node to a leaf.
–
We will be working on the credit data again. Again, our goal is to predict default!
file_path = "https://xiaoruizhu.github.io/Data-Mining-R/lecture/data/credit_default.csv"
credit.data <- read.csv(file = file_path, header=T)[1:2000,]
credit.data$SEX<- as.factor(credit.data$SEX)
credit.data$EDUCATION<- as.factor(credit.data$EDUCATION)
credit.data$MARRIAGE<- as.factor(credit.data$MARRIAGE)
colnames(credit.data)
## [1] "LIMIT_BAL" "SEX"
## [3] "EDUCATION" "MARRIAGE"
## [5] "AGE" "PAY_0"
## [7] "PAY_2" "PAY_3"
## [9] "PAY_4" "PAY_5"
## [11] "PAY_6" "BILL_AMT1"
## [13] "BILL_AMT2" "BILL_AMT3"
## [15] "BILL_AMT4" "BILL_AMT5"
## [17] "BILL_AMT6" "PAY_AMT1"
## [19] "PAY_AMT2" "PAY_AMT3"
## [21] "PAY_AMT4" "PAY_AMT5"
## [23] "PAY_AMT6" "default.payment.next.month"
Let’s look a very simple classification tree. It is so simple, it is just a rule: whether PAY_0 < 2.
That mean:
if PAY_0 < 2, we predict 0 which
means no default.
if PAY_0 >= 2, we predict 1 which
means default.
There are one Root Node (which is naturally a Decision Node) and two Leaf Nodes. For each node, there are three number, for example:
The Root Node has are labeled as 0 and
.78 .22:
That means:
0 for this node0 and 22%
data is actually 1In this case, we have 1 Root Node, 2 Decision Node (again, Root Node is naturally a Decision Node),and 3 Leaf Nodes
In this section, we will address four key questions concerning the construction and application of Classification Trees.
How to choose the split variable and its value?
When should we stop splitting?
for the leaf node, how do we decide the predicted class?
How to classify a new data?
For each split, here are the general steps involved:
rpart function in R. It measures node impurity; a Gini
index of 0 indicates perfect purity (all elements are of the same
class). The Gini index is computationally less intensive than
information gain.\[
\text{Gini} = 1 - \sum_{i=1}^{c} p(i)^{2}
\] where c is 1 for binary classification problem and \(p(i)\) is the proportion of data points
belong to i.
Calculate Split Quality (continued):
\[ \text{Information Gain} = \text{Entropy}(parent) - \sum \left( \frac{|child|}{|parent|} \times \text{Entropy}(child) \right) \] The entropy for a node is given by:
\[ \text{Entropy}(t) = - \sum_{i=1}^{c} p(i) \log_{2} p(i) \]
Select Best Split: Choose the variable and its split value that result in the best split according to the chosen criterion.
Implement the Split: Divide the data into subsets based on the chosen variable and value, resulting in two child nodes.
From these two child nodes, we can do another round of split on each node, and go on…
Splitting too much can lead to overfitting
Not splitting enough can result in underfitting.
Generally, we want to stop here:
But this is not realistic, as we don’t know how it will perform on new unseen data. But there is something we can do:
Pruning: First built a big tree and use pruning methods like reduced error pruning or cost complexity pruning to remove unnecessary branches.
Complexity Parameter (cp): Setting a threshold for the complexity parameter can stop the tree from growing when the improvement is marginal.
cp values result in larger trees, while higher
values result in smaller trees.Minimum Samples: Limit the number of samples
required for a split (minsplit) or a leaf node
(minbucket).
minsplit=20 means that a node must have at
least 20 samples to be considered for a further split.Maximum Depth: As demonstrated earlier, the
maximum depth (maxdepth) of the tree can be explicitly
set.
Once a record reaches a leaf node, it’s time to make a classification decision. Here are the key approaches:
Majority Voting: The most straightforward method, jsut assign the class that are the majority in the leaf node.
Class Probabilities: Instead of making a hard assignment, you can calculate the class probabilities based on the proportion of each class in the leaf node.
Weighted Voting: In some cases, the classes in the leaf node can be weighted based on some external metric or business rule.
Start at the Root: The classification process begins at the root node of the tree.
Follow the Path: Based on the decision rules at each internal node, navigate through the tree.
Reach a Leaf Node: Once you reach a leaf node, use the decision criterion at the leaf to classify the data point.
Assign Class Label: Finally, assign the class label or class probabilities present at the leaf node to the new data point.
Once a tree model is built, it’s essential to evaluate its performance to ensure it is generalizing well to new data. We can also use it to compare different models.
Understanding both the strengths and weaknesses of Classification Trees will help you make informed decisions when selecting algorithms for different tasks.
Interpretable Model: One of the most interpretable machine learning algorithms, making it easy to explain to non-experts.
Handles both Categorical and Numerical Features: The algorithm can handle data of various types.
Minimal Data Preprocessing: No need for scaling, and it can handle missing values.
Quick to Build: Computational complexity is generally linear in terms of the number of features and data points.
Non-Parametric: Does not make strong assumptions about the underlying data distribution.
Overfitting: Trees can easily capture noise in data and overfit, although techniques like pruning can help.
Unstable: Slight changes in data can result in a significantly different tree.
Biased to Dominant Class: In imbalanced datasets, the algorithm is biased towards the dominant class.
Limited Expressiveness: Each decision boundary is axis-aligned, in another word, it consider only one variable at a time.(It will be more clear when we talk about Regression Tree)
Classification Trees are not limited to binary classification. They can be readily extended to handle multi-class problems.
Same Algorithm: The same basic algorithm used for binary classification can be extended to multi-class classification.
Multiple Classes in Leaf Nodes: Each leaf node in the tree will now contain samples from multiple classes, rather than just two.
Majority Voting: Classification is done by assigning the most frequent class in the leaf node to the new data point.
rpart PackageOverview: -An important pcajge that implements the classification and regression tree algorithm (CART).
Key Functions:
rpart(): Builds the tree model.predict(): Makes predictions based on the model.summary(): Provides a detailed summary of the tree
model.Installation:
#If you haven't, run the following to install rpart package
install.packages('rpart')
rpart.plot Packagerpart.
rpart.plot(): Creates an advanced plot of the
tree.#If you haven't, run the following to install rpart.plot package
install.packages("rpart.plot")
We will be working on the credit data again.
file_path = "https://xiaoruizhu.github.io/Data-Mining-R/lecture/data/credit_default.csv"
credit.data <- read.csv(file = file_path, header=T)
And let’s take a look of the column names:
colnames(credit.data)
## [1] "LIMIT_BAL" "SEX"
## [3] "EDUCATION" "MARRIAGE"
## [5] "AGE" "PAY_0"
## [7] "PAY_2" "PAY_3"
## [9] "PAY_4" "PAY_5"
## [11] "PAY_6" "BILL_AMT1"
## [13] "BILL_AMT2" "BILL_AMT3"
## [15] "BILL_AMT4" "BILL_AMT5"
## [17] "BILL_AMT6" "PAY_AMT1"
## [19] "PAY_AMT2" "PAY_AMT3"
## [21] "PAY_AMT4" "PAY_AMT5"
## [23] "PAY_AMT6" "default.payment.next.month"
rename the default.payment.next.month column to
default:
library(dplyr)
credit.data<- rename(credit.data, default=default.payment.next.month)
We know that SEX, EDUCATION, MARRIAGE are categorical, we convert them to factor.
credit.data$SEX<- as.factor(credit.data$SEX)
credit.data$EDUCATION<- as.factor(credit.data$EDUCATION)
credit.data$MARRIAGE<- as.factor(credit.data$MARRIAGE)
set.seed(2023)
index <- sample(1:nrow(credit.data),nrow(credit.data)*0.60)
credit.train = credit.data[index,]
credit.test = credit.data[-index,]
We will use the function rpart from package
rpart.
library(rpart)
library(rpart.plot)
# fit the model
fit_tree <- rpart(as.factor(default) ~ ., data=credit.train)
rpart.plot(fit_tree,extra=4, yesno=2)
Obtain the predicted values for training data and confusion matrix.
# Make predictions on the train data
pred_credit_train <- predict(fit_tree, credit.train, type="class")
# Confusion matrix to evaluate the model on train data
Cmatrix_train = table(true = credit.train$default,
pred = pred_credit_train)
Cmatrix_train
## pred
## true 0 1
## 0 5316 310
## 1 956 618
Mis-classficiation Rate (MR)
1 - sum(diag(Cmatrix_train))/sum(Cmatrix_train)
## [1] 0.1758333
Obtain the predicted values for testing data and confusion matrix.
# Make predictions on the train data
pred_credit_test <- predict(fit_tree, credit.test, type="class")
# Confusion matrix to evaluate the model on train data
Cmatrix_test = table(true = credit.test$default,
pred = pred_credit_test)
Cmatrix_test
## pred
## true 0 1
## 0 3542 200
## 1 644 414
Mis-classficiation Rate (MR)
1 - sum(diag(Cmatrix_test))/sum(Cmatrix_test)
## [1] 0.1758333
Using classification tree with asymmetric cost is different again!
# We need to define a cost matrix first, don't change 0 there
cost_matrix <- matrix(c(0, 1, # cost of 1 for FP
5, 0), # cost of 5 for FN
byrow = TRUE, nrow = 2)
fit_tree_asym <- rpart(as.factor(default) ~ ., data=credit.train,
parms = list(loss = cost_matrix))
rpart.plot(fit_tree_asym,extra=4, yesno=2)
#get predictions for training
pred_credit_train <- predict(fit_tree_asym, credit.train,
type = "class")
#C matrix for training
table( true = credit.train$default, pred = pred_credit_train)
## pred
## true 0 1
## 0 3550 2076
## 1 364 1210
#get predictions for testing
pred_credit_test <- predict(fit_tree_asym, credit.test, type = "class")
#C matrix for testing with more weights on "1"
Cmatrix_test_weight = table( true = credit.test$default, pred = pred_credit_test)
Cmatrix_test_weight
## pred
## true 0 1
## 0 2359 1383
## 1 289 769
#Recal we had the testing confusion matrix without any asymmetric cost.
Cmatrix_test # no weights
## pred
## true 0 1
## 0 3542 200
## 1 644 414
# work on it here!
While tree-based models like those generated with rpart are not probabilistic classifiers in the same sense as logistic regression or some other algorithms, you can still compute an AUC score based on the leaf node probabilities of belonging to the positive class.
We will obtain the predicted probabilities first with:
# obtain predicted probability
pred_prob_train = predict(fit_tree, credit.train, type = "prob")
# This is necessary again, as predict() for tree model return two values, one for 0 and one for 1.
# Replace "1" with the actual category if response variable is a factor
pred_prob_train = pred_prob_train[,"1"]
# Looks familar, right?
library(ROCR)
pred <- prediction(pred_prob_train, credit.train$default)
perf <- performance(pred, "tpr", "fpr")
plot(perf, colorize=TRUE)
#Get the AUC
unlist(slot(performance(pred, "auc"), "y.values"))
## [1] 0.6965917
# obtain predicted probability
pred_prob_test = predict(fit_tree, credit.test, type = "prob")
# This is necessary again, as predict() for tree model return two values, one for 0 and one for 1.
pred_prob_test = pred_prob_test[,"1"] #replace "1" with the actual category if reponse variable is a factor
#ROC
pred <- prediction(pred_prob_test, credit.test$default)
perf <- performance(pred, "tpr", "fpr")
plot(perf, colorize=TRUE)
#Get the AUC
unlist(slot(performance(pred, "auc"), "y.values"))
## [1] 0.6950851
# work on it here!
maxdepth# Fit the tree with a maximum depth of 2
depth_control_tree <- rpart(default ~ ., data = credit.train, method = "class",
control = rpart.control(maxdepth = 3),
parms = list(loss = cost_matrix))
# I am using cost_matrix just for demostration, remove it if needed
rpart.plot(depth_control_tree,extra=4, yesno=2)
CP# Fit the tree with a maximum depth of 2
depth_control_tree <- rpart(default ~ ., data = credit.train,
method = "class",cp = 0.003,
parms = list(loss = cost_matrix))
# I am using cost_matrix just for demostration,
#remove it if needed
rpart.plot(depth_control_tree,extra=4, yesno=2)
# Fit a full tree with very low cp value
full_tree <- rpart(default ~ ., data = credit.train,
method = "class", cp = 0.001)
# Display CP table to identify the optimal cp
printcp(full_tree)
##
## Classification tree:
## rpart(formula = default ~ ., data = credit.train, method = "class",
## cp = 0.001)
##
## Variables actually used in tree construction:
## [1] AGE BILL_AMT1 BILL_AMT2 BILL_AMT4 BILL_AMT5 BILL_AMT6 EDUCATION
## [8] LIMIT_BAL PAY_0 PAY_2 PAY_5 PAY_6 PAY_AMT1 PAY_AMT2
## [15] PAY_AMT3 PAY_AMT5 PAY_AMT6 SEX
##
## Root node error: 1574/7200 = 0.21861
##
## n= 7200
##
## CP nsplit rel error xerror xstd
## 1 0.1747141 0 1.00000 1.00000 0.022281
## 2 0.0104828 1 0.82529 0.82529 0.020730
## 3 0.0082592 3 0.80432 0.80940 0.020573
## 4 0.0041296 4 0.79606 0.79860 0.020464
## 5 0.0034943 8 0.77891 0.80686 0.020547
## 6 0.0031766 10 0.77192 0.80559 0.020535
## 7 0.0021177 11 0.76874 0.81067 0.020586
## 8 0.0019060 19 0.75095 0.83037 0.020779
## 9 0.0015883 32 0.72490 0.84180 0.020890
## 10 0.0012706 36 0.71792 0.85896 0.021053
## 11 0.0010589 46 0.70457 0.88564 0.021301
## 12 0.0010000 53 0.69314 0.91169 0.021535
From the cp table, choose cp that give you lowest
xerror, and use that value to prune a tree as follow:
# Identify optimal CP munually or use the following codes
optimal_cp <- full_tree$cptable[which.min(
full_tree$cptable[,"xerror"]),"CP"]
# Prune the tree
pruned_tree <- prune(full_tree, cp = optimal_cp)
rpart.plot(pruned_tree,extra=4, yesno=2)