Classification and Regression Trees or CART refer to decision tree algorithm that can be used for both classification and regression predictive modelling. The goal is to classify or predict an outcome based on set of predictors using certain RULES represented by trees usig recursive partitioning i.e., repeatedly split the records into two parts so as to achieve maximum homogeneity within the new parts.
We will use the same credit card data we used for SVM model and compare their outputs.
Let’s begin by loading the required libraries and importing the data set we are going to use for this model.
#set working directory
setwd("C:/Users/awani/Desktop/50daysofAnalytics")
options(scipen = 999)
# load required libraries
if (!require("pacman")) install.packages("pacman")
pacman::p_load(rpart, rpart.plot, caret, knitr, dplyr, kableExtra, ggplot2, tidyr, reshape2, ROSE)
#read data
data = read.csv("creditcard.csv", stringsAsFactors = F)
Before we proceed any further, it is essential to understand the data well, get it in correct format and more importantly check data correctness.
The data seems to be in correct format except the dependent variable Class which needs to be of type factor
#data format correction
str(data)
## 'data.frame': 255120 obs. of 31 variables:
## $ Time : num 157072 157072 157072 157073 157073 ...
## $ Var1 : num 1.989 -1.042 -0.785 -3.752 1.865 ...
## $ Var2 : num -0.351 -0.173 0.424 -3.417 -0.911 ...
## $ Var3 : num -0.391 -0.187 2.568 0.734 -0.153 ...
## $ Var4 : num 0.399 -2.27 0.263 1.211 0.613 ...
## $ Var5 : num -0.4611 2.6017 -1.0306 -0.0221 -1.2008 ...
## $ Var6 : num -0.151 3.997 0.701 -0.646 -0.269 ...
## $ Var7 : num -0.5949 1.3717 0.1882 -0.0273 -0.8684 ...
## $ Var8 : num -0.0169 0.6916 0.1949 0.6058 0.0884 ...
## $ Var9 : num 1.3236 -0.0615 0.8801 -1.9 -0.3732 ...
## $ Var10 : num -0.199 -1.329 -1.007 0.153 1.161 ...
## $ Var11 : num -0.895 -0.324 -1.31 0.421 0.594 ...
## $ Var12 : num 0.9077 -0.0696 0.2669 0.6044 0.3659 ...
## $ Var13 : num 1.046 -0.552 0.067 0.662 -0.436 ...
## $ Var14 : num -0.3952 -0.0109 -0.821 0.5525 0.269 ...
## $ Var15 : num 0.4227 -0.3995 -0.0821 -0.0862 0.0013 ...
## $ Var16 : num 0.0722 -0.4544 -0.6816 -1.4636 -0.769 ...
## $ Var17 : num -0.591 -0.5096 0.2718 0.0521 -0.6953 ...
## $ Var18 : num 0.137 -0.614 0.159 2.333 1.897 ...
## $ Var19 : num -0.2096 -0.3594 1.1792 -0.0557 -1.1062 ...
## $ Var20 : num -0.1317 0.4493 0.0418 1.5443 -0.5398 ...
## $ Var21 : num 0.19535 -0.09585 -0.00605 0.27723 -0.38351 ...
## $ Var22 : num 0.827 -0.513 0.314 -0.509 -0.794 ...
## $ Var23 : num 0.138 0.053 -0.316 0.779 0.385 ...
## $ Var24 : num 0.78176 0.68561 0.02175 -0.00436 -0.05757 ...
## $ Var25 : num -0.0412 0.9004 0.3478 0.6919 -0.5991 ...
## $ Var26 : num -0.2494 -0.0638 -0.0702 -0.4183 -0.9538 ...
## $ Var27 : num 0.0426 -0.2348 0.048 0.1337 0.0593 ...
## $ Var28 : num -0.0278 -0.1739 0.0572 -0.3452 -0.0158 ...
## $ Amount: num 9.99 212.36 96.38 474 84 ...
## $ Class : int 0 0 0 0 0 0 0 0 0 0 ...
data$Class = factor(data$Class)
The frequency distribution of variable Class shows that frauds are ubder represented in data. We should either use over or undersampling method to correct this bias.
#- Exploratory Data Analysis
# dependent variable
kable(table(data$Class),
col.names = c("Fraud", "Frequency"), align = 'l') %>%
kable_styling(bootstrap_options = "striped", full_width = F, position = "left")
Fraud | Frequency |
---|---|
0 | 254721 |
1 | 399 |
Since, most of the data is numeric, a five point univariate summary would increase our familiarity with. We can get an idea of central tendency and spread of the variables. Aditionally, density plots of variables will give us a good idea of their distributions.
# Independent Variables
summary(data[,2:(ncol(data)-1)])
## Var1 Var2 Var3
## Min. :-56.40751 Min. :-72.71573 Min. :-48.32559
## 1st Qu.: -0.91672 1st Qu.: -0.61862 1st Qu.: -1.01313
## Median : 0.03103 Median : 0.05509 Median : 0.07276
## Mean : 0.02444 Mean : -0.01469 Mean : -0.08388
## 3rd Qu.: 1.41813 3rd Qu.: 0.80507 3rd Qu.: 0.94627
## Max. : 2.45493 Max. : 22.05773 Max. : 9.38256
## Var4 Var5 Var6
## Min. :-5.68317 Min. :-113.74331 Min. :-26.16051
## 1st Qu.:-0.86487 1st Qu.: -0.67592 1st Qu.: -0.78081
## Median :-0.04890 Median : -0.02945 Median : -0.28719
## Mean :-0.02484 Mean : 0.02367 Mean : -0.01035
## 3rd Qu.: 0.69977 3rd Qu.: 0.64481 3rd Qu.: 0.38862
## Max. :16.87534 Max. : 34.80167 Max. : 73.30163
## Var7 Var8 Var9
## Min. :-43.55724 Min. :-73.21672 Min. :-13.43407
## 1st Qu.: -0.54757 1st Qu.: -0.21420 1st Qu.: -0.66336
## Median : 0.05341 Median : 0.02040 Median : -0.07463
## Mean : 0.01416 Mean : -0.00357 Mean : -0.04116
## 3rd Qu.: 0.58541 3rd Qu.: 0.33159 3rd Qu.: 0.54326
## Max. :120.58949 Max. : 18.70925 Max. : 15.59500
## Var10 Var11 Var12
## Min. :-24.58826 Min. :-4.79747 Min. :-18.68371
## 1st Qu.: -0.52766 1st Qu.:-0.80390 1st Qu.: -0.33576
## Median : -0.08453 Median :-0.08581 Median : 0.17299
## Mean : 0.01345 Mean :-0.05975 Mean : 0.07936
## 3rd Qu.: 0.47006 3rd Qu.: 0.67395 3rd Qu.: 0.63654
## Max. : 23.74514 Max. : 9.41304 Max. : 7.84839
## Var13 Var14 Var15
## Min. :-5.79188 Min. :-18.39209 Min. :-4.498945
## 1st Qu.:-0.67002 1st Qu.: -0.44978 1st Qu.:-0.589985
## Median :-0.04320 Median : 0.02636 Median : 0.033537
## Mean :-0.04603 Mean : -0.04010 Mean :-0.006425
## 3rd Qu.: 0.60862 3rd Qu.: 0.45410 3rd Qu.: 0.636654
## Max. : 7.12688 Max. : 10.52677 Max. : 8.877742
## Var16 Var17 Var18
## Min. :-14.129855 Min. :-25.16280 Min. :-9.498746
## 1st Qu.: -0.465704 1st Qu.: -0.50046 1st Qu.:-0.492541
## Median : 0.066457 Median : -0.09022 Median : 0.003446
## Mean : -0.000296 Mean : -0.02336 Mean : 0.009396
## 3rd Qu.: 0.518873 3rd Qu.: 0.37030 3rd Qu.: 0.510935
## Max. : 17.315112 Max. : 7.34307 Max. : 5.041069
## Var19 Var20 Var21
## Min. :-7.213527 Min. :-54.49772 Min. :-34.83038
## 1st Qu.:-0.446838 1st Qu.: -0.21784 1st Qu.: -0.22592
## Median : 0.008237 Median : -0.06763 Median : -0.02056
## Mean : 0.004718 Mean : -0.00539 Mean : 0.00395
## 3rd Qu.: 0.457672 3rd Qu.: 0.12903 3rd Qu.: 0.19623
## Max. : 5.591971 Max. : 39.42090 Max. : 27.20284
## Var22 Var23 Var24
## Min. :-10.93314 Min. :-44.80774 Min. :-2.824849
## 1st Qu.: -0.54195 1st Qu.: -0.15976 1st Qu.:-0.357605
## Median : 0.02241 Median : -0.00577 Median : 0.038739
## Mean : 0.01442 Mean : 0.00486 Mean :-0.001086
## 3rd Qu.: 0.55834 3rd Qu.: 0.15716 3rd Qu.: 0.447899
## Max. : 10.50309 Max. : 22.52841 Max. : 4.584549
## Var25 Var26 Var27
## Min. :-10.29540 Min. :-2.604551 Min. :-22.565679
## 1st Qu.: -0.33405 1st Qu.:-0.325949 1st Qu.: -0.071411
## Median : -0.00945 Median :-0.051439 Median : 0.000553
## Mean : -0.01544 Mean :-0.002654 Mean : -0.001439
## 3rd Qu.: 0.33965 3rd Qu.: 0.236202 3rd Qu.: 0.091187
## Max. : 7.51959 Max. : 3.415636 Max. : 31.612198
## Var28 Amount
## Min. :-15.43008 Min. : 0.00
## 1st Qu.: -0.05512 1st Qu.: 5.49
## Median : 0.00908 Median : 22.29
## Mean : -0.00056 Mean : 89.37
## 3rd Qu.: 0.07850 3rd Qu.: 78.00
## Max. : 33.84781 Max. :25691.16
# distribution plot
data[,-31] %>%
gather() %>%
ggplot(aes(value)) +
facet_wrap(~ key, scales = "free") +
geom_density()
We should now perform some basic multivariate analysis with an attempt to uncover some interesting trends. Correltion heatmap is a quick method to check for multicolinearity and higly correlated variables.
# Plot class against with Amount
ggplot(data, aes(x = Time, y = Amount, shape = Class, color = Class)) +
geom_point() +
ylim(0,5000)+
ggtitle("Frauds by Amount and Time")
## Warning: Removed 50 rows containing missing values (geom_point).
#correlation heat map
cor = round(cor(data[,1:30]),2)
ggplot(data = melt(cor), aes(x=Var1, y=Var2, fill=value)) +
geom_tile()
Lets us now prepare our data sets for training and validation of svm model. Since, the is imbalanced, unequal representation of fraud and non fraud cases, we should either use oversampling or undersampling. Undersampling, although results in loss of information, is better in this cas. Oversampling would almost double the observations in datset and svms are very slow with huge data sets.
# Undersampling
data_corrected = ovun.sample(Class ~ ., data = data, method = "under", N = 399*2)$data
kable(table(data_corrected$Class),
col.names = c("Fraud", "Frequency"), align = 'l') %>%
kable_styling(bootstrap_options = "striped", full_width = F, position = "left")
Fraud | Frequency |
---|---|
0 | 399 |
1 | 399 |
#training and Validation dataset
set.seed(123)
smp_size = floor(0.7 * nrow(data_corrected))
train_ind = sample(seq_len(nrow(data_corrected)), size = smp_size)
train = data_corrected[train_ind, ]
val = data_corrected[-train_ind, ]
Let us now train CART using traing data set we prepared. We will the test model accuracy by using the model to classify frauds in validation data set.
#model training
rpart_model = rpart(Class ~ ., data = train)
#model summary
rpart_model
## n= 558
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 558 272 0 (0.51254480 0.48745520)
## 2) Var14>=-1.615586 311 30 0 (0.90353698 0.09646302)
## 4) Var21>=-0.8291239 297 21 0 (0.92929293 0.07070707) *
## 5) Var21< -0.8291239 14 5 1 (0.35714286 0.64285714) *
## 3) Var14< -1.615586 247 5 1 (0.02024291 0.97975709) *
#Tree plot
rpart.plot(rpart_model)
The CART model seems to be doing a great job of classification with overall accuracy of 89% with very high senstivity of 94% and specificty of 84%. If we recall results of SVM, CART does a better job in terms of accuracy.
#prediction
val$fraud_prob = predict(rpart_model, val[,-31])[,2]
val$pred = ifelse(val$fraud_prob > 0.5,1,0)
#confusion Matrix
confusionMatrix(val$Class, factor(val$pred))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 109 4
## 1 21 106
##
## Accuracy : 0.8958
## 95% CI : (0.8501, 0.9314)
## No Information Rate : 0.5417
## P-Value [Acc > NIR] : < 0.00000000000000022
##
## Kappa : 0.7927
## Mcnemar's Test P-Value : 0.001374
##
## Sensitivity : 0.8385
## Specificity : 0.9636
## Pos Pred Value : 0.9646
## Neg Pred Value : 0.8346
## Prevalence : 0.5417
## Detection Rate : 0.4542
## Detection Prevalence : 0.4708
## Balanced Accuracy : 0.9010
##
## 'Positive' Class : 0
##