This notebook uses the Customer Segmentation: AV - Janatahack : Customer Segmentation dataset. The objective of the notebook is to identify the key features that can predict customer segments of an automobile company.
library(tidyverse)
library(skimr)
library(scales)
library(ggpubr)
library(psych)
library(colorspace)
library(PerformanceAnalytics)
library(caret)
library(rpart)
library(rpart.plot)
library(rattle)
library(randomForest)
library(xgboost)
library(nnet)
train= readr::read_csv("cust_seg_train.csv")
── Column specification ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
ID = col_double(),
Gender = col_character(),
Ever_Married = col_character(),
Age = col_double(),
Graduated = col_character(),
Profession = col_character(),
Work_Experience = col_double(),
Spending_Score = col_character(),
Family_Size = col_double(),
Var_1 = col_character(),
Segmentation = col_character()
)
test= readr::read_csv("cust_seg_test.csv")
── Column specification ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
ID = col_double(),
Gender = col_character(),
Ever_Married = col_character(),
Age = col_double(),
Graduated = col_character(),
Profession = col_character(),
Work_Experience = col_double(),
Spending_Score = col_character(),
Family_Size = col_double(),
Var_1 = col_character()
)
skim(train)
── Data Summary ────────────────────────
Values
Name train
Number of rows 8068
Number of columns 11
_______________________
Column type frequency:
character 7
numeric 4
________________________
Group variables None
── Variable type: character ──────────────────────────────────────────────────────────────────────────────────────────
skim_variable n_missing complete_rate min max empty n_unique whitespace
1 Gender 0 1 4 6 0 2 0
2 Ever_Married 140 0.983 2 3 0 2 0
3 Graduated 78 0.990 2 3 0 2 0
4 Profession 124 0.985 6 13 0 9 0
5 Spending_Score 0 1 3 7 0 3 0
6 Var_1 76 0.991 5 5 0 7 0
7 Segmentation 0 1 1 1 0 4 0
── Variable type: numeric ────────────────────────────────────────────────────────────────────────────────────────────
skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
1 ID 0 1 463479. 2595. 458982 461241. 463472. 465744. 467974 ▇▇▇▇▇
2 Age 0 1 43.5 16.7 18 30 40 53 89 ▇▇▅▃▂
3 Work_Experience 829 0.897 2.64 3.41 0 0 1 4 14 ▇▁▂▁▁
4 Family_Size 335 0.958 2.85 1.53 1 2 3 4 9 ▇▆▁▁▁
skim(test)
── Data Summary ────────────────────────
Values
Name test
Number of rows 2627
Number of columns 10
_______________________
Column type frequency:
character 6
numeric 4
________________________
Group variables None
── Variable type: character ──────────────────────────────────────────────────────────────────────────────────────────
skim_variable n_missing complete_rate min max empty n_unique whitespace
1 Gender 0 1 4 6 0 2 0
2 Ever_Married 50 0.981 2 3 0 2 0
3 Graduated 24 0.991 2 3 0 2 0
4 Profession 38 0.986 6 13 0 9 0
5 Spending_Score 0 1 3 7 0 3 0
6 Var_1 32 0.988 5 5 0 7 0
── Variable type: numeric ────────────────────────────────────────────────────────────────────────────────────────────
skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
1 ID 0 1 463434. 2618. 458989 461162. 463379 465696 467968 ▇▇▇▇▇
2 Age 0 1 43.6 17.0 18 30 41 53 89 ▇▇▅▃▂
3 Work_Experience 269 0.898 2.55 3.34 0 0 1 4 14 ▇▁▁▁▁
4 Family_Size 113 0.957 2.83 1.55 1 2 2 4 9 ▇▆▁▁▁
# highlighted bar plot
p1 = train %>% summarise(across(everything(), ~mean(!is.na(.)))) %>%
gather() %>%
mutate(key= fct_reorder(key, value)) %>%
ggplot(aes(key, value)) +
geom_col(aes(fill=I(ifelse(value==1,'#adb5bd','#98c1d9'))),width=0.8) +
geom_text(aes(label= percent(value)),
nudge_y=0.07, size=3, hjust=1.7) +
scale_y_continuous(labels= scales::percent) +
theme_minimal() +
theme(
panel.grid.major.x = element_blank(),
panel.grid.minor.x = element_blank(),
panel.grid.major.y = element_blank(),
axis.title=element_text(size=10),
plot.title=element_text(size=12,hjust=0.5),
plot.title.position = "plot",
legend.position="none"
) +
theme(axis.text.x=element_blank()) +
labs(x="Feature", y="% of data present", title="Train data") +
coord_flip()
p2 = test %>% summarise(across(everything(), ~mean(!is.na(.)))) %>%
gather() %>%
mutate(key= fct_reorder(key, value)) %>%
ggplot(aes(key, value)) +
geom_col(aes(fill=I(ifelse(value==1,'#adb5bd','#98c1d9'))),width=0.8) +
geom_text(aes(label= percent(value)),
nudge_y=0.07, size=3, hjust=1.7) +
scale_y_continuous(labels= scales::percent) +
theme_minimal() +
theme(
panel.grid.major.x = element_blank(),
panel.grid.minor.x = element_blank(),
panel.grid.major.y = element_blank(),
axis.title=element_text(size=10),
plot.title=element_text(size=12,hjust=0.5),
plot.title.position = "plot",
legend.position="none"
) +
theme(axis.text.x=element_blank()) +
labs(x="Feature", y="% of data present", title="Test data") +
coord_flip()
ggarrange(p1,p2, ncol=2)
# add labels to train and test sets
train$test_1 = 0
test$test_1 = 1
# bind test and train data
data = bind_rows(train, test)
dim(data)
[1] 10695 12
head(data)
# Segmentation
train %>% group_by(Segmentation) %>% tally() %>% mutate(proportion=round(n/sum(n),3))
# Var_1 (Anonymised Category for the customer)
# table (combined data)
data %>% group_by(Var_1) %>% tally() %>% mutate(proportion=round(n/sum(n),3))
# count plot
v1 = train %>%
ggplot(aes(x=Var_1, fill=Segmentation)) +
geom_bar(stat="count", position="dodge", alpha=0.9) +
scale_fill_manual(values=c("#f3722c","#f9c74f","#43aa8b","#577590")) +
labs(subtitle="Var_1 (Anonymised Category for the customer)") +
theme_light() +
theme(axis.title=element_text(size=10),
legend.title=element_text(size=9),
legend.position = c(0.13, 0.75))
# heatmap
v2 = train %>%
group_by(Segmentation, Var_1) %>%
tally() %>%
mutate(pct=round(n/sum(n)*100,2)) %>%
ggplot(aes(y=fct_rev(Segmentation), x=Var_1, fill=pct)) +
geom_tile(color="white",size=2) +
geom_text(aes(label=paste(pct,"%")), size=3) +
scale_fill_distiller(palette = 'Spectral') +
theme_light() +
theme(legend.position="top",
axis.text.y=element_text(size=12),
axis.ticks=element_blank(),
axis.title=element_text(size=10),
legend.title=element_text(size=9),
axis.title.y = element_text(margin = margin(t = 0, r = 5, b = 0, l = 0)),
axis.title.x = element_text(margin = margin(t = 5, r = 0, b = 0, l = 0))) +
labs(y="Segmentation", x="Var_1",
fill="Percentage", subtitle="Var_1 and Segmentation") +
guides(fill = guide_colorbar(title.position = "top",
title.hjust = .5,
barwidth = unit(20, "lines"),
barheight = unit(.5, "lines")))
# combine plot
ggarrange(v1,v2,ncol=2)
# Age
# histogram with kernel density curve (test and train data)
age1 = ggplot(data, aes(x=Age)) +
geom_histogram(aes(y=..density..),
binwidth=2,
fill="#6c757d", alpha=0.5) +
geom_density(alpha=.2, color="#00a8e8", size=1) +
labs(subtitle="Age Distribution (train and test data combined)") +
scale_x_continuous(limits = c(0, 100), breaks=seq(0,100,20), oob = scales::oob_keep) +
theme_light()
# line plot (train data)
age2 = ggplot(data, aes(x=Age)) +
geom_density(aes(color=Segmentation),key_glyph = draw_key_point, size=0.8) +
scale_x_continuous(limits=c(0,100), breaks=seq(0,100,20)) +
scale_y_continuous(limits=c(0,0.045)) +
scale_color_manual(values=c("#f3722c","#f9c74f","#43aa8b","#577590"), na.value="#c9ada7") +
theme_light() +
theme(legend.position = c(0.9, 0.75),
legend.text = element_text(size=8),
legend.title = element_text(size=8)) +
guides(color = guide_legend(override.aes = list(shape=15,size=5))) +
labs(subtitle= "Age Distribution (train data)")
ggarrange(age1, age2, ncol=2)
# summary of age by segment (train data)
psych::describeBy(train$Age, train$Segmentation,mat=TRUE)
# Work_Experience
# summary by segment
psych::describeBy(train$Work_Experience, train$Segmentation,mat=TRUE)
# count plot
we1 = data %>%
filter(!is.na(Work_Experience)) %>%
ggplot(aes(x=factor(Work_Experience))) +
geom_bar(stat="count",fill="#98c1d9") +
theme_light() +
theme(panel.grid.major.x=element_blank(),
axis.title=element_text(size=10)) +
labs(x="Work_Experience (in years)", subtitle="Work_Experience")
# count plot by segment
we2 = train %>%
filter(!is.na(Work_Experience)) %>%
ggplot(aes(x=factor(Work_Experience), fill=Segmentation)) +
geom_bar(stat="count", show.legend = F) +
theme_light() +
facet_wrap(~Segmentation, labeller=label_both) +
theme(panel.grid.major.x=element_blank(),
axis.title=element_text(size=10))+
labs(x="Work_Experience (in years)", subtitle="Work_Experience and Segementation") +
scale_fill_manual(values=c("#f3722c","#f9c74f","#43aa8b","#577590"))
# combine plot
ggarrange(we1, we2, ncol=2)
# Family_Size
# summary of Family_Size by segment
psych::describeBy(train$Family_Size, train$Segmentation,mat=TRUE)
# distribution
fs1 = data %>%
filter(!is.na(Family_Size)) %>%
ggplot(aes(x=Family_Size)) +
geom_histogram(binwidth = 1, color="white", fill="#98c1d9") +
geom_vline(aes(xintercept=mean(Family_Size)),color="#d66853", linetype="dashed", size=1) +
geom_text(aes(x=3.9, y=2900, label="Mean = 2.83"), size=3.5, color="#495057") +
scale_x_continuous(breaks=seq(1,10,1)) +
theme_light() +
theme(panel.grid.minor.x=element_blank(),
axis.title=element_text(size=10)) +
labs(subtitle="Family_Size")
# bubble plot
fs2 = train %>%
filter(!is.na(Family_Size)) %>%
group_by(Segmentation, Family_Size) %>% tally() %>% rename(Count=n) %>%
ggplot(aes(y=Segmentation, x=factor(Family_Size))) +
geom_point(aes(size=Count, color=Segmentation)) +
scale_color_manual(values=c("#f3722c","#f9c74f","#43aa8b","#577590")) +
theme(legend.position="bottom",
axis.title=element_text(size=10),
legend.title=element_text(size=9)) +
labs(x="Family_Size", subtitle="Family_Size and Segmentation") +
guides(color=FALSE) + scale_size(range=c(2,12))
ggarrange(fs1,fs2,ncol=2)
# Gender
# count plot
ge1 = train %>%
ggplot(aes(x=Gender, fill=Segmentation)) +
geom_bar(stat="count", position="dodge", alpha=0.9) +
geom_text(stat="count", aes(label=..count.., group=Segmentation),
vjust=1.4, position=position_dodge(width=0.9), size=3.5) +
scale_fill_manual(values=c("#f3722c","#f9c74f","#43aa8b","#577590")) +
labs(subtitle="Gender: Count") +
theme_light() +
theme(legend.position="bottom",
axis.title=element_text(size=10),
legend.title=element_text(size=9),)
# proportion plot
ge2 = train %>%
group_by(Segmentation, Gender) %>% tally() %>% mutate(proportion=round(n/sum(n),3)) %>%
ggplot(aes(y=fct_rev(Segmentation), x=proportion, fill=Gender)) +
geom_col(width=0.7, alpha=0.9) +
geom_text(aes(x = proportion, y = Segmentation, label = paste0(proportion*100, "%")),
size = 3, position = position_fill(vjust = 0.5), color="white") +
theme_light() +
theme(legend.position="bottom",
axis.title=element_text(size=10),
legend.title=element_text(size=9),) +
scale_fill_manual(values=c("#d66853","#364156"), guide=guide_legend(reverse=T)) +
labs(fill="Gender", y="Segmentation", x="percentage",subtitle="Gender: Percentage") +
scale_x_continuous(labels=scales::percent)
ggarrange(ge1,ge2, ncol=2)
# Ever_Married
# count plot
em1 = train %>% filter(Ever_Married!="") %>%
ggplot(aes(x=Ever_Married, fill=Segmentation)) +
geom_bar(stat="count", position="dodge", alpha=0.9) +
geom_text(stat="count", aes(label=..count.., group=Segmentation),
vjust=1.4, position=position_dodge(width=0.9), size=3.5) +
scale_fill_manual(values=c("#f3722c","#f9c74f","#43aa8b","#577590")) +
labs(subtitle="Ever_Married: Count") +
theme_light() +
theme(legend.position="bottom",
axis.title=element_text(size=10),
legend.title=element_text(size=9),)
# proportion plot
em2 = train %>% filter(Ever_Married!="") %>%
group_by(Segmentation, Ever_Married) %>% tally() %>% mutate(proportion=round(n/sum(n),3)) %>%
ggplot(aes(y=fct_rev(Segmentation), x=proportion, fill=factor(Ever_Married))) +
geom_col(width=0.7, alpha=0.9) +
geom_text(aes(x = proportion, y = Segmentation, label = paste0(proportion*100, "%")),
size = 3, position = position_fill(vjust = 0.5), color="white") +
theme_light() +
theme(legend.position="bottom",
axis.title=element_text(size=10),
legend.title=element_text(size=9),) +
scale_fill_manual(values=c("#d66853","#364156"), guide=guide_legend(reverse=T)) +
labs(fill="Ever_Married", y="Segmentation", x="percentage",subtitle="Ever_Married: Percentage") +
scale_x_continuous(labels=scales::percent)
ggarrange(em1,em2, ncol=2)
# Graduated
# count plot
gr1 = train %>% filter(Graduated!="") %>%
ggplot(aes(x=Graduated, fill=Segmentation)) +
geom_bar(stat="count", position="dodge", alpha=0.9) +
geom_text(stat="count", aes(label=..count.., group=Segmentation),
vjust=1.4, position=position_dodge(width=0.9), size=3.5) +
scale_fill_manual(values=c("#f3722c","#f9c74f","#43aa8b","#577590")) +
labs(subtitle="Graduated: Count") +
theme_light() +
theme(legend.position="bottom",
axis.title=element_text(size=10),
legend.title=element_text(size=9),)
# proportion plot
gr2 = train %>% filter(Graduated!="") %>%
group_by(Segmentation, Graduated) %>% tally() %>% mutate(proportion=round(n/sum(n),3)) %>%
ggplot(aes(y=fct_rev(Segmentation), x=proportion, fill=factor(Graduated))) +
geom_col(width=0.7, alpha=0.9) +
geom_text(aes(x = proportion, y = Segmentation, label = paste0(proportion*100, "%")),
size = 3, position = position_fill(vjust = 0.5), color="white") +
theme_light() +
theme(legend.position="bottom",
axis.title=element_text(size=10),
legend.title=element_text(size=9),) +
scale_fill_manual(values=c("#d66853","#364156"), guide=guide_legend(reverse=T)) +
labs(fill="Graduated", y="Segmentation", x="percentage",subtitle="Graduated: Percentage") +
scale_x_continuous(labels=scales::percent)
ggarrange(gr1,gr2, ncol=2)
# Spending_Score
# count plot
ss1 = train %>%
ggplot(aes(x=Spending_Score, fill=Segmentation)) +
geom_bar(stat="count", position="dodge", alpha=0.9) +
geom_text(stat="count", aes(label=..count.., group=Segmentation),
vjust=1.4, position=position_dodge(width=0.9), size=3) +
scale_fill_manual(values=c("#f3722c","#f9c74f","#43aa8b","#577590")) +
theme_light() +
theme(legend.position="bottom") +
labs(subtitle="Spending Score: Count")
# proportion plot
ss2 = train %>%
group_by(Segmentation, Spending_Score) %>% tally() %>% mutate(proportion=round(n/sum(n),3)) %>%
ggplot(aes(y=fct_rev(Segmentation), x=proportion, fill=factor(Spending_Score, level=c("High","Average","Low")))) +
geom_col(width=0.7, alpha=0.9) +
geom_text(aes(x = proportion, y = Segmentation, label = paste0(proportion*100, "%")),
size = 3, position = position_fill(vjust = 0.5), color="white") +
theme_light() +
theme(legend.position="bottom") +
scale_fill_manual(values=c("#d66853","#7d4e57","#364156"), guide=guide_legend(reverse=T)) +
labs(fill="Spending_Score", y="Segmentation", x="percentage", subtitle="Spending Score: Percentage") +
scale_x_continuous(labels=scales::percent)
ggarrange(ss1,ss2, ncol=2)
# Profession
# level frequency
train %>% group_by(Profession) %>% tally() %>% mutate(proportion=round(n/sum(n),3))
# count
train %>% filter(Profession!="") %>%
group_by(Profession, Segmentation) %>%
tally() %>%
ggplot(aes(y=Profession, x=n)) +
geom_line(aes(group=Profession), color="#736f72") +
geom_point(aes(color=Segmentation),key_glyph = draw_key_point, alpha=0.8, size=3) +
scale_color_manual(values=c("#f3722c","#f9c74f","#43aa8b","#577590")) +
guides(color = guide_legend(override.aes = list(shape=15,size=5))) +
theme_light() +
theme(legend.position="top",
plot.title.position = "plot") +
labs(x="count", subtitle="Profession: Count")
# proportion
train %>%
group_by(Segmentation, Profession) %>%
tally() %>%
mutate(pct=round(n/sum(n)*100,2)) %>%
ggplot(aes(y=fct_rev(Segmentation), x=Profession, fill=pct)) +
geom_tile(color="white",size=2) +
geom_text(aes(label=paste(pct,"%")), size=3) +
scale_fill_distiller(palette = 'Spectral') +
theme_light() +
theme(legend.position="top",
axis.text.y=element_text(size=12),
axis.ticks=element_blank(),
axis.title=element_text(size=10),
legend.title=element_text(size=9),
plot.title.position = "plot",
axis.title.y = element_text(margin = margin(t = 0, r = 5, b = 0, l = 0)),
axis.title.x = element_text(margin = margin(t = 5, r = 0, b = 0, l = 0))) +
labs(y="Segmentation", x="Profession",
fill="Percentage", subtitle="Profession: Percentage") +
guides(fill = guide_colorbar(title.position = "top",
title.hjust = .5,
barwidth = unit(20, "lines"),
barheight = unit(.5, "lines")))
# correlation
c_data <- data[, c(4,7,9)]
chart.Correlation(c_data , histogram=TRUE, pch=19)
# count of missing data by column
data %>% type.convert() %>% sapply(function(x)sum(is.na(x)))
ID Gender Ever_Married Age Graduated Profession Work_Experience Spending_Score
0 0 190 0 102 162 1098 0
Family_Size Var_1 Segmentation test_1
448 108 2627 0
cdf = data
# dummify Spending_Score
cdf$Spending_Score= as.numeric(factor(cdf$Spending_Score, order=TRUE, levels=c("Low","Average","High")))
# inpute missing values
# replace with mean
cdf$Family_Size <- ifelse(is.na(cdf$Family_Size), mean(cdf$Family_Size, na.rm=TRUE), cdf$Family_Size)
cdf$Work_Experience <- ifelse(is.na(cdf$Work_Experience), mean(cdf$Work_Experience, na.rm=TRUE), cdf$Work_Experience)
# replace withunknown
cdf$Profession <- ifelse(is.na(cdf$Profession), "unknown", cdf$Profession)
# replace with most frequent level
cdf$Graduated <- ifelse(is.na(cdf$Graduated), "Yes", cdf$Graduated)
cdf$Ever_Married <- ifelse(is.na(cdf$Ever_Married), "Yes", cdf$Ever_Married)
cdf$Var_1 <- ifelse(is.na(cdf$Var_1), "Cat_6", cdf$Var_1)
sapply(cdf,function(x)sum(is.na(x)))
ID Gender Ever_Married Age Graduated Profession Work_Experience Spending_Score
0 0 0 0 0 0 0 0
Family_Size Var_1 Segmentation test_1
0 0 2627 0
# train and test
traindf = cdf %>% filter(test_1==0) %>% select(-test_1, -ID) %>% as.data.frame()
testdf = cdf %>% filter(test_1==1) %>% select(-test_1, -ID,-Segmentation) %>% as.data.frame()
# partition train set based on outcome
set.seed(3456)
train.index <- createDataPartition(traindf$Segmentation, p = .7, list = FALSE)
xtrain <- traindf[ train.index,]
xtest <- traindf[-train.index,]
Research question: What are the key features for predicting customer segments?
To identify the key features for predicting customer segments, all the features in the dataset are used for modeling.
set.seed(123)
dt <- train(
Segmentation ~., data = xtrain, method = "rpart",
trControl = trainControl("cv", number = 10),
tuneLength = 10
)
# plot complexity parameter
plot(dt)
# print best complexity parameter
dt$bestTune %>% unlist()
cp
0.001477469
# plot tree
fancyRpartPlot(dt$finalModel)
# predict
dt.p <- dt %>% predict(xtest)
# accuracy
mean(dt.p == xtest$Segmentation)
[1] 0.5084746
# feature importance
plot(varImp(dt))
set.seed(123)
rf <- train(Segmentation ~., data = xtrain, method = "rf",
trControl = trainControl("cv", number = 10),
importance = TRUE)
rf$bestTune
rf$finalModel
Call:
randomForest(x = x, y = y, mtry = param$mtry, importance = TRUE)
Type of random forest: classification
Number of trees: 500
No. of variables tried at each split: 2
OOB estimate of error rate: 48.54%
Confusion matrix:
A B C D class.error
A 705 78 279 319 0.4895004
B 442 135 552 172 0.8962337
C 210 66 900 203 0.3473532
D 317 40 64 1167 0.2651134
# predict
rf.p <- rf %>% predict(xtest)
# accuracy
mean(rf.p == xtest$Segmentation)
[1] 0.5126085
# confusion matrix
confusionMatrix(rf.p, factor(xtest$Segmentation))
Confusion Matrix and Statistics
Reference
Prediction A B C D
A 296 194 73 163
B 46 59 25 12
C 119 237 406 26
D 130 67 87 479
Overall Statistics
Accuracy : 0.5126
95% CI : (0.4925, 0.5327)
No Information Rate : 0.2811
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.3457
Mcnemar's Test P-Value : < 2.2e-16
Statistics by Class:
Class: A Class: B Class: C Class: D
Sensitivity 0.5008 0.10592 0.6870 0.7044
Specificity 0.7648 0.95542 0.7910 0.8367
Pos Pred Value 0.4077 0.41549 0.5152 0.6278
Neg Pred Value 0.8258 0.78129 0.8866 0.8786
Prevalence 0.2443 0.23026 0.2443 0.2811
Detection Rate 0.1224 0.02439 0.1678 0.1980
Detection Prevalence 0.3001 0.05870 0.3258 0.3154
Balanced Accuracy 0.6328 0.53067 0.7390 0.7705
# AUC
#multiclass.roc(response= xtest$Segmentation, predictor = factor(rf.p,ordered=TRUE), plot=FALSE, print.auc=TRUE)
# variable importance
varImpPlot(rf$finalModel, type=2)
#varImp(rf2)
#importance(rf2$finalModel)
set.seed(123)
xgb <- train(Segmentation ~., data = xtrain, method = "xgbTree",trControl = trainControl("cv", number = 10))
xgb$bestTune
# predict
xgb.p = xgb %>% predict(xtest)
# accuracy
mean(xgb.p == xtest$Segmentation)
[1] 0.536172
# variable importance (pct)
varImp(xgb)
xgbTree variable importance
only 20 most important variables shown (out of 22)
plot(varImp(xgb))
set.seed(123)
bag <- train(Segmentation ~., data = xtrain, method = "treebag",trControl = trainControl("cv", number = 10))
bag$bestTune
# predict
bag.p = bag %>% predict(xtest)
# accuracy
mean(bag.p == xtest$Segmentation)
[1] 0.4745763
# variable importance (pct)
varImp(bag)
treebag variable importance
only 20 most important variables shown (out of 22)
plot(varImp(bag))
## multinorm logistic reg
# define reference level
xtrain$Segmentation = factor(xtrain$Segmentation, ordered = FALSE)
xtrain$Segmentation <- relevel(xtrain$Segmentation , ref = "A")
# model
multinom_model <- multinom(Segmentation ~ ., data = xtrain)
# weights: 96 (69 variable)
initial value 7831.176846
iter 10 value 6893.779420
iter 20 value 6573.644554
iter 30 value 6391.502335
iter 40 value 6308.610680
iter 50 value 6293.992444
iter 60 value 6292.638859
iter 70 value 6292.459550
final value 6292.454232
converged
summary(multinom_model)
Call:
multinom(formula = Segmentation ~ ., data = xtrain)
Coefficients:
(Intercept) GenderMale Ever_MarriedYes Age GraduatedYes ProfessionDoctor ProfessionEngineer ProfessionEntertainment
B -1.6186815 -0.1617574 0.3262336 0.01387245 0.3235098 -0.3058758 -0.6295035 -0.7374869
C -2.8465722 -0.2052905 0.4663696 0.02041253 0.6778097 -0.5185472 -1.6203062 -1.3488591
D 0.4382531 0.2269359 -0.1768833 -0.02807020 -0.5622723 1.0604937 0.7220786 0.6590366
ProfessionExecutive ProfessionHealthcare ProfessionHomemaker ProfessionLawyer ProfessionMarketing Professionunknown Work_Experience
B -0.4121788 0.06529989 -0.294444 -1.270369 -1.0150341 -0.6172276 -0.03937352
C -1.0782623 0.31493832 -1.114422 -1.989324 -0.7130235 -1.7333173 -0.04234758
D 1.3823799 2.81205689 1.485937 2.044152 1.8483612 1.5679038 0.02855183
Spending_Score Family_Size Var_1Cat_2 Var_1Cat_3 Var_1Cat_4 Var_1Cat_5 Var_1Cat_6 Var_1Cat_7
B 0.3566670 0.1622673 0.4171412 0.04682155 -0.1332478 0.4320620 0.1241479 -0.14628249
C 0.5751489 0.3226334 0.2521779 -0.17074694 -0.9242315 0.2900542 0.3182524 -0.02047237
D -0.3507156 0.1205068 -0.4020011 -0.44457033 -0.3033641 -0.2418461 -0.1437171 -0.60624816
Std. Errors:
(Intercept) GenderMale Ever_MarriedYes Age GraduatedYes ProfessionDoctor ProfessionEngineer ProfessionEntertainment
B 0.3820602 0.08782314 0.1121508 0.003744930 0.09291390 0.1537537 0.1448114 0.1278468
C 0.3996658 0.09056851 0.1209493 0.003909013 0.10216074 0.1583170 0.1795986 0.1412876
D 0.4003354 0.09292432 0.1178644 0.004406274 0.09152707 0.1755550 0.1772991 0.1599361
ProfessionExecutive ProfessionHealthcare ProfessionHomemaker ProfessionLawyer ProfessionMarketing Professionunknown Work_Experience
B 0.1762950 0.1917768 0.2391510 0.1865849 0.2823286 0.3324021 0.01315727
C 0.1824463 0.1850256 0.2896999 0.1943822 0.2579919 0.4544046 0.01371233
D 0.2172659 0.1809245 0.2342785 0.2308178 0.2188313 0.3052694 0.01293735
Spending_Score Family_Size Var_1Cat_2 Var_1Cat_3 Var_1Cat_4 Var_1Cat_5 Var_1Cat_6 Var_1Cat_7
B 0.07110530 0.03205418 0.3625032 0.3385279 0.3345384 0.4976097 0.3207676 0.4106196
C 0.07435629 0.03310478 0.3766607 0.3537579 0.3571772 0.5256959 0.3310724 0.4181147
D 0.08566326 0.03162401 0.3690472 0.3399052 0.3334015 0.5001568 0.3208734 0.4086847
Residual Deviance: 12584.91
AIC: 12722.91
# predict xtest
multinorm.p <- predict(multinom_model, newdata = xtest, type="class")
# confusion matrix
confusionMatrix(multinorm.p, factor(xtest$Segmentation))
Confusion Matrix and Statistics
Reference
Prediction A B C D
A 272 174 70 152
B 72 83 48 30
C 130 242 398 38
D 117 58 75 460
Overall Statistics
Accuracy : 0.5014
95% CI : (0.4813, 0.5216)
No Information Rate : 0.2811
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.3319
Mcnemar's Test P-Value : < 2.2e-16
Statistics by Class:
Class: A Class: B Class: C Class: D
Sensitivity 0.4602 0.14901 0.6734 0.6765
Specificity 0.7834 0.91944 0.7757 0.8562
Pos Pred Value 0.4072 0.35622 0.4926 0.6479
Neg Pred Value 0.8178 0.78317 0.8802 0.8713
Prevalence 0.2443 0.23026 0.2443 0.2811
Detection Rate 0.1124 0.03431 0.1645 0.1902
Detection Prevalence 0.2761 0.09632 0.3340 0.2935
Balanced Accuracy 0.6218 0.53423 0.7246 0.7664
# ROC AUC
#multiclass.roc(xtest$Segmentation, factor(multinorm.p,ordered = T))
# get z values
zvalues <- summary(multinom_model)$coefficients / summary(multinom_model)$standard.errors
# then p values
pnorm(abs(zvalues), lower.tail=FALSE)*2
(Intercept) GenderMale Ever_MarriedYes Age GraduatedYes ProfessionDoctor ProfessionEngineer ProfessionEntertainment
B 2.268107e-05 0.06549656 0.0036272048 2.119511e-04 4.980110e-04 4.665839e-02 1.379769e-05 7.997036e-09
C 1.060779e-12 0.02340930 0.0001153003 1.770834e-07 3.250816e-11 1.055226e-03 1.849888e-19 1.336331e-21
D 2.736416e-01 0.01459973 0.1334240094 1.884047e-10 8.085750e-10 1.533463e-09 4.647971e-05 3.778460e-05
ProfessionExecutive ProfessionHealthcare ProfessionHomemaker ProfessionLawyer ProfessionMarketing Professionunknown Work_Experience
B 1.938695e-02 7.334805e-01 2.182460e-01 9.859979e-12 3.241138e-04 6.332962e-02 0.002766766
C 3.420572e-09 8.873027e-02 1.196626e-04 1.395038e-24 5.714237e-03 1.364699e-04 0.002013157
D 1.983440e-10 1.782935e-54 2.259075e-10 8.284192e-19 3.001382e-17 2.804512e-07 0.027318918
Spending_Score Family_Size Var_1Cat_2 Var_1Cat_3 Var_1Cat_4 Var_1Cat_5 Var_1Cat_6 Var_1Cat_7
B 5.274762e-07 4.142697e-07 0.2498458 0.8899960 0.690406589 0.3852439 0.6987311 0.7216547
C 1.033708e-14 1.922077e-22 0.5031705 0.6293329 0.009664721 0.5811177 0.3364127 0.9609484
D 4.237754e-05 1.386236e-04 0.2760239 0.1908989 0.362871866 0.6287119 0.6542301 0.1379647