knitr::include_graphics("reisen.png")In this project, data from a company named “Trips & Travel.Com” is used with the goal to make marketing expenditure more efficient. The company is planning to launch a new product, a Wellness Tourism Package. Wellness Tourism is defined as Travel that allows the traveler to maintain, enhance or kick-start a healthy lifestyle, and support or increase one’s sense of well-being. For this purpose, available data shall be used to predict the potential customer who is going to purchase the newly introduced travel package. A Random Forrest algorithm will be used for this approach.
Tasks to Solve:
First, relevant packages are loaded. We load a range of libraries for general data wrangling and general visualisation together with more specialised tools for dealing with unbalanced data.
library(here) #used for folder navigation
library(tidyverse) #used for data wrangling
library(data.table) #used for data wrangling
library(skimr) #used to get overview of data
library(yarrr) #used to create "pirateplots"
library(infer) #used for Bootstrapp-Based inference
library(ggridges) #used for graphs
library(tidymodels) #used for modeling
library(tune) #used for hyperparameter tuning
library(themis) #used to expand recipe package for dealing with unbalanced data
library(vip) #used for variable importance
library(knitr) #used for table displaying
library(ggpubr) #used for annotating figuresThe data was downloaded from kaggle and stored locally (login necessary). The here package is used to locate the files relative to the project root. We already have a first look at the amount of missing data:
#get data
travel <- fread(here("Data", "Travel.csv"))
cat(dim(travel[!complete.cases(travel),])[1], "out of", nrow(travel), "observations have at least one missing value.")760 out of 4888 observations have at least one missing value.
As a first step, let’s have a quick look at the data using the skim function:
skim(travel)| Name | travel |
| Number of rows | 4888 |
| Number of columns | 20 |
| Key | NULL |
| _______________________ | |
| Column type frequency: | |
| character | 6 |
| numeric | 14 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| TypeofContact | 0 | 1 | 0 | 15 | 25 | 3 | 0 |
| Occupation | 0 | 1 | 8 | 14 | 0 | 4 | 0 |
| Gender | 0 | 1 | 4 | 7 | 0 | 3 | 0 |
| ProductPitched | 0 | 1 | 4 | 12 | 0 | 5 | 0 |
| MaritalStatus | 0 | 1 | 6 | 9 | 0 | 4 | 0 |
| Designation | 0 | 1 | 2 | 14 | 0 | 5 | 0 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| CustomerID | 0 | 1.00 | 202443.50 | 1411.19 | 200000 | 201221.8 | 202443.5 | 203665.2 | 204887 | ▇▇▇▇▇ |
| ProdTaken | 0 | 1.00 | 0.19 | 0.39 | 0 | 0.0 | 0.0 | 0.0 | 1 | ▇▁▁▁▂ |
| Age | 226 | 0.95 | 37.62 | 9.32 | 18 | 31.0 | 36.0 | 44.0 | 61 | ▂▇▆▃▂ |
| CityTier | 0 | 1.00 | 1.65 | 0.92 | 1 | 1.0 | 1.0 | 3.0 | 3 | ▇▁▁▁▃ |
| DurationOfPitch | 251 | 0.95 | 15.49 | 8.52 | 5 | 9.0 | 13.0 | 20.0 | 127 | ▇▁▁▁▁ |
| NumberOfPersonVisiting | 0 | 1.00 | 2.91 | 0.72 | 1 | 2.0 | 3.0 | 3.0 | 5 | ▁▅▇▃▁ |
| NumberOfFollowups | 45 | 0.99 | 3.71 | 1.00 | 1 | 3.0 | 4.0 | 4.0 | 6 | ▂▆▇▃▁ |
| PreferredPropertyStar | 26 | 0.99 | 3.58 | 0.80 | 3 | 3.0 | 3.0 | 4.0 | 5 | ▇▁▂▁▂ |
| NumberOfTrips | 140 | 0.97 | 3.24 | 1.85 | 1 | 2.0 | 3.0 | 4.0 | 22 | ▇▁▁▁▁ |
| Passport | 0 | 1.00 | 0.29 | 0.45 | 0 | 0.0 | 0.0 | 1.0 | 1 | ▇▁▁▁▃ |
| PitchSatisfactionScore | 0 | 1.00 | 3.08 | 1.37 | 1 | 2.0 | 3.0 | 4.0 | 5 | ▅▃▇▅▅ |
| OwnCar | 0 | 1.00 | 0.62 | 0.49 | 0 | 0.0 | 1.0 | 1.0 | 1 | ▅▁▁▁▇ |
| NumberOfChildrenVisiting | 66 | 0.99 | 1.19 | 0.86 | 0 | 1.0 | 1.0 | 2.0 | 3 | ▅▇▁▅▁ |
| MonthlyIncome | 233 | 0.95 | 23619.85 | 5380.70 | 1000 | 20346.0 | 22347.0 | 25571.0 | 98678 | ▃▇▁▁▁ |
Notes:
There is some valuable information in this quick overview. We are dealing with approximately 5% missing data in the variables Age, DurationOfPitch and MonthlyIncome. In addition, there are also (less) missing datapoints in the NumberOfFollowups, PreferredPropertyStar and NumberOfChildrenVising variables.
There are some variables classified as numeric which should be character instead (e.g. ProdTaken, Passport).
ProdTaken is our variable which we want to predict. It can be 0 (Wellness Tourism Package not taken) or 1 (Wellness Tourism Package taken).
The variable names are already clean.
We can already see that we don’t have to deal with any zero variance variables.
Let us have a closer look at the standard deviations of the numeric variables in our dataset in descendig order. This will help us to decide if normalization is necessary. In addition, since some variables like CityTier are classified as numeric although they are truly factor type, we will also have a closer look at the number number of unique values of the numeric variables. This will help to itendify which variables can be converted to factor type.
data_sd <- map_dbl(travel[,-c(1,2)], sd, na.rm=TRUE) %>%
tibble(Variable = names(travel[,-c(1,2)]), sd = .) %>%
filter(sd != "NA") %>%
arrange(-sd)
data_unique <- map(travel[,-c(1,2)], unique) %>%
lengths(.) %>%
tibble(Variable = names(travel[,-c(1,2)]), n_unique = .)
left_join(data_sd, data_unique, by="Variable") %>% kable()| Variable | sd | n_unique |
|---|---|---|
| MonthlyIncome | 5380.6983607 | 2476 |
| Age | 9.3163870 | 45 |
| DurationOfPitch | 8.5196426 | 35 |
| NumberOfTrips | 1.8490193 | 13 |
| PitchSatisfactionScore | 1.3657917 | 5 |
| NumberOfFollowups | 1.0025087 | 7 |
| CityTier | 0.9165834 | 3 |
| NumberOfChildrenVisiting | 0.8578612 | 5 |
| PreferredPropertyStar | 0.7980087 | 4 |
| NumberOfPersonVisiting | 0.7248906 | 5 |
| OwnCar | 0.4853632 | 2 |
| Passport | 0.4542316 | 2 |
Notes:
We have very large variability in the MonthlyIncome variable, which means that the data should be normalized for any machine learning model that uses gradient descent algorithm or is based on class distances. With tree based models, we do not need neither normalize the data nor converting factors to dummies.
We have further information wich variables can be converted to character type.
We will convert the identified numeric variables to character class. In addition, we will clean some misspelling in the gender variable and check if remaining gender values are correct. We will leave missing data for now, since we will deal with them later through linear imputation when specifying our prediction model.
#prepocess character variables
character_variables <- c("ProdTaken", "CityTier", "Passport", "OwnCar", "TypeofContact", "Gender", "Occupation", "MaritalStatus")
travel <- travel %>% mutate_at(character_variables, as.character)
travel[ProdTaken == 0, ProdTaken := "No",]
travel[ProdTaken == 1, ProdTaken := "Yes",]
#clean Gender misspelling
travel[Gender == "Fe Male", Gender := "Female"]
cat("Unique gender values:", unique(travel$Gender))Unique gender values: Female Male
Now let us also have a closer look on the binary outcome Variable, which indicates whether the Holiday Package was taken:
#take a first look at ProdTaken
travel[order(ProdTaken),.(.N), by = ProdTaken] %>% kable(col.names = c("ProdTaken", "Frequency"))| ProdTaken | Frequency |
|---|---|
| No | 3968 |
| Yes | 920 |
prop.table(table(travel$ProdTaken)) %>% kable(col.names = c("ProdTaken", "Proportion"))| ProdTaken | Proportion |
|---|---|
| No | 0.811784 |
| Yes | 0.188216 |
Notes:
Before we start with the prediction, we will have a look at especially important variables and their relationship to the outcome variable. We will start with MonthlyIncome, since this variable has the biggest impact on ProdTaken during Prediction when using all variables in a random forest model.
#distribution of MonthlyIncome
ggplot(data = travel[!is.na(MonthlyIncome)], aes(x = MonthlyIncome)) +
geom_histogram(bins = 250) +
xlab("\nMonthlyIncome") + ylab("Count") +
theme_classic() Notes:
All in all, the Monthly Income values are really high (center near 25000?). It is not stated in the dataset description with which currency we are dealing with.
Another thing to be aware of is the skew of the distribution. Salary data are often right skewed, which means that some very high salaries will have a big influence on mean. Median or log transformation are good solutions to reduce influence of single salary datapoints when examining the relationship between MonthlyIncome and ProdTaken.
We will start now with plotting the relationship between Monthly Income and Customers who took the Product or not. The logarithm is used because of the skewness of the ProdTaken variable. Because we are still dealing with extreme outliers, a second graph is shown where the top and botton 1% of data is removed.
par(mfrow = c(1, 2))
pirateplot(log(MonthlyIncome) ~ ProdTaken, data = travel[!is.na(MonthlyIncome)],
theme = 1,
gl.col = "white",
xlab = "ProdTaken - Original Data",
cex.lab = 0.75,
cex.axis = 0.75,
cex.names = 0.75)
#trimming top and botton 1% of data
travel_trimmed <- travel[MonthlyIncome %between% quantile(MonthlyIncome, c(.01, .99), na.rm = TRUE)]
pirateplot(data = travel_trimmed, log(MonthlyIncome) ~ ProdTaken,
gl.col = "white",
xlab = "ProdTaken - Trimmed Data",
cex.lab = 0.75,
cex.axis = 0.75,
cex.names = 0.75)Notes:
Pirateplots use mean by default. The Mean of logged MonthlyIncome is higher for customers who did not take the product.
Although we used the logarithm of ProdTaken, we still see huge outliers in the data.
After trimming, the Income is still lower for Customers who took the Wellness Tourism Product.
We will also have a look at the median differences because of the skewness:
ggplot(data = travel[!is.na(MonthlyIncome)], aes(x = MonthlyIncome, y = as.factor(ProdTaken), fill = as.factor(ProdTaken)))+
stat_density_ridges(quantile_lines = TRUE, quantiles = 2, scale = 3, color = "white") +
scale_fill_manual(values = c("grey30", "#9a5ea1"), guide = FALSE) +
labs(x = "\nMonthlyIncome", y = "ProdTaken") +
theme_classic()Notes:
The white lines in the two graphs are representing the median, which is also higher in the customer group who did not take the Wellness Packages.
The distributions from the two subgroups are looking quite similar.
To get a better overview whether this difference is unlikely the result of pure chance (sampling issues), we will do Bootstrapp-Based inference testing. We will first calculate the difference in Monthly Income between the medians of the two groups (Customer who did buy the package vs Customer who did not buy). After that, we will simulate a world where the actual difference in medians between these two groups is zero (shuffling the Product Taken labels within the existing data 5,000 times). We then plot that null distribution, place the observed difference in Income medians in it, and see how well it fits. We will also calculate the probability of seeing a difference as big as found with the get_pvalue() function.
#calculating the median difference
diff_med <- travel[!is.na(MonthlyIncome)] %>%
specify(MonthlyIncome ~ ProdTaken) %>%
calculate("diff in medians",
order = c("Yes", "No"))
#specify null hypthesis
ProdTaken_null <- as_tibble(travel[!is.na(MonthlyIncome)]) %>%
specify(MonthlyIncome ~ ProdTaken) %>%
hypothesize(null = "independence") %>%
generate(reps = 5000, type = "permute") %>%
calculate("diff in medians",
order = c("Yes", "No"))
#get p-value
pv <- get_pvalue(ProdTaken_null,obs_stat = diff_med, direction = "both")[[1]]
#get lower and upper bound of Confidence Intervall
diff_ConfInt <- travel[!is.na(MonthlyIncome)] %>%
specify(MonthlyIncome ~ ProdTaken) %>%
generate(reps = 5000, type = "bootstrap") %>%
calculate("diff in medians",
order = c("Yes", "No")) %>%
get_confidence_interval()
# Vizualization of Null hypotheses Distribution and actual difference in median
p <- ProdTaken_null %>%
visualize() +
geom_vline(xintercept = diff_med$stat, color = "#FF4136", size = 1) +
labs(x = "\nDifference in median proportion\n(Product Taken vs Produkt Not Taken)",
y = "Count",
subtitle = "Red line shows observed difference in median Income\n") +
theme_classic()
annotate_figure(p, bottom = text_grob(paste0("p-value: ", format.pval(pv)," \n lower bound of ci: ", round(diff_ConfInt[[1]], 0), "\n upper bound of ci: ", round(diff_ConfInt[[2]],0)),
hjust = 1, x = 1, color = "red", size = 10)
)Notes:
The red line represents our measured difference in median MonthlyIncome between our two groups (value: -1557). This red line is pretty far in the left tail of the distribution and seems atypical, which indicates a small p-value.
The p Value is indeed way smaller than 0.05. That is pretty strong evidence, and I’d feel confident declaring that there is a statistically significant difference between Median Monthly Income for Customers who took the Product and Customers who did not take it.
According to the upper bound of the Confidence Interval, we can be be 95% certain that the true median difference between our groups is at least -1316.
We will only look quickly at the character variables by looking at the mean percentage that customers did take the Wellnes Product for each unique value of the character variables. The y-axis is Percentage of ProductTaken for all following graphs:
#helper function to get percentages how often the target Product was taken
get_percentage <- function(x){
length(x[x == "Yes"]) / length(x)
}
#loop through character variable names to plot each character variable in its relationship with Percentage of ProductTaken
p <- list()
j = 1
for(i in c(character_variables[-1])){
temp <- travel[, .(Percentage_ProductTaken = get_percentage(ProdTaken)), by = i]
p[[j]] <- ggplot(aes_string(x=names(temp)[1], y=names(temp)[2]), data = temp) +
geom_bar(stat="identity") +
ggtitle(paste0(names(temp[1]))) +
xlab("") +
ylab("") +
theme_classic() +
theme(plot.title = element_text(hjust = 0.5)) +
theme(axis.title.x=element_blank(), axis.text.x=element_blank(), axis.ticks.x=element_blank())
j = j+1
}
do.call(grid.arrange,p)Notes:
We finally look at Duration of Pitch before we start our prediction:
#there was one big outlier with duration > 100
pirateplot(data = travel[DurationOfPitch < 100], DurationOfPitch ~ ProdTaken,
gl.col = "white",
cex.lab = 0.75,
cex.axis = 0.75,
cex.names = 0.75)Notes:
All in all, we see many variables which have a meaningfull relationship with ProdTaken. During model testing on a cross validation subset, a model with all variables got the highest accuracy rates and will be used for our Random Forest Model.
We start to prepare our data by splitting it put into a training and testing set. As said before, we will stratify the split to address that the outcome variable is unbalanced. Resampling will be used with the training data in form of cross-validation. This will help us evaluate our model. Our random forest model is specified with the parsnip package:
set.seed(1234)
split_travel <- initial_split(travel[,-"CustomerID"], prop = 0.7,
strata = ProdTaken)
#split to training and testing data
training_travel <- training(split_travel)
testing_travel <- testing(split_travel)
#cross-validation subset
vfold_travel <- rsample::vfold_cv(data = training_travel, v = 10)
#specify random forest model
rf <- parsnip::rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("classification")
rfRandom Forest Model Specification (classification)
Main Arguments:
trees = 1000
Computational engine: ranger
The recipes package provides an easy way to combine all the transformations and other features related to the model as a single block that can be used for any other subset of the data.
For our case we tested two different recipes:
Recipe 1:
Recipe 2:
Since Recipe 1 did get a much higher accuracy, we only focus on Recipe 1 in further analyses.
data_recipe <- training_travel %>%
recipe(ProdTaken ~ .) %>%
step_impute_linear(all_numeric(),
impute_with = all_nominal())
# accuracy with more variable preproessing was worse, so the second recipe was used:
data_recipe2 <- training_travel %>%
recipe(ProdTaken ~ .) %>%
step_impute_linear(all_numeric(), impute_with = all_nominal()) %>%
step_dummy(all_predictors()) %>%
step_corr(all_predictors(), threshold = 0.8) %>%
step_normalize(all_numeric()) %>%
step_smote(ProdTaken) We will optionally perform the preprocessing to see how it influences the data:
prepped_rec <- prep(data_recipe, verbose = TRUE, retain = TRUE)oper 1 step impute linear [training]
The retained training set is ~ 0.26 Mb in memory.
#Let us have a look at the preprocessed training data
preproc_train <- recipes::bake(prepped_rec, new_data = NULL)
skim(preproc_train)| Name | preproc_train |
| Number of rows | 3421 |
| Number of columns | 19 |
| _______________________ | |
| Column type frequency: | |
| factor | 10 |
| numeric | 9 |
| ________________________ | |
| Group variables | None |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| TypeofContact | 0 | 1 | FALSE | 3 | Sel: 2402, Com: 999, emp: 20 |
| CityTier | 0 | 1 | FALSE | 3 | 1: 2225, 3: 1058, 2: 138 |
| Occupation | 0 | 1 | FALSE | 4 | Sal: 1629, Sma: 1481, Lar: 310, Fre: 1 |
| Gender | 0 | 1 | FALSE | 2 | Mal: 2040, Fem: 1381 |
| ProductPitched | 0 | 1 | FALSE | 5 | Bas: 1294, Del: 1207, Sta: 537, Sup: 232 |
| MaritalStatus | 0 | 1 | FALSE | 4 | Mar: 1623, Div: 668, Sin: 638, Unm: 492 |
| Passport | 0 | 1 | FALSE | 2 | 0: 2437, 1: 984 |
| OwnCar | 0 | 1 | FALSE | 2 | 1: 2123, 0: 1298 |
| Designation | 0 | 1 | FALSE | 5 | Exe: 1294, Man: 1207, Sen: 537, AVP: 232 |
| ProdTaken | 0 | 1 | FALSE | 2 | No: 2777, Yes: 644 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| Age | 0 | 1 | 37.73 | 9.11 | 18 | 31 | 37 | 43 | 61 | ▂▇▇▃▂ |
| DurationOfPitch | 0 | 1 | 15.59 | 8.22 | 5 | 9 | 14 | 19 | 126 | ▇▁▁▁▁ |
| NumberOfPersonVisiting | 0 | 1 | 2.90 | 0.72 | 1 | 2 | 3 | 3 | 5 | ▁▅▇▃▁ |
| NumberOfFollowups | 0 | 1 | 3.71 | 0.99 | 1 | 3 | 4 | 4 | 6 | ▂▆▇▃▁ |
| PreferredPropertyStar | 0 | 1 | 3.58 | 0.79 | 3 | 3 | 3 | 4 | 5 | ▇▁▂▁▂ |
| NumberOfTrips | 0 | 1 | 3.25 | 1.87 | 1 | 2 | 3 | 4 | 22 | ▇▁▁▁▁ |
| PitchSatisfactionScore | 0 | 1 | 3.08 | 1.36 | 1 | 2 | 3 | 4 | 5 | ▅▃▇▅▅ |
| NumberOfChildrenVisiting | 0 | 1 | 1.18 | 0.85 | 0 | 1 | 1 | 2 | 3 | ▅▇▁▅▁ |
| MonthlyIncome | 0 | 1 | 23580.74 | 5288.28 | 4678 | 20468 | 22597 | 25388 | 98678 | ▇▆▁▁▁ |
Notes:
To well organize our workflow in a structured and smoother way, we use the workflow package that is one of the tidymodels collection. We will run a first Random Forest Model on the cross-validation dataset to get an impression about performance without hyperparameter tuning:
rf_wf <- workflows::workflow() %>%
workflows::add_recipe(data_recipe) %>%
workflows::add_model(rf)
model_rf <- rf_wf %>% fit_resamples(vfold_travel, control = control_resamples(save_pred = TRUE))
model_rf_pred <- collect_predictions(model_rf)
cm <- caret::confusionMatrix(factor(model_rf_pred$ProdTaken, levels = c("Yes", "No")), factor(model_rf_pred$.pred_class, levels = c("Yes", "No")))
t1 <- head(cm$overall, 4)
t2 <- head(cm$byClass,4)
knitr::kable(
list(t1, t2), col.names = "",
)
|
|
Without further optimizing, we get:
Since we get an Accuracy of approx. 81% with just predicting No for all datapoints, let us now try to improve performance in the next step through hyperparameter tuning. After tuning, we will finalize our model by automatically choosing the best peformaning paramters, fit them to the whole training set and afterwards assess the performance on the test set.
First, to tune the decision tree hyperparameters mtry and min_n, we create a model specification that identifies which hyperparameters we plan to tune. We can’t train this specification on a single data set (such as the entire training set) and learn what the hyperparameter values should be, therefore we will train many models using resampled data and see which models turn out best. We will also create a regular grid of values to try using some convenience functions for each hyperparameter.
Once we have our tuning results, we can explore them through visualization and then select the best result. The function collect_metrics() gives us a tidy tibble with all the results. Let us look at the tuning results in a table and in a plot:
set.seed(1234)
# tune mtry and min_n
rf_tuning <- parsnip::rand_forest(trees = 1000,
mtry = tune(),
min_n = tune()
) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("classification")
#create grid of values to try
rf_grid <- grid_regular(min_n(range = c(2,8)),
mtry(range= c (2,17)),
levels = 5)
recipe_tuning <- data_recipe #concentrate on paramters of model first
# put together in a workflow
rf_wf_tune <- workflows::workflow() %>%
workflows::add_recipe(recipe_tuning) %>%
workflows::add_model(rf_tuning)
#train hyperparameters
set.seed(1234)
rf_res <- rf_wf_tune %>%
tune::tune_grid(resamples = vfold_travel,
grid = rf_grid)
# safe tuning results
tuning_results <- rf_res %>% collect_metrics %>% arrange(.metric ,-mean)
kable(head(tuning_results,4))| mtry | min_n | .metric | .estimator | mean | n | std_err | .config |
|---|---|---|---|---|---|---|---|
| 17 | 2 | accuracy | binary | 0.9279018 | 9 | 0.0048667 | Preprocessor1_Model21 |
| 13 | 2 | accuracy | binary | 0.9249759 | 9 | 0.0046226 | Preprocessor1_Model16 |
| 13 | 3 | accuracy | binary | 0.9240031 | 9 | 0.0047957 | Preprocessor1_Model17 |
| 17 | 3 | accuracy | binary | 0.9240022 | 9 | 0.0054469 | Preprocessor1_Model22 |
rf_res %>%
collect_metrics() %>%
filter(.metric == "accuracy") %>%
select(mean, min_n, mtry) %>%
pivot_longer(min_n:mtry,
values_to = "value",
names_to = "parameter"
) %>%
ggplot(aes(value, mean, color = parameter)) +
geom_point(show.legend = FALSE) +
facet_wrap(~parameter, scales = "free_x") +
labs(x = NULL, y = "Accuracy") +
theme_classic()Notes:
For the min_n parameter, the accuracy is higher for smaller values
For the mtry paramater, the accuracy is higher for higher values.
We got the highest accuracy for a min_n value of 2 and a mtry value of 17.
We use the select_best() function to pull out the single set of hyperparameter values for our best decision tree model. We will show the accuracy on the cross-validation subset with this selection:
best_rf <- rf_res %>%
select_best("accuracy")
final_wf <- rf_wf_tune %>%
finalize_workflow(best_rf)
final_cv_performance <- final_wf %>% fit_resamples(vfold_travel, control = control_resamples(save_pred = TRUE))
kable(collect_metrics(final_cv_performance))| .metric | .estimator | mean | n | std_err | .config |
|---|---|---|---|---|---|
| accuracy | binary | 0.9240041 | 9 | 0.0053551 | Preprocessor1_Model1 |
| roc_auc | binary | 0.9517956 | 9 | 0.0058250 | Preprocessor1_Model1 |
Notes:
It is time to assess final performance on the test set. Let’s fit this final model to the training data and use our test data to estimate the model performance we expect to see with new data:
final_fit <- final_wf %>%
last_fit(split_travel)
final_accuracy <- final_fit %>%
collect_metrics()
#get final predictions
final_fit_pred <- final_fit[[5]][[1]]
#get confusion matrix
final_cm <- caret::confusionMatrix(factor(testing_travel$ProdTaken, levels = c("Yes", "No")), factor(final_fit_pred$.pred_class, levels = c("Yes", "No")))
final_cmConfusion Matrix and Statistics
Reference
Prediction Yes No
Yes 200 76
No 14 1177
Accuracy : 0.9387
95% CI : (0.9251, 0.9504)
No Information Rate : 0.8541
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.7802
Mcnemar's Test P-Value : 1.276e-10
Sensitivity : 0.9346
Specificity : 0.9393
Pos Pred Value : 0.7246
Neg Pred Value : 0.9882
Prevalence : 0.1459
Detection Rate : 0.1363
Detection Prevalence : 0.1881
Balanced Accuracy : 0.9370
'Positive' Class : Yes
Notes:
All in all, we get get a final accuracy of 93.87%, which is higher than just predicting No in the unbalanced data.
The specificity rate related to the minor class is 93.46% and is quite similar to the sensitivity with 93.93%.
We did a good job in dealing with the imbalance in our dataset and improving the performance!
final_fit %>%
extract_workflow()%>%
extract_fit_parsnip() %>%
vip() +
theme_classic()In this project, we used a Random Forest model to predict which customers will buy a new Wellness Product. The final accuracy is 93.87% on the test set. The most important variables for prediction were MonthlyIncome, Age, DurationOfPitch and ProductPitched. We also found evidence that the relationship between MonthlyIncome and ProdTaken was not due to chance and is generizable to new customers.