library(gmodels) # Cross Tables [CrossTable()]
library(ggplot2)
library(ggmosaic) # Mosaic plot with ggplot [geom_mosaic()]
library(corrplot) # Correlation plot [corrplot()]
library(ggpubr) # Arranging ggplots together [ggarrange()]
library(cowplot) # Arranging ggplots together [plot_grid()]
library(caret) # ML [train(), confusionMatrix(), createDataPartition(), varImp(), trainControl()]
library(ROCR) # Model performance [performance(), prediction()]
library(plotROC) # ROC Curve with ggplot [geom_roc()]
library(pROC) # AUC computation [auc()]
library(PRROC) # AUPR computation [pr.curve()]
library(vcd)
library(rpart) # Decision trees [rpart(), plotcp(), prune()]
library(rpart.plot) # Decision trees plotting [rpart.plot()]
library(ranger) # Optimized Random Forest [ranger()]
library(lightgbm) # Light GBM [lgb.train()]
library(xgboost) # XGBoost [xgb.DMatrix(), xgb.train()]
library(MLmetrics) # Custom metrics (F1 score for example)
library(tidyverse) # Data manipulation
#library(doMC) # Parallel processing
bank = read.csv(file = "./bank-additional-full.csv",
sep = ";",
stringsAsFactors = F)
head(bank)
rows <- nrow(bank)
dim(bank)
## [1] 41188 21
names(bank)
## [1] "age" "job" "marital" "education"
## [5] "default" "housing" "loan" "contact"
## [9] "month" "day_of_week" "duration" "campaign"
## [13] "pdays" "previous" "poutcome" "emp.var.rate"
## [17] "cons.price.idx" "cons.conf.idx" "euribor3m" "nr.employed"
## [21] "y"
CrossTable(bank$y)
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Table Total |
## |-------------------------|
##
##
## Total Observations in Table: 41188
##
##
## | no | yes |
## |-----------|-----------|
## | 36548 | 4640 |
## | 0.887 | 0.113 |
## |-----------|-----------|
##
##
##
##
bank = bank %>%
mutate(y = factor(if_else(y == "yes", "1", "0"),
levels = c("0", "1")))
sum(is.na(bank))
## [1] 0
There are 12,718 unknown values in the dataset, let’s try to find out which variables suffer the most from those “missing values”.
6 features have at least 1 unknown value. Before deciding how to manage those missing values, we’ll study each variable and take a decision after visualisations. We can’t afford to delete 8,597 rows in our dataset, it’s more than 20% of our observations.
sum(bank == "unknown")
## [1] 12718
# summarise_all is a function from the dplyr package that applies a function to all columns in the data frame. In this case, the function is list(~sum(. == "unknown")), which uses the ~ notation to define an anonymous function that counts the number of times the value "unknown" appears in each column. The output is a new data frame with a single row that summarizes the counts for each column.
#
# gather is a function from the tidyr package that reshapes data from wide format to long format. The key parameter specifies the name of the new column that will contain the names of the original columns, and the value parameter specifies the name of the new column that will contain the counts. The resulting data frame has three columns: "variable", "nr_unknown", and the names of the original columns.
#
# arrange is a function from the dplyr package that sorts rows based on one or more columns. In this case, -Count specifies that the rows should be sorted in descending order based on the values in the "Count" column.
bank %>%
summarise_all(list(~sum(. == "unknown"))) %>%
gather(key = "Variable Name", value = "Unknown_Count") %>%
arrange(-Unknown_Count)
crosstable_f = function(df, x1, x2){
# df: dataframe containing both columns to cross
# var1, var2: columns to cross together.
CrossTable(df[, x1], df[, x2],
prop.r = T, # include raw percentages
prop.c = F,
prop.t = F,
prop.chisq = F,
dnn = c(x1, x2)) # set names of columns
}
mosaic_theme <- theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5),
axis.text.y = element_blank(),
axis.ticks.y = element_blank())
fun_mosaic_plot <- function(data, x_var, y_var, x_label, y_label){
data %>%
ggplot() +
geom_mosaic(aes_string(x = paste0("product(", y_var, ",", x_var, ")"), fill = y_var)) +
mosaic_theme +
xlab(x_label) +
ylab(y_label)
}
summary(bank$age)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 17.00 32.00 38.00 40.02 47.00 98.00
bank %>%
ggplot() + # begin plot construction
aes(x = age) + # specify x axis = age
geom_bar() + # plot a bar plot faceted on y
facet_grid(y ~ .,
scales = "free_y") + # scales vary across y, y-axis scales of each panel will be independent of each other
scale_x_continuous(breaks = seq(0, 100, 5))
# first creates the age_60 variable Then, count() is used to compute the frequency of each combination of age_60 and y.
# Next, the data is grouped by y and nr_y is computed as the sum of n within each y group.
# Finally, relative_freq is computed as a percentage of n within each elder60-y group, and the results are ungrouped and selected for output.
bank %>%
mutate(age_60 = if_else(age > 60, "1", "0")) %>%
count(age_60, y) %>%
group_by(y) %>%
mutate(nr_y = sum(n)) %>%
mutate(relative_freq = round(100*n/nr_y, 2)) %>%
ungroup() %>%
select(age_60, y, n, relative_freq)
We can also slice the age feature at 30 years to make three easily interpretable classes : [0, 30[, [30, 60[ and [60, +Inf[. The minimum and maximum values are 17 and 98 but we can expect new observations outside this range. We’re replacing the continious variable “age” by this categorical variable.
We might lose some information from this continious-to-discrete transformation, but there wasn’t any clear pattern between years. Cutting into classes make the algorithms easier to interpret later.
bank = bank %>%
mutate(age = if_else(age > 60, "high", if_else(age > 30, "mid", "low")))
crosstable_f(bank, "age", "y")
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 41188
##
##
## | y
## age | 0 | 1 | Row Total |
## -------------|-----------|-----------|-----------|
## high | 496 | 414 | 910 |
## | 0.545 | 0.455 | 0.022 |
## -------------|-----------|-----------|-----------|
## low | 6259 | 1124 | 7383 |
## | 0.848 | 0.152 | 0.179 |
## -------------|-----------|-----------|-----------|
## mid | 29793 | 3102 | 32895 |
## | 0.906 | 0.094 | 0.799 |
## -------------|-----------|-----------|-----------|
## Column Total | 36548 | 4640 | 41188 |
## -------------|-----------|-----------|-----------|
##
##
# types of jobs
table(bank$job)
##
## admin. blue-collar entrepreneur housemaid management
## 10422 9254 1456 1060 2924
## retired self-employed services student technician
## 1720 1421 3969 875 6743
## unemployed unknown
## 1014 330
crosstable_f(bank, "job", "y")
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 41188
##
##
## | y
## job | 0 | 1 | Row Total |
## --------------|-----------|-----------|-----------|
## admin. | 9070 | 1352 | 10422 |
## | 0.870 | 0.130 | 0.253 |
## --------------|-----------|-----------|-----------|
## blue-collar | 8616 | 638 | 9254 |
## | 0.931 | 0.069 | 0.225 |
## --------------|-----------|-----------|-----------|
## entrepreneur | 1332 | 124 | 1456 |
## | 0.915 | 0.085 | 0.035 |
## --------------|-----------|-----------|-----------|
## housemaid | 954 | 106 | 1060 |
## | 0.900 | 0.100 | 0.026 |
## --------------|-----------|-----------|-----------|
## management | 2596 | 328 | 2924 |
## | 0.888 | 0.112 | 0.071 |
## --------------|-----------|-----------|-----------|
## retired | 1286 | 434 | 1720 |
## | 0.748 | 0.252 | 0.042 |
## --------------|-----------|-----------|-----------|
## self-employed | 1272 | 149 | 1421 |
## | 0.895 | 0.105 | 0.035 |
## --------------|-----------|-----------|-----------|
## services | 3646 | 323 | 3969 |
## | 0.919 | 0.081 | 0.096 |
## --------------|-----------|-----------|-----------|
## student | 600 | 275 | 875 |
## | 0.686 | 0.314 | 0.021 |
## --------------|-----------|-----------|-----------|
## technician | 6013 | 730 | 6743 |
## | 0.892 | 0.108 | 0.164 |
## --------------|-----------|-----------|-----------|
## unemployed | 870 | 144 | 1014 |
## | 0.858 | 0.142 | 0.025 |
## --------------|-----------|-----------|-----------|
## unknown | 293 | 37 | 330 |
## | 0.888 | 0.112 | 0.008 |
## --------------|-----------|-----------|-----------|
## Column Total | 36548 | 4640 | 41188 |
## --------------|-----------|-----------|-----------|
##
##
bank = bank %>%
filter(job != "unknown")
head(bank)
fun_mosaic_plot(bank, "job", "y", "Job", "Proportion")
crosstable_f(bank, "marital", "y")
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 40858
##
##
## | y
## marital | 0 | 1 | Row Total |
## -------------|-----------|-----------|-----------|
## divorced | 4126 | 473 | 4599 |
## | 0.897 | 0.103 | 0.113 |
## -------------|-----------|-----------|-----------|
## married | 22178 | 2516 | 24694 |
## | 0.898 | 0.102 | 0.604 |
## -------------|-----------|-----------|-----------|
## single | 9889 | 1605 | 11494 |
## | 0.860 | 0.140 | 0.281 |
## -------------|-----------|-----------|-----------|
## unknown | 62 | 9 | 71 |
## | 0.873 | 0.127 | 0.002 |
## -------------|-----------|-----------|-----------|
## Column Total | 36255 | 4603 | 40858 |
## -------------|-----------|-----------|-----------|
##
##
bank = bank %>%
filter(marital != "unknown")
fun_mosaic_plot(bank, "marital", "y", "Job", "Proportion")
crosstable_f(bank, 'education', 'y')
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 40787
##
##
## | y
## education | 0 | 1 | Row Total |
## --------------------|-----------|-----------|-----------|
## basic.4y | 3695 | 423 | 4118 |
## | 0.897 | 0.103 | 0.101 |
## --------------------|-----------|-----------|-----------|
## basic.6y | 2077 | 187 | 2264 |
## | 0.917 | 0.083 | 0.056 |
## --------------------|-----------|-----------|-----------|
## basic.9y | 5536 | 470 | 6006 |
## | 0.922 | 0.078 | 0.147 |
## --------------------|-----------|-----------|-----------|
## high.school | 8436 | 1028 | 9464 |
## | 0.891 | 0.109 | 0.232 |
## --------------------|-----------|-----------|-----------|
## illiterate | 14 | 4 | 18 |
## | 0.778 | 0.222 | 0.000 |
## --------------------|-----------|-----------|-----------|
## professional.course | 4631 | 594 | 5225 |
## | 0.886 | 0.114 | 0.128 |
## --------------------|-----------|-----------|-----------|
## university.degree | 10442 | 1654 | 12096 |
## | 0.863 | 0.137 | 0.297 |
## --------------------|-----------|-----------|-----------|
## unknown | 1362 | 234 | 1596 |
## | 0.853 | 0.147 | 0.039 |
## --------------------|-----------|-----------|-----------|
## Column Total | 36193 | 4594 | 40787 |
## --------------------|-----------|-----------|-----------|
##
##
bank = bank %>%
filter(education != "illiterate")
Among the 1,596 rows containing the “unknown” value, 234 of them subscribed to a term deposit. This is around 5% of the total group of subscribers. Since we’re facing a very unbalanced dependent variable situation, we can not afford to discard those rows. Because this category has the highest relative frequency of “y = 1” (14.7%), we’re going to add them in the “university.degree” level. It has the second highest “y = 1” relative frequency (13.7%).
It appears that a positive correlation between the number of years of education and the odds to subscribe to a term deposit exists.
bank = bank %>%
mutate(education = recode(education, "unknown" = "university.degree"))
fun_mosaic_plot(bank, "education", "y", "Job", "Proportion")
bank %>%
ggplot() +
aes(x = education, y = ..count../rows, fill = y) +
geom_bar() +
ylab("relative frequency")
crosstable_f(bank, "default", "y")
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 40769
##
##
## | y
## default | 0 | 1 | Row Total |
## -------------|-----------|-----------|-----------|
## no | 28182 | 4155 | 32337 |
## | 0.872 | 0.128 | 0.793 |
## -------------|-----------|-----------|-----------|
## unknown | 7994 | 435 | 8429 |
## | 0.948 | 0.052 | 0.207 |
## -------------|-----------|-----------|-----------|
## yes | 3 | 0 | 3 |
## | 1.000 | 0.000 | 0.000 |
## -------------|-----------|-----------|-----------|
## Column Total | 36179 | 4590 | 40769 |
## -------------|-----------|-----------|-----------|
##
##
bank = bank %>%
select(-default)
One way to assess the strength and significance of the association between the response and a categorical predictor is by using a chi-squared test.
The chi-squared test helps us to determine whether there is a statistically significant difference in the distribution of the response variable across different categories of the predictor variable. In other words, it tests whether the proportions of the response variable in different levels of the categorical predictor variable are significantly different from each other
The p-value associated to the Chi-squared test equals to 0.065, which is higher than a 0.05-threshold. So, for a confidence level of 95%, there’s no association between the dependent variable y and our feature housing. We’re removing it from the dataset.
# Stacked bar chart for categorical variable and binary response variable
bank %>%
ggplot(aes(x = y, fill = housing)) +
geom_bar() +
labs(x = "Response Variable", y = "Count") +
scale_fill_discrete(name = "Housing Loan")
crosstable_f(bank, "housing", "y")
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 40769
##
##
## | y
## housing | 0 | 1 | Row Total |
## -------------|-----------|-----------|-----------|
## no | 16416 | 2003 | 18419 |
## | 0.891 | 0.109 | 0.452 |
## -------------|-----------|-----------|-----------|
## unknown | 877 | 107 | 984 |
## | 0.891 | 0.109 | 0.024 |
## -------------|-----------|-----------|-----------|
## yes | 18886 | 2480 | 21366 |
## | 0.884 | 0.116 | 0.524 |
## -------------|-----------|-----------|-----------|
## Column Total | 36179 | 4590 | 40769 |
## -------------|-----------|-----------|-----------|
##
##
chisq.test(bank$housing, bank$y)
##
## Pearson's Chi-squared test
##
## data: bank$housing and bank$y
## X-squared = 5.4627, df = 2, p-value = 0.06513
bank = bank %>%
select(-housing)
crosstable_f(bank, "loan", "y")
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 40769
##
##
## | y
## loan | 0 | 1 | Row Total |
## -------------|-----------|-----------|-----------|
## no | 29799 | 3806 | 33605 |
## | 0.887 | 0.113 | 0.824 |
## -------------|-----------|-----------|-----------|
## unknown | 877 | 107 | 984 |
## | 0.891 | 0.109 | 0.024 |
## -------------|-----------|-----------|-----------|
## yes | 5503 | 677 | 6180 |
## | 0.890 | 0.110 | 0.152 |
## -------------|-----------|-----------|-----------|
## Column Total | 36179 | 4590 | 40769 |
## -------------|-----------|-----------|-----------|
##
##
chisq.test(bank$loan, bank$y)
##
## Pearson's Chi-squared test
##
## data: bank$loan and bank$y
## X-squared = 0.86841, df = 2, p-value = 0.6478
bank = bank %>%
select(-loan)
crosstable_f(bank, "contact", "y")
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 40769
##
##
## | y
## contact | 0 | 1 | Row Total |
## -------------|-----------|-----------|-----------|
## cellular | 22098 | 3815 | 25913 |
## | 0.853 | 0.147 | 0.636 |
## -------------|-----------|-----------|-----------|
## telephone | 14081 | 775 | 14856 |
## | 0.948 | 0.052 | 0.364 |
## -------------|-----------|-----------|-----------|
## Column Total | 36179 | 4590 | 40769 |
## -------------|-----------|-----------|-----------|
##
##
fun_mosaic_plot(bank, "contact", "y", "Job", "Proportion")
bank %>%
ggplot() +
aes(x = contact, y = ..count../rows, fill = y) +
geom_bar() +
ylab("relative frequency")
crosstable_f(bank, 'month', 'y')
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 40769
##
##
## | y
## month | 0 | 1 | Row Total |
## -------------|-----------|-----------|-----------|
## apr | 2082 | 536 | 2618 |
## | 0.795 | 0.205 | 0.064 |
## -------------|-----------|-----------|-----------|
## aug | 5459 | 644 | 6103 |
## | 0.894 | 0.106 | 0.150 |
## -------------|-----------|-----------|-----------|
## dec | 92 | 88 | 180 |
## | 0.511 | 0.489 | 0.004 |
## -------------|-----------|-----------|-----------|
## jul | 6471 | 642 | 7113 |
## | 0.910 | 0.090 | 0.174 |
## -------------|-----------|-----------|-----------|
## jun | 4697 | 548 | 5245 |
## | 0.896 | 0.104 | 0.129 |
## -------------|-----------|-----------|-----------|
## mar | 267 | 274 | 541 |
## | 0.494 | 0.506 | 0.013 |
## -------------|-----------|-----------|-----------|
## may | 12734 | 882 | 13616 |
## | 0.935 | 0.065 | 0.334 |
## -------------|-----------|-----------|-----------|
## nov | 3672 | 412 | 4084 |
## | 0.899 | 0.101 | 0.100 |
## -------------|-----------|-----------|-----------|
## oct | 396 | 311 | 707 |
## | 0.560 | 0.440 | 0.017 |
## -------------|-----------|-----------|-----------|
## sep | 309 | 253 | 562 |
## | 0.550 | 0.450 | 0.014 |
## -------------|-----------|-----------|-----------|
## Column Total | 36179 | 4590 | 40769 |
## -------------|-----------|-----------|-----------|
##
##
bank %>%
ggplot() +
aes(x = month, y = ..count../rows, fill = y) +
geom_bar() +
ylab("relative frequency")
crosstable_f(bank, 'day_of_week', 'y')
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 40769
##
##
## | y
## day_of_week | 0 | 1 | Row Total |
## -------------|-----------|-----------|-----------|
## fri | 6936 | 839 | 7775 |
## | 0.892 | 0.108 | 0.191 |
## -------------|-----------|-----------|-----------|
## mon | 7578 | 841 | 8419 |
## | 0.900 | 0.100 | 0.207 |
## -------------|-----------|-----------|-----------|
## thu | 7493 | 1031 | 8524 |
## | 0.879 | 0.121 | 0.209 |
## -------------|-----------|-----------|-----------|
## tue | 7056 | 945 | 8001 |
## | 0.882 | 0.118 | 0.196 |
## -------------|-----------|-----------|-----------|
## wed | 7116 | 934 | 8050 |
## | 0.884 | 0.116 | 0.197 |
## -------------|-----------|-----------|-----------|
## Column Total | 36179 | 4590 | 40769 |
## -------------|-----------|-----------|-----------|
##
##
bank %>%
ggplot() +
aes(x = day_of_week, y = ..count../rows, fill = y) +
geom_bar() +
ylab("relative frequency")
bank = bank %>%
select(-day_of_week)
ggplot(bank, aes(x = duration, fill = y)) +
geom_density(alpha = 0.5) +
ggtitle("Distribution of y and duration") +
xlab("Duration") +
ylab("Density")
# How many times was the client contacted during this campaign?
bank %>%
ggplot() +
aes(x = campaign, y = y, fill = y) +
geom_boxplot() +
ylab("relative frequency")
bank %>%
ggplot() + # begin plot construction
aes(x = campaign) + # specify x axis = campaign
geom_bar() + # plot a bar plot faceted on y
facet_grid(y ~ .,
scales = "free_y") + # scales vary across y, y-axis scales of each panel will be independent of each other
scale_x_continuous(breaks = seq(0, 100, 5))
bank = bank %>%
filter(campaign <= 10)
bank %>%
ggplot() + # begin plot construction
aes(x = campaign) + # specify x axis = campaign
geom_bar() + # plot a bar plot faceted on y
facet_grid(y ~ .,
scales = "free_y") + # scales vary across y, y-axis scales of each panel will be independent of each other
scale_x_continuous(breaks = seq(0, 100, 5))
- We can see that more number of times a customer is contacted in a
campaign, the less likely it is they will subscribe for a term
deposit
bank %>%
ggplot() +
aes(x = campaign, y = ..count../rows, fill = y) +
geom_bar() +
ylab("relative frequency")
fun_mosaic_plot(bank, "campaign", "y", "Campaign", "Proportion")
This is the number of days that passed by after the client was last contacted from a previous campaign. 999 value means the client wasn’t previously contacted. Let’s make a dummy out of it.
Clients who haven’t been contacted in a previous campaign will be labeled “0” in the pdays_dummy variable.
Interesting to note that, Recontacting a client after a previous campaign seems to highly increase the odds of subscription
# number of days that passed by after the client was last contacted from a previous campaign
bank_counts <- table(bank$pdays)
c <- data.frame(bank_counts)
c
bank = bank %>%
mutate(pdays_dummy = if_else(pdays == 999, "0", "1")) %>%
select(-pdays)
fun_mosaic_plot(bank, "pdays_dummy", "y", "Previous Campaign Contact", "Proportion")
# previous: number of contacts performed before this campaign and for this client
fun_mosaic_plot(bank, "previous", "y", "Number of Contacts in Previous Campaign", "Proportion")
bank %>%
ggplot() +
aes(x = previous, y = ..count../rows, fill = y) +
geom_bar() +
ylab("relative frequency")
crosstable_f(bank, "poutcome", "y")
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## |-------------------------|
##
##
## Total Observations in Table: 39912
##
##
## | y
## poutcome | 0 | 1 | Row Total |
## -------------|-----------|-----------|-----------|
## failure | 3616 | 595 | 4211 |
## | 0.859 | 0.141 | 0.106 |
## -------------|-----------|-----------|-----------|
## nonexistent | 31268 | 3085 | 34353 |
## | 0.910 | 0.090 | 0.861 |
## -------------|-----------|-----------|-----------|
## success | 464 | 884 | 1348 |
## | 0.344 | 0.656 | 0.034 |
## -------------|-----------|-----------|-----------|
## Column Total | 35348 | 4564 | 39912 |
## -------------|-----------|-----------|-----------|
##
##
3 month rate is the interest rate at which euro interbank term deposits are offered by one prime bank to another prime bank within the eurozone. It is an important benchmark used in financial markets and is used as a reference rate for various financial instruments. In the context of a person subscribing for a term deposit, a higher euribor 3 month rate generally indicates that interest rates are higher, and the returns on the term deposit will also be higher. This may make the term deposit more attractive to potential investors, including individuals who may be considering subscribing for a term deposit. Conversely, a lower euribor 3 month rate may make term deposits less attractive, as the returns on the investment would be lower.
is an economic indicator that measures changes in the number of employed people in a given population over a certain period, usually a quarter. The rate is calculated as the percentage change in the number of employed people from one period to another. The employment variation rate can have an impact on a person’s decision to subscribe for a term deposit in several ways. For example, if the employment variation rate is high, it may indicate a strong and growing economy, which can lead to higher confidence among consumers and increased spending. This increased spending may include investments in term deposits as a way to save money and earn interest.
library(reshape2)
# Select the variables to include in the correlation matrix
vars <- bank %>%
select(emp.var.rate, cons.price.idx, cons.conf.idx, euribor3m, nr.employed)
# Calculate the correlation matrix
cor_matrix <- cor(vars)
# Convert the correlation matrix to a data frame
cor_df <- cor_matrix %>%
as.data.frame() %>%
rownames_to_column(var = "var1") %>%
pivot_longer(cols = -var1, names_to = "var2", values_to = "correlation")
# Plot the correlation heatmap
ggplot(cor_df, aes(var1, var2)) +
geom_tile(aes(fill = correlation)) +
scale_fill_gradient2(low = "blue", mid = "white", high = "red", midpoint = 0) +
geom_text(aes(label = round(correlation, 2)), color = "black", size = 3) +
theme_minimal() +
theme(axis.text.x = element_text(angle = 90)) +
coord_fixed() +
labs(title = "Correlation Heatmap", x = "", y = "")
pivot_longer function is used to reshape the data frame from a wide format to a long format. It takes the columns emp.var.rate, cons.price.idx, cons.conf.idx, euribor3m, nr.employed and stacks them into a single column named variable, while creating a new column named value to store the actual values.
facet_wrap function is used to create a faceted plot. It takes the variable column as the facetting variable and sets the scales argument to “free”, so that the scales of each panel can be adjusted independently.
geom_histogram is used to add a histogram layer to the plot.
scale_fill_manual is used to set the fill colors for the two levels of the y variable.
theme is used to customize the appearance of the plot. In this case, it is used to tilt the x axis labels by 90 degrees.
bank %>%
pivot_longer(cols = c(emp.var.rate, cons.price.idx, cons.conf.idx, euribor3m, nr.employed),
names_to = "variable") %>%
ggplot(aes(x = value, fill = y)) +
facet_wrap(~variable, scales = "free") +
geom_histogram(alpha = 0.5, position = "identity", bins = 20) +
xlab("Variable Value") +
ylab("Count") +
ggtitle("Histogram of Financial Variables against Response") +
scale_fill_manual(values = c("#999999", "#E69F00")) + # custom colors for the fill
theme(axis.text.x = element_text(angle = 90))
library(tidyverse)
# Scatterplot matrix
plot(bank %>%
select(emp.var.rate, cons.price.idx, cons.conf.idx, euribor3m, nr.employed))
# Boxplot of employment variation rate by response
ggplot(bank, aes(x = y, y = emp.var.rate)) +
geom_boxplot()
# Density plot of consumer price index by response
ggplot(bank, aes(x = cons.price.idx, fill = y)) +
geom_density(alpha = 0.5)
# Histogram of consumer confidence index by response
ggplot(bank, aes(x = cons.conf.idx, fill = y)) +
geom_histogram(alpha = 0.5, bins = 30)
# Density plot of euribor 3 month rate by response
ggplot(bank, aes(x = euribor3m, fill = y)) +
geom_density(alpha = 0.5)
# Boxplot of number of employees by response
ggplot(bank, aes(x = y, y = nr.employed)) +
geom_boxplot()
library(ggplot2)
library(dplyr)
# Scatter plots for numeric variables
bank %>%
ggplot(aes(x = emp.var.rate, y = nr.employed)) +
geom_point() +
labs(x = "Employment Variation Rate", y = "Number of Employees")
bank %>%
ggplot(aes(x = cons.price.idx, y = euribor3m)) +
geom_point() +
labs(x = "Consumer Price Index", y = "Euribor 3 Month Rate")
bank %>%
ggplot(aes(x = cons.conf.idx, y = nr.employed)) +
geom_point() +
labs(x = "Consumer Confidence Index", y = "Number of Employees")
# Boxplots for numeric variables and binary response variable
bank %>%
ggplot(aes(x = y, y = emp.var.rate)) +
geom_boxplot() +
labs(x = "Response Variable", y = "Employment Variation Rate")
bank %>%
ggplot(aes(x = y, y = cons.price.idx)) +
geom_boxplot() +
labs(x = "Response Variable", y = "Consumer Price Index")
bank %>%
ggplot(aes(x = y, y = cons.conf.idx)) +
geom_boxplot() +
labs(x = "Response Variable", y = "Consumer Confidence Index")
bank %>%
ggplot(aes(x = y, y = euribor3m)) +
geom_boxplot() +
labs(x = "Response Variable", y = "Euribor 3 Month Rate")
bank %>%
ggplot(aes(x = y, y = nr.employed)) +
geom_boxplot() +
labs(x = "Response Variable", y = "Number of Employees")
library(rcompanion)
cat_vars = c(2, 3, 4, 5, 6, 7, 8, 9, 10, 15)
y_cV <- setNames(data.frame(matrix(ncol = 4, nrow = 0)), c("Variable", "Pvalue", "CramerV", "DF"))
for (var in cat_vars){
ct <- xtabs(~bank[, var]+y, data=bank)
cV <- cramerV(ct)
y_cV[nrow(y_cV)+1, 1] <- names(bank)[var]
y_cV[nrow(y_cV), 2] <- round(chisq.test(ct)$p.value, 4)
y_cV[nrow(y_cV), 3] <- cV
y_cV[nrow(y_cV), 4] <- min(length(unique(bank[, var]))-1, length(unique(bank))-1)
}
y_cV <- y_cV[order(y_cV$CramerV, decreasing=TRUE), ]
y_cV
# y is independent of housing and loan
# Create a bar plot
ggplot(y_cV[c("Variable","CramerV")], aes(x = Variable, y = CramerV, fill=CramerV)) +
geom_col() +
#scale_fill_gradient(low = "white", high = "red") +
ggtitle("Cramer V for different variables") +
xlab("Variables") +
ylab("Cramer V vs Response") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5))
cat_vars2 = c(2, 3, 4, 5, 8, 9, 10, 15)
cat_cV <- setNames(data.frame(matrix(ncol = 4, nrow = 0)), c("Variable1", "Variable2", "CramerV", "DF"))
for (var1 in cat_vars2){
for (var2 in cat_vars2){
if (var1 == var2){break}
else{
ct <- xtabs(~bank[, var1]+bank[, var2], data=bank)
cV <- cramerV(ct)
cat_cV[nrow(cat_cV)+1, 1] <- names(bank)[var1]
cat_cV[nrow(cat_cV), 2] <- names(bank)[var2]
cat_cV[nrow(cat_cV), 3] <- cV
cat_cV[nrow(cat_cV), 4] <- min(length(unique(bank[, var1]))-1, length(unique(bank[, var2]))-1)
}
}
}
cat_cV <- cat_cV[order(cat_cV$CramerV, decreasing=TRUE), ]
cat_cV
bank = bank %>%
select(-duration)
bank = bank %>%
select(-emp.var.rate)
library(caret)
bank[sapply(bank, is.character)] <- lapply(bank[sapply(bank, is.character)], as.factor)
# Split the dataset into training (60%), testing (20%), and validation (20%) sets
set.seed(123) # For reproducibility
trainIndex <- createDataPartition(bank$y, p = 0.6, list = FALSE, times = 1)
train_data <- bank[trainIndex, ]
test <- bank[-trainIndex, ]
testIndex <- createDataPartition(test$y, p = 0.5, list = FALSE, times = 1)
test_data <- test[-testIndex, ]
val_data <- test[testIndex, ]
truth_values <- train_data$y
# Check the class distribution in each dataset
prop.table(table(bank$y))
##
## 0 1
## 0.8856484 0.1143516
prop.table(table(train_data$y))
##
## 0 1
## 0.8856272 0.1143728
prop.table(table(test_data$y))
##
## 0 1
## 0.8857286 0.1142714
prop.table(table(val_data$y))
##
## 0 1
## 0.885632 0.114368
nrow(train_data)
## [1] 23948
nrow(test_data)
## [1] 7981
nrow(val_data)
## [1] 7983
# PLots f1 for different thresholds
plot_f1_vs_threshold <- function(pred, true) {
# Compute f1 scores across different thresholds
threshold_seq <- seq(0, 1, by = 0.01)
f1_scores <- lapply(threshold_seq, function(th) {
pred_labels <- ifelse(pred > th, 1, 0)
f1 <- 2 * sum((pred_labels == 1) & (true == 1)) / (sum(pred_labels == 1) + sum(true == 1))
return(f1)
})
# Convert to data frame
df <- data.frame(threshold = threshold_seq, f1_score = unlist(f1_scores))
# Find the threshold with the highest F1 score
threshold_max_f1 <- df %>% filter(f1_score == max(f1_score)) %>% pull(threshold)
f1_max <- df %>% filter(f1_score == max(f1_score)) %>% pull(f1_score)
# printing text
best_f1_measure = paste("Best F1 Score: x =", round(f1_max, 3))
best_cutoff_measure = paste("Best cutoff for F1 Score: x =", round(threshold_max_f1, 3))
txt_tot = cat(best_f1_measure, "\n", best_cutoff_measure, sep = "")
# Plot f1 score vs threshold
ggplot(df, aes(x = threshold, y = f1_score)) +
geom_line() +
geom_hline(yintercept = f1_max, linetype = "dashed", color = "red") +
geom_vline(xintercept = threshold_max_f1, linetype = "dashed", color = "red") +
labs(caption = txt_tot) +
scale_x_continuous(limits = c(0, 1)) +
labs(x = "Threshold", y = "F1 Score")
}
# Code to plot variable importance
plot_varimp <- function(model, top_n = Inf, title = "Variable Importance", rf=T) {
# Get variable importance
if (rf == TRUE){
varimp <- model$variable.importance
varimp <- varimp[order(varimp, decreasing = TRUE)]
# Limit to top n variables
if (is.finite(top_n)) {
varimp <- varimp[1:min(length(varimp), top_n)]
}
# Create data frame for plotting
varimp_df <- data.frame(variable = names(varimp), importance = varimp)
} else {
imp_df <- varImp(model)
names(imp_df)[names(imp_df) == "Overall"] <- "importance"
imp_df = imp_df %>%
rownames_to_column(var="variable") %>%
arrange(-importance)
varimp_df <- head(imp_df, top_n)
}
# Plot bar chart
p <- ggplot(varimp_df, aes(x = reorder(variable, importance, decreasing=T), y = importance)) +
geom_bar(stat = "identity", fill = "steelblue") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 90)) +
labs(title = title, x = "Variable", y = "Importance")
return(p)
}
plot_roc_auc <- function(actual, predicted) {
# Create a data frame with the actual and predicted values
df <- data.frame(actual, predicted)
# Calculate sensitivity and specificity for different threshold values
threshold <- seq(0, 1, length.out = 100)
sens <- sapply(threshold, function(x) sum(predicted[actual == 1] >= x) / sum(actual == 1))
spec <- sapply(threshold, function(x) sum(predicted[actual == 0] < x) / sum(actual == 0))
# Combine the results into a data frame
results <- data.frame(threshold, sens, spec)
# Create the plot
ggplot(results, aes(x = 1 - spec, y = sens)) +
geom_line() +
geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
labs(title = "ROC AUC",
x = "1 - Specificity",
y = "Sensitivity")
}
plot_roc <- function(train_truth, test_truth, val_truth, trainp, testp, valp) {
train_roc <- roc(train_truth, trainp)
test_roc <- roc(test_truth, testp)
val_roc <- roc(val_truth, valp)
{
plot(train_roc, type = "S", col = "blue", xlab = "Recall", ylab = "Precision", main = "Precision-Recall Curve")
lines(test_roc, type = "S", col = "red")
lines(val_roc, type = "S", col = "maroon")
legend("bottomleft", c("Train", "Test", "Validation"), lty = 1, col = c("blue", "red", "maroon"))
}
}
plot_sens_spec <- function(actual, predicted) {
# Create a data frame with the actual and predicted values
df <- data.frame(actual, predicted)
# Calculate sensitivity and specificity for different threshold values
threshold <- seq(0, 1, length.out = 100)
sens <- sapply(threshold, function(x) sum(predicted[actual == 1] >= x) / sum(actual == 1))
spec <- sapply(threshold, function(x) sum(predicted[actual == 0] < x) / sum(actual == 0))
# Create the plot
matplot(threshold,cbind(sens,spec),type="l",xlab="Threshold",
ylab="Proportion",lty=1:2,cex.lab=1.5)
legend(.35,0.8, c("spec","sens"), lty=c(2,1),col=c(2,1),cex=1.5)
}
# job:education + age:marital + nr.employed:previous + previous:contact
# + job:education + job:marital + nr.employed:previous + age:contact
lmod <- glm(y ~ .,
data = train_data,
family = "binomial")
summary(lmod)
##
## Call:
## glm(formula = y ~ ., family = "binomial", data = train_data)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.1859 -0.3910 -0.3301 -0.2605 2.7860
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 85.086749 28.506946 2.985 0.002838 **
## agelow -0.179879 0.150590 -1.194 0.232284
## agemid -0.235427 0.137068 -1.718 0.085872 .
## jobblue-collar -0.191425 0.088321 -2.167 0.030205 *
## jobentrepreneur -0.051428 0.136699 -0.376 0.706755
## jobhousemaid -0.043258 0.160331 -0.270 0.787309
## jobmanagement -0.120108 0.097345 -1.234 0.217264
## jobretired 0.194827 0.129758 1.501 0.133236
## jobself-employed 0.034118 0.126224 0.270 0.786934
## jobservices -0.257345 0.098335 -2.617 0.008870 **
## jobstudent 0.132654 0.126602 1.048 0.294730
## jobtechnician -0.123156 0.080994 -1.521 0.128369
## jobunemployed -0.063584 0.143344 -0.444 0.657348
## maritalmarried 0.103884 0.079139 1.313 0.189292
## maritalsingle 0.139359 0.088678 1.572 0.116063
## educationbasic.6y -0.021776 0.139695 -0.156 0.876127
## educationbasic.9y -0.009546 0.107369 -0.089 0.929152
## educationhigh.school 0.069019 0.103676 0.666 0.505595
## educationprofessional.course 0.069309 0.114878 0.603 0.546292
## educationuniversity.degree 0.165239 0.101034 1.635 0.101949
## contacttelephone -0.454409 0.078144 -5.815 6.06e-09 ***
## monthaug -0.159740 0.118536 -1.348 0.177787
## monthdec 0.056154 0.245474 0.229 0.819059
## monthjul 0.118661 0.107037 1.109 0.267606
## monthjun 0.244799 0.105780 2.314 0.020656 *
## monthmar 0.570147 0.152284 3.744 0.000181 ***
## monthmay -0.711175 0.085650 -8.303 < 2e-16 ***
## monthnov -0.443611 0.137680 -3.222 0.001273 **
## monthoct -0.367721 0.178018 -2.066 0.038862 *
## monthsep -0.677734 0.189292 -3.580 0.000343 ***
## campaign -0.044135 0.015426 -2.861 0.004222 **
## previous -0.144290 0.072895 -1.979 0.047769 *
## poutcomenonexistent 0.381601 0.112115 3.404 0.000665 ***
## poutcomesuccess 0.700424 0.266456 2.629 0.008572 **
## cons.price.idx -0.166727 0.157455 -1.059 0.289653
## cons.conf.idx 0.010285 0.009082 1.133 0.257408
## euribor3m 0.152148 0.152573 0.997 0.318660
## nr.employed -0.013900 0.002920 -4.760 1.93e-06 ***
## pdays_dummy1 1.251361 0.270141 4.632 3.62e-06 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 17030 on 23947 degrees of freedom
## Residual deviance: 13430 on 23909 degrees of freedom
## AIC: 13508
##
## Number of Fisher Scoring iterations: 6
plot_varimp(lmod, top_n = 10, title = "Top 5 Variable Importance", rf=F)
lmod_f <- step(lmod, direction = "backward", trace=T)
## Start: AIC=13507.99
## y ~ age + job + marital + education + contact + month + campaign +
## previous + poutcome + cons.price.idx + cons.conf.idx + euribor3m +
## nr.employed + pdays_dummy
##
## Df Deviance AIC
## - education 5 13436 13504
## - marital 2 13432 13506
## - euribor3m 1 13431 13507
## - cons.price.idx 1 13431 13507
## - cons.conf.idx 1 13431 13507
## - age 2 13434 13508
## - job 10 13450 13508
## <none> 13430 13508
## - previous 1 13434 13510
## - campaign 1 13438 13514
## - poutcome 2 13446 13520
## - pdays_dummy 1 13451 13527
## - nr.employed 1 13453 13529
## - contact 1 13465 13541
## - month 9 13646 13706
##
## Step: AIC=13503.81
## y ~ age + job + marital + contact + month + campaign + previous +
## poutcome + cons.price.idx + cons.conf.idx + euribor3m + nr.employed +
## pdays_dummy
##
## Df Deviance AIC
## - euribor3m 1 13437 13503
## - marital 2 13439 13503
## - cons.price.idx 1 13437 13503
## - cons.conf.idx 1 13437 13503
## - age 2 13439 13503
## <none> 13436 13504
## - previous 1 13440 13506
## - campaign 1 13444 13510
## - job 10 13466 13514
## - poutcome 2 13452 13516
## - pdays_dummy 1 13457 13523
## - nr.employed 1 13459 13525
## - contact 1 13472 13538
## - month 9 13656 13706
##
## Step: AIC=13502.87
## y ~ age + job + marital + contact + month + campaign + previous +
## poutcome + cons.price.idx + cons.conf.idx + nr.employed +
## pdays_dummy
##
## Df Deviance AIC
## - cons.price.idx 1 13437 13501
## - marital 2 13440 13502
## - age 2 13440 13502
## <none> 13437 13503
## - previous 1 13441 13505
## - cons.conf.idx 1 13445 13509
## - campaign 1 13446 13510
## - job 10 13468 13514
## - poutcome 2 13453 13515
## - pdays_dummy 1 13458 13522
## - contact 1 13473 13537
## - month 9 13673 13721
## - nr.employed 1 14157 14221
##
## Step: AIC=13500.98
## y ~ age + job + marital + contact + month + campaign + previous +
## poutcome + cons.conf.idx + nr.employed + pdays_dummy
##
## Df Deviance AIC
## - marital 2 13440 13500
## - age 2 13440 13500
## <none> 13437 13501
## - previous 1 13441 13503
## - campaign 1 13446 13508
## - cons.conf.idx 1 13448 13510
## - job 10 13468 13512
## - poutcome 2 13453 13513
## - pdays_dummy 1 13458 13520
## - contact 1 13487 13549
## - month 9 13686 13732
## - nr.employed 1 14173 14235
##
## Step: AIC=13500.12
## y ~ age + job + contact + month + campaign + previous + poutcome +
## cons.conf.idx + nr.employed + pdays_dummy
##
## Df Deviance AIC
## <none> 13440 13500
## - age 2 13444 13500
## - previous 1 13444 13502
## - campaign 1 13449 13507
## - cons.conf.idx 1 13451 13509
## - job 10 13472 13512
## - poutcome 2 13457 13513
## - pdays_dummy 1 13461 13519
## - contact 1 13490 13548
## - month 9 13690 13732
## - nr.employed 1 14181 14239
lmod_f
##
## Call: glm(formula = y ~ age + job + contact + month + campaign + previous +
## poutcome + cons.conf.idx + nr.employed + pdays_dummy, family = "binomial",
## data = train_data)
##
## Coefficients:
## (Intercept) agelow agemid
## 56.04435 -0.12900 -0.21442
## jobblue-collar jobentrepreneur jobhousemaid
## -0.29910 -0.08210 -0.12956
## jobmanagement jobretired jobself-employed
## -0.10930 0.11776 0.03273
## jobservices jobstudent jobtechnician
## -0.31619 0.10334 -0.14972
## jobunemployed contacttelephone monthaug
## -0.11887 -0.47510 -0.12122
## monthdec monthjul monthjun
## 0.12764 0.11671 0.21407
## monthmar monthmay monthnov
## 0.62986 -0.69974 -0.36196
## monthoct monthsep campaign
## -0.27547 -0.58194 -0.04477
## previous poutcomenonexistent poutcomesuccess
## -0.14556 0.38769 0.70419
## cons.conf.idx nr.employed pdays_dummy1
## 0.01812 -0.01109 1.24810
##
## Degrees of Freedom: 23947 Total (i.e. Null); 23918 Residual
## Null Deviance: 17030
## Residual Deviance: 13440 AIC: 13500
summary(lmod_f)
##
## Call:
## glm(formula = y ~ age + job + contact + month + campaign + previous +
## poutcome + cons.conf.idx + nr.employed + pdays_dummy, family = "binomial",
## data = train_data)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.1497 -0.3906 -0.3350 -0.2584 2.7672
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 56.0443469 1.9514880 28.719 < 2e-16 ***
## agelow -0.1289966 0.1465645 -0.880 0.378786
## agemid -0.2144181 0.1361576 -1.575 0.115307
## jobblue-collar -0.2991033 0.0725860 -4.121 3.78e-05 ***
## jobentrepreneur -0.0820966 0.1352856 -0.607 0.543957
## jobhousemaid -0.1295578 0.1535232 -0.844 0.398727
## jobmanagement -0.1092953 0.0957946 -1.141 0.253897
## jobretired 0.1177620 0.1253524 0.939 0.347501
## jobself-employed 0.0327288 0.1248999 0.262 0.793291
## jobservices -0.3161882 0.0942917 -3.353 0.000799 ***
## jobstudent 0.1033420 0.1246808 0.829 0.407188
## jobtechnician -0.1497205 0.0721528 -2.075 0.037982 *
## jobunemployed -0.1188718 0.1413978 -0.841 0.400521
## contacttelephone -0.4751003 0.0683749 -6.948 3.69e-12 ***
## monthaug -0.1212155 0.1129865 -1.073 0.283346
## monthdec 0.1276397 0.2334000 0.547 0.584467
## monthjul 0.1167137 0.0984333 1.186 0.235735
## monthjun 0.2140664 0.1001425 2.138 0.032548 *
## monthmar 0.6298647 0.1414603 4.453 8.48e-06 ***
## monthmay -0.6997387 0.0830111 -8.429 < 2e-16 ***
## monthnov -0.3619644 0.1077067 -3.361 0.000778 ***
## monthoct -0.2754652 0.1432417 -1.923 0.054470 .
## monthsep -0.5819378 0.1534541 -3.792 0.000149 ***
## campaign -0.0447699 0.0153960 -2.908 0.003639 **
## previous -0.1455625 0.0709488 -2.052 0.040203 *
## poutcomenonexistent 0.3876943 0.1104632 3.510 0.000449 ***
## poutcomesuccess 0.7041907 0.2663473 2.644 0.008196 **
## cons.conf.idx 0.0181243 0.0055743 3.251 0.001148 **
## nr.employed -0.0110950 0.0003918 -28.316 < 2e-16 ***
## pdays_dummy1 1.2481037 0.2700595 4.622 3.81e-06 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 17030 on 23947 degrees of freedom
## Residual deviance: 13440 on 23918 degrees of freedom
## AIC: 13500
##
## Number of Fisher Scoring iterations: 6
plot_varimp(lmod_f, top_n = 10, title = "Top 5 Variable Importance", rf=F)
logistic_train_prob = predict(lmod_f,
newdata = train_data,
type = "response")
logistic_test_prob = predict(lmod_f,
newdata = test_data,
type = "response")
logistic_val_prob = predict(lmod_f,
newdata = val_data,
type = "response")
plot_f1_vs_threshold(logistic_train_prob, truth_values)
## Best F1 Score: x = 0.488
## Best cutoff for F1 Score: x = 0.22
plot_f1_vs_threshold(logistic_test_prob, test_data$y)
## Best F1 Score: x = 0.491
## Best cutoff for F1 Score: x = 0.23
plot_f1_vs_threshold(logistic_val_prob, val_data$y)
## Best F1 Score: x = 0.49
## Best cutoff for F1 Score: x = 0.22
plot_roc_auc(truth_values, logistic_train_prob)
plot_roc_auc(test_data$y, logistic_test_prob)
plot_roc_auc(val_data$y, logistic_val_prob)
plot_roc(truth_values, test_data$y, val_data$y, logistic_train_prob, logistic_test_prob, logistic_val_prob)
plot_sens_spec(truth_values, logistic_train_prob)
plot_sens_spec(test_data$y, logistic_test_prob)
plot_sens_spec(val_data$y, logistic_val_prob)
library(faraway)
halfnorm(hatvalues(lmod_f))
bank_logreg <- mutate(train_data, residuals=residuals(lmod_f), linpred=predict(lmod_f))
gdf <- group_by(bank_logreg, cut(linpred, breaks=c(min(linpred),
unique(quantile(linpred, (1:100)/101)),max(linpred)),
include.lowest = TRUE))
diagdf <- summarise(gdf, residuals=mean(residuals), linpred=mean(linpred))
par(mar=c(5, 6, 4, 2) + 0.2) ### making plot margins a bit bigger
plot(residuals ~ linpred, diagdf, xlab="linear predictor",cex.lab=2)
##### Goodness of Fit Test
suppressMessages(library(ResourceSelection))
hoslem.test(lmod_f$y,fitted(lmod_f),g=10)
##
## Hosmer and Lemeshow goodness of fit (GOF) test
##
## data: lmod_f$y, fitted(lmod_f)
## X-squared = 26.08, df = 8, p-value = 0.001018
lmod_f
##
## Call: glm(formula = y ~ age + job + contact + month + campaign + previous +
## poutcome + cons.conf.idx + nr.employed + pdays_dummy, family = "binomial",
## data = train_data)
##
## Coefficients:
## (Intercept) agelow agemid
## 56.04435 -0.12900 -0.21442
## jobblue-collar jobentrepreneur jobhousemaid
## -0.29910 -0.08210 -0.12956
## jobmanagement jobretired jobself-employed
## -0.10930 0.11776 0.03273
## jobservices jobstudent jobtechnician
## -0.31619 0.10334 -0.14972
## jobunemployed contacttelephone monthaug
## -0.11887 -0.47510 -0.12122
## monthdec monthjul monthjun
## 0.12764 0.11671 0.21407
## monthmar monthmay monthnov
## 0.62986 -0.69974 -0.36196
## monthoct monthsep campaign
## -0.27547 -0.58194 -0.04477
## previous poutcomenonexistent poutcomesuccess
## -0.14556 0.38769 0.70419
## cons.conf.idx nr.employed pdays_dummy1
## 0.01812 -0.01109 1.24810
##
## Degrees of Freedom: 23947 Total (i.e. Null); 23918 Residual
## Null Deviance: 17030
## Residual Deviance: 13440 AIC: 13500
# job:education + age:marital + nr.employed:previous + previous:contact
# + job:education + job:marital + nr.employed:previous + age:contact
lmod2 <- glm(y ~ age + job + contact + month + campaign + previous +
poutcome + cons.conf.idx + nr.employed + pdays_dummy + nr.employed:previous + age:contact + euribor3m:nr.employed,
data = train_data,
family = "binomial")
summary(lmod2)
##
## Call:
## glm(formula = y ~ age + job + contact + month + campaign + previous +
## poutcome + cons.conf.idx + nr.employed + pdays_dummy + nr.employed:previous +
## age:contact + euribor3m:nr.employed, family = "binomial",
## data = train_data)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.1576 -0.3913 -0.3330 -0.2602 2.7914
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 5.602e+01 5.876e+00 9.533 < 2e-16 ***
## agelow -1.326e-01 1.513e-01 -0.876 0.380959
## agemid -2.705e-01 1.399e-01 -1.934 0.053153 .
## jobblue-collar -2.993e-01 7.262e-02 -4.122 3.75e-05 ***
## jobentrepreneur -8.450e-02 1.353e-01 -0.625 0.532226
## jobhousemaid -1.219e-01 1.534e-01 -0.795 0.426875
## jobmanagement -1.099e-01 9.580e-02 -1.147 0.251456
## jobretired 1.186e-01 1.253e-01 0.947 0.343850
## jobself-employed 3.046e-02 1.250e-01 0.244 0.807462
## jobservices -3.139e-01 9.432e-02 -3.328 0.000876 ***
## jobstudent 9.103e-02 1.250e-01 0.728 0.466420
## jobtechnician -1.486e-01 7.221e-02 -2.059 0.039538 *
## jobunemployed -1.204e-01 1.414e-01 -0.852 0.394365
## contacttelephone -8.906e-01 3.157e-01 -2.821 0.004783 **
## monthaug -8.792e-02 1.174e-01 -0.749 0.453841
## monthdec 1.575e-01 2.362e-01 0.667 0.504935
## monthjul 1.386e-01 1.078e-01 1.285 0.198751
## monthjun 2.091e-01 1.015e-01 2.060 0.039361 *
## monthmar 6.296e-01 1.433e-01 4.393 1.12e-05 ***
## monthmay -7.065e-01 8.346e-02 -8.466 < 2e-16 ***
## monthnov -3.411e-01 1.152e-01 -2.962 0.003059 **
## monthoct -2.513e-01 1.536e-01 -1.636 0.101767
## monthsep -5.485e-01 1.674e-01 -3.276 0.001054 **
## campaign -4.414e-02 1.540e-02 -2.865 0.004165 **
## previous -1.921e+00 4.391e+00 -0.438 0.661690
## poutcomenonexistent 4.262e-01 1.410e-01 3.023 0.002506 **
## poutcomesuccess 7.146e-01 2.667e-01 2.679 0.007384 **
## cons.conf.idx 1.655e-02 5.642e-03 2.933 0.003357 **
## nr.employed -1.110e-02 1.178e-03 -9.426 < 2e-16 ***
## pdays_dummy1 1.241e+00 2.701e-01 4.596 4.30e-06 ***
## previous:nr.employed 3.568e-04 8.818e-04 0.405 0.685713
## agelow:contacttelephone 2.357e-01 3.443e-01 0.685 0.493646
## agemid:contacttelephone 5.088e-01 3.267e-01 1.558 0.119336
## nr.employed:euribor3m -1.191e-06 1.017e-05 -0.117 0.906789
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 17030 on 23947 degrees of freedom
## Residual deviance: 13435 on 23914 degrees of freedom
## AIC: 13503
##
## Number of Fisher Scoring iterations: 6
suppressMessages(library(ResourceSelection))
hoslem.test(lmod_f$y,fitted(lmod2),g=10)
##
## Hosmer and Lemeshow goodness of fit (GOF) test
##
## data: lmod_f$y, fitted(lmod2)
## X-squared = 22.765, df = 8, p-value = 0.00368
pred_values <- ifelse(fitted(lmod_f) > .5, "1", "0")
confusionMatrix(factor(pred_values, levels = levels(truth_values)), truth_values)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 20887 2096
## 1 322 643
##
## Accuracy : 0.899
## 95% CI : (0.8951, 0.9028)
## No Information Rate : 0.8856
## P-Value [Acc > NIR] : 1.781e-11
##
## Kappa : 0.3058
##
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.9848
## Specificity : 0.2348
## Pos Pred Value : 0.9088
## Neg Pred Value : 0.6663
## Prevalence : 0.8856
## Detection Rate : 0.8722
## Detection Prevalence : 0.9597
## Balanced Accuracy : 0.6098
##
## 'Positive' Class : 0
##
variables = c(2, 4, 6, 8, 10, 12)
set.seed(123)
tune_grid = expand.grid(
mtry = variables,
splitrule = "gini",
min.node.size = 1
)
tune_control = trainControl(
method = "cv", # cross-validation
number = 3, # with n folds
summaryFunction = prSummary,
verboseIter = FALSE, # no training log
allowParallel = T, # FALSE for reproducible results
classProbs = TRUE
)
levels(train_data$y) <- make.names(levels(train_data$y))
ranger_tune = train(
y ~ .,
data = train_data,
metric = "F",
trControl = tune_control,
tuneGrid = tune_grid,
method = "ranger"
)
ggplot(ranger_tune) +
theme(legend.position = "bottom")
ranger_tune$bestTune$mtry
## [1] 2
ranger_tune$bestTune$min.node.size
## [1] 1
rf = ranger(y ~ .,
data = train_data,
num.trees = 1000,
importance = "impurity",
splitrule = ranger_tune$bestTune$splitrule,
mtry = ranger_tune$bestTune$mtry,
min.node.size = ranger_tune$bestTune$min.node.size,
write.forest = T,
probability = T)
rf_train_score = predict(rf,
data = train_data)$predictions[, 2]
rf_test_score = predict(rf,
data = test_data)$predictions[, 2]
rf_val_score = predict(rf,
data = val_data)$predictions[, 2]
pred_values <- ifelse(rf_train_score > .5, "1", "0")
confusionMatrix(factor(pred_values, levels = levels(truth_values)), truth_values)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 21123 1638
## 1 86 1101
##
## Accuracy : 0.928
## 95% CI : (0.9247, 0.9313)
## No Information Rate : 0.8856
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5283
##
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.9959
## Specificity : 0.4020
## Pos Pred Value : 0.9280
## Neg Pred Value : 0.9275
## Prevalence : 0.8856
## Detection Rate : 0.8820
## Detection Prevalence : 0.9504
## Balanced Accuracy : 0.6990
##
## 'Positive' Class : 0
##
plot_f1_vs_threshold(rf_test_score, test_data$y)
## Best F1 Score: x = 0.509
## Best cutoff for F1 Score: x = 0.2
plot_f1_vs_threshold(rf_val_score, val_data$y)
## Best F1 Score: x = 0.517
## Best cutoff for F1 Score: x = 0.26
plot_f1_vs_threshold(rf_train_score, truth_values)
## Best F1 Score: x = 0.621
## Best cutoff for F1 Score: x = 0.36
plot_varimp(rf, top_n = 10, title = "Top 10 Variable Importance", rf=T)
plot_roc_auc(truth_values, rf_train_score)
plot_roc_auc(test_data$y, rf_test_score)
plot_roc_auc(val_data$y, rf_val_score)
plot_sens_spec(truth_values, rf_train_score)
plot_sens_spec(test_data$y, rf_test_score)
plot_sens_spec(val_data$y, rf_val_score)
# parameter grid for XGBoost
parameterGrid <- expand.grid(eta = c(0.01, 0.1, 0.3), # shrinkage (learning rate)
colsample_bytree = 0.7, # subsample ration of columns
max_depth = c(2,3,4), # max tree depth. model complexity
nrounds = c(800,900,1000), # boosting iterations
gamma = 1, # minimum loss reduction
subsample = 0.8, # ratio of the training instances
min_child_weight = 1)
# tune_grid = expand.grid(
# nrounds = seq(from = 200, to = nrounds, by = 50),
# eta = c(0.025, 0.05, 0.1, 0.3),
# max_depth = c(2, 3, 4, 5, 6),
# gamma = 0,
# colsample_bytree = 1,
# min_child_weight = 1,
# subsample = 1
# )
tune_control = trainControl(
method = "cv", # cross-validation
number = 3, # with n folds
summaryFunction = prSummary,
verboseIter = FALSE, # no training log
allowParallel = FALSE, # FALSE for reproducible results
classProbs = TRUE
)
xgb_tune = train(
y ~ .,
data = train_data,
metric = "F",
trControl = tune_control,
tuneGrid = parameterGrid,
method = "xgbTree",
verbose = FALSE
)
## [11:15:38] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:38] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:38] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:38] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:43] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:43] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:43] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:43] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:54] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:54] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:54] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:15:54] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:07] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:07] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:07] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:07] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:11] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:11] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:11] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:11] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:17] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:17] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:17] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:17] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:34] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:34] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:34] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:34] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:40] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:40] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:40] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:40] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:44] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:44] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:44] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:44] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:57] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:57] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:57] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:16:57] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:01] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:01] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:01] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:01] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:13] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:13] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:13] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:13] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:17] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:17] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:17] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:17] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:29] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:29] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:29] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:29] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:33] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:33] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:33] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:33] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:39] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:39] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:39] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:39] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:46] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:46] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:46] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:46] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:50] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:55] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:55] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:55] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:17:55] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:18:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:18:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:18:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [11:18:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
xgb_tune$bestTune
outcome <- "y"
predictors <- setdiff(names(train_data), outcome)
params <- list(
objective = "binary:logistic",
eval_metric = "auc",
eta = 0.01,
max_depth = 3,
min_child_weight = xgb_tune$bestTune$min_child_weight,
subsample = 0.8,
colsample_bytree = 0.7,
gamma = 1,
scale_pos_weight = 21209 / 2739
)
train_data_one_hot <- predict(dummyVars(~ ., data = train_data[,-14]), newdata = train_data[,-14])
test_data_one_hot <- predict(dummyVars(~ ., data = test_data[,-14]), newdata = test_data[,-14])
val_data_one_hot <- predict(dummyVars(~ ., data = val_data[,-14]), newdata = val_data[,-14])
# Train XGBoost model
xgb_model <- xgb.train(
params = params,
data = xgb.DMatrix(data = as.matrix(train_data_one_hot), label = as.numeric(as.character(truth_values))),
watchlist = list(train = xgb.DMatrix(data = as.matrix(train_data_one_hot), label = as.numeric(as.character(truth_values))),
test = xgb.DMatrix(data = as.matrix(val_data_one_hot), label = as.numeric(as.character(val_data$y)))),
early_stopping_rounds = 10, #number of iterations to perform to come out of a local minimal
nrounds = 1000,
verbose = FALSE
)
# Make predictions on test set
test_y_pred <- predict(xgb_model, xgb.DMatrix(data = as.matrix(test_data_one_hot), label = as.numeric(as.character(test_data$y))))
train_y_pred <- predict(xgb_model, xgb.DMatrix(data = as.matrix(train_data_one_hot), label = as.numeric(as.character(truth_values))))
plot_f1_vs_threshold(test_y_pred, test_data$y)
## Best F1 Score: x = 0.485
## Best cutoff for F1 Score: x = 0.53
plot_f1_vs_threshold(train_y_pred, truth_values)
## Best F1 Score: x = 0.487
## Best cutoff for F1 Score: x = 0.53
### Plot ROC-AUC
plot_roc_auc(truth_values, train_y_pred)
plot_roc_auc(test_data$y, test_y_pred)
pred_values <- ifelse(train_y_pred > .5, "1", "0")
confusionMatrix(factor(pred_values, levels = levels(truth_values)), truth_values)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 18249 1074
## 1 2960 1665
##
## Accuracy : 0.8316
## 95% CI : (0.8268, 0.8363)
## No Information Rate : 0.8856
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.3603
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.8604
## Specificity : 0.6079
## Pos Pred Value : 0.9444
## Neg Pred Value : 0.3600
## Prevalence : 0.8856
## Detection Rate : 0.7620
## Detection Prevalence : 0.8069
## Balanced Accuracy : 0.7342
##
## 'Positive' Class : 0
##