Goal: Build a regression model to predict ratings for chocolate based on their main characteristics
Click here for the data.
chocolate <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/main/data/2022/2022-01-18/chocolate.csv')
## Rows: 2530 Columns: 10
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (7): company_manufacturer, company_location, country_of_bean_origin, spe...
## dbl (3): ref, review_date, rating
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(chocolate)
Name | chocolate |
Number of rows | 2530 |
Number of columns | 10 |
_______________________ | |
Column type frequency: | |
character | 7 |
numeric | 3 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
company_manufacturer | 0 | 1.00 | 2 | 39 | 0 | 580 | 0 |
company_location | 0 | 1.00 | 4 | 21 | 0 | 67 | 0 |
country_of_bean_origin | 0 | 1.00 | 4 | 21 | 0 | 62 | 0 |
specific_bean_origin_or_bar_name | 0 | 1.00 | 3 | 51 | 0 | 1605 | 0 |
cocoa_percent | 0 | 1.00 | 3 | 6 | 0 | 46 | 0 |
ingredients | 87 | 0.97 | 4 | 14 | 0 | 21 | 0 |
most_memorable_characteristics | 0 | 1.00 | 3 | 37 | 0 | 2487 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
ref | 0 | 1 | 1429.80 | 757.65 | 5 | 802 | 1454.00 | 2079.0 | 2712 | ▆▇▇▇▇ |
review_date | 0 | 1 | 2014.37 | 3.97 | 2006 | 2012 | 2015.00 | 2018.0 | 2021 | ▃▅▇▆▅ |
rating | 0 | 1 | 3.20 | 0.45 | 1 | 3 | 3.25 | 3.5 | 4 | ▁▁▅▇▇ |
data1 <- chocolate %>%
# Treat missing values
na.omit() %>%
# Extract number of ingredients from ingredients
separate(col = ingredients, into = c("n_ing", "ing"), sep = "-( |)") %>%
# Separate ingredients into separate rows
separate_rows(ing, sep = ",") %>%
# Convert number of ingredients into numeric
mutate(n_ing = n_ing %>%
as.numeric()) %>%
# Extract memorable characteristics from most_memorable_characteristics
# separate(col = most_memorable_characteristics, into = c("most", "memorable", "characteristics"), sep = ",( |)") %>%
# Separate ingredients into separate rows
separate_rows(most_memorable_characteristics, sep = ",( |)") %>%
# Drop N/A
na.omit() %>%
# Remove specific_bean_origin_or_bar_name as it's info is captured in another
select(-specific_bean_origin_or_bar_name) %>%
# Convert Cocoa % into numeric
mutate(cocoa_percent = cocoa_percent %>% str_remove("%") %>% as.numeric()) %>%
# Convert all character variables to factor
mutate(across(where(is.character), factor))
data1
## # A tibble: 20,815 × 10
## ref company_manufacturer company_location review_date
## <dbl> <fct> <fct> <dbl>
## 1 2454 5150 U.S.A. 2019
## 2 2454 5150 U.S.A. 2019
## 3 2454 5150 U.S.A. 2019
## 4 2454 5150 U.S.A. 2019
## 5 2454 5150 U.S.A. 2019
## 6 2454 5150 U.S.A. 2019
## 7 2454 5150 U.S.A. 2019
## 8 2454 5150 U.S.A. 2019
## 9 2454 5150 U.S.A. 2019
## 10 2458 5150 U.S.A. 2019
## # ℹ 20,805 more rows
## # ℹ 6 more variables: country_of_bean_origin <fct>, cocoa_percent <dbl>,
## # n_ing <dbl>, ing <fct>, most_memorable_characteristics <fct>, rating <dbl>
Identify good predictors.
ref
data1 %>%
ggplot(aes(rating, ref)) +
scale_y_log10() +
geom_point()
review_date
data1 %>%
ggplot(aes(rating, review_date)) +
scale_y_log10() +
geom_point()
EDA Shortcut
# Step 1: Prepare date
data_binarized_tbl <- data1 %>%
select(-ref) %>%
binarize()
data_binarized_tbl %>% glimpse()
## Rows: 20,815
## Columns: 94
## $ company_manufacturer__A._Morin <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Bonnat <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Fresco <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Guittard <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Pralus <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Scharffen_Berger <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Soma <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Valrhona <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Zotter <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `company_manufacturer__-OTHER` <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ company_location__Australia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Austria <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Belgium <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Brazil <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Canada <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Colombia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Denmark <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Ecuador <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__France <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Germany <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Italy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Japan <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Spain <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Switzerland <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__U.K. <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__U.S.A. <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ company_location__Venezuela <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `company_location__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `review_date__-Inf_2011` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ review_date__2011_2014 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ review_date__2014_2018 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ review_date__2018_Inf <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ country_of_bean_origin__Belize <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Blend <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Bolivia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Brazil <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Colombia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Costa_Rica <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Dominican_Republic <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Ecuador <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Ghana <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Guatemala <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Haiti <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__India <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Jamaica <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Madagascar <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Mexico <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Nicaragua <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Papua_New_Guinea <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Peru <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Tanzania <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ country_of_bean_origin__Trinidad <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__U.S.A. <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Venezuela <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Vietnam <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `country_of_bean_origin__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `cocoa_percent__-Inf_70` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cocoa_percent__70_74 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cocoa_percent__74_Inf <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ n_ing__2 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ n_ing__3 <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ n_ing__4 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ n_ing__5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `n_ing__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ing__B <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ ing__C <dbl> 0, 0, 0, 0, 0, 0, 1, 1, 1, …
## $ ing__L <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ing__S <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ `ing__S*` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ing__V <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `ing__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__cocoa <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__creamy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__earthy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__fatty <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, …
## $ most_memorable_characteristics__floral <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__fruit <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__intense <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__molasses <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__nutty <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__rich <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__roasty <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sandy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sour <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__spicy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sticky <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sweet <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__vanilla <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__woody <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `most_memorable_characteristics__-OTHER` <dbl> 1, 0, 1, 1, 0, 1, 1, 0, 1, …
## $ `rating__-Inf_3` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ rating__3_3.25 <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ rating__3.25_3.5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ rating__3.5_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
# Step 2: Correlate
data_corr_tbl <- data_binarized_tbl %>%
correlate(rating__3.5_Inf)
data_corr_tbl
## # A tibble: 94 × 3
## feature bin correlation
## <fct> <chr> <dbl>
## 1 rating 3.5_Inf 1
## 2 rating -Inf_3 -0.398
## 3 rating 3.25_3.5 -0.230
## 4 rating 3_3.25 -0.202
## 5 company_manufacturer -OTHER -0.154
## 6 company_manufacturer Soma 0.122
## 7 most_memorable_characteristics creamy 0.119
## 8 company_manufacturer Bonnat 0.103
## 9 company_manufacturer Scharffen_Berger 0.0979
## 10 company_manufacturer A._Morin 0.0757
## # ℹ 84 more rows
# Step 3: Plot
data_corr_tbl %>%
plot_correlation_funnel()
## Warning: ggrepel: 77 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
Split Data
# data1 <- sample_n(data1, 300)
# Split into train and test dataset
set.seed(1235)
data_split1 <- rsample::initial_split(data1)
data_train1 <- training(data_split1)
data_test1 <- testing(data_split1)
# Further Split training dataset for cross validation
set.seed(2345)
data_cv1 <- rsample::vfold_cv(data_train1)
data_cv1
## # 10-fold cross-validation
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [14049/1562]> Fold01
## 2 <split [14050/1561]> Fold02
## 3 <split [14050/1561]> Fold03
## 4 <split [14050/1561]> Fold04
## 5 <split [14050/1561]> Fold05
## 6 <split [14050/1561]> Fold06
## 7 <split [14050/1561]> Fold07
## 8 <split [14050/1561]> Fold08
## 9 <split [14050/1561]> Fold09
## 10 <split [14050/1561]> Fold10
library(usemodels)
usemodels::use_xgboost(rating ~ ., data = data_train1)
## xgboost_recipe <-
## recipe(formula = rating ~ ., data = data_train1) %>%
## step_zv(all_predictors())
##
## xgboost_spec <-
## boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(),
## loss_reduction = tune(), sample_size = tune()) %>%
## set_mode("classification") %>%
## set_engine("xgboost")
##
## xgboost_workflow <-
## workflow() %>%
## add_recipe(xgboost_recipe) %>%
## add_model(xgboost_spec)
##
## set.seed(35947)
## xgboost_tune <-
## tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
# Specify Recipe
xgboost_recipe <-
recipe(formula = rating ~ ., data = data_train1) %>%
recipes::update_role(ref, new_role = "id variable") %>%
step_other(company_manufacturer, company_location, country_of_bean_origin, ing, most_memorable_characteristics) %>%
step_dummy(company_manufacturer, company_location, country_of_bean_origin, ing, most_memorable_characteristics, one_hot = TRUE) %>%
step_log(review_date, cocoa_percent)
xgboost_recipe %>% prep() %>% juice() %>% glimpse
## Rows: 15,611
## Columns: 25
## $ ref <dbl> 2506, 170, 2040, 1518, 147, …
## $ review_date <dbl> 7.610853, 7.604396, 7.609862…
## $ cocoa_percent <dbl> 4.248495, 4.143135, 4.248495…
## $ n_ing <dbl> 3, 4, 3, 3, 4, 3, 3, 5, 3, 2…
## $ rating <dbl> 4.00, 3.50, 4.00, 2.75, 3.50…
## $ company_manufacturer_Soma <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_manufacturer_other <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
## $ company_location_Canada <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_location_France <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_location_U.S.A. <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 1, 1…
## $ company_location_other <dbl> 0, 1, 0, 0, 1, 1, 0, 0, 0, 0…
## $ country_of_bean_origin_Blend <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Dominican.Republic <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1…
## $ country_of_bean_origin_Ecuador <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 1, 0…
## $ country_of_bean_origin_Madagascar <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Peru <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Venezuela <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_other <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 0…
## $ ing_B <dbl> 1, 0, 0, 0, 0, 1, 1, 1, 0, 0…
## $ ing_C <dbl> 0, 0, 1, 1, 0, 0, 0, 0, 0, 0…
## $ ing_L <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ ing_S <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 1…
## $ ing_other <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_sweet <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0…
## $ most_memorable_characteristics_other <dbl> 1, 1, 1, 1, 1, 0, 1, 1, 1, 1…
# Specify Model
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune(), mtry = tune(), learn_rate = tune()) %>%
set_mode("regression") %>%
set_engine("xgboost")
# Combine recipe and model using workflow
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_recipe) %>%
add_model(xgboost_spec)
# Tune hyperparameters
set.seed(12782)
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_cv1,
grid = 5)
## i Creating pre-processing data to finalize unknown parameter: mtry
tune::show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 10
## mtry trees min_n learn_rate .metric .estimator mean n std_err .config
## <int> <int> <int> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 17 1359 38 0.0583 rmse standard 0.282 10 0.00215 Preproces…
## 2 9 694 12 0.0217 rmse standard 0.331 10 0.00175 Preproces…
## 3 23 1723 9 0.00428 rmse standard 0.341 10 0.00137 Preproces…
## 4 3 66 24 0.277 rmse standard 0.366 10 0.00136 Preproces…
## 5 14 1157 26 0.00113 rmse standard 0.829 10 0.00335 Preproces…
# Update the model by selecting the best hyperparameters.
xgboost_fw1 <- tune::finalize_workflow(xgboost_workflow,
tune::select_best(xgboost_tune, metric = "rmse"))
# Fit the model on the entire training data and test it on the test data.
data_fit1 <- tune::last_fit(xgboost_fw1, data_split1)
tune::collect_metrics(data_fit1)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 0.281 Preprocessor1_Model1
## 2 rsq standard 0.584 Preprocessor1_Model1
tune::collect_predictions(data_fit1) %>%
ggplot(aes(rating, .pred)) +
geom_point(alpha = 0.3, fill = "midnightblue") +
geom_abline(lty = 2, color = "gray50") +
coord_fixed()