Goal: Build a regression model to predict ratings for chocolate based on their main characteristics

Click here for the data.

Import Data

chocolate <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/main/data/2022/2022-01-18/chocolate.csv')
## Rows: 2530 Columns: 10
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (7): company_manufacturer, company_location, country_of_bean_origin, spe...
## dbl (3): ref, review_date, rating
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(chocolate)
Data summary
Name chocolate
Number of rows 2530
Number of columns 10
_______________________
Column type frequency:
character 7
numeric 3
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
company_manufacturer 0 1.00 2 39 0 580 0
company_location 0 1.00 4 21 0 67 0
country_of_bean_origin 0 1.00 4 21 0 62 0
specific_bean_origin_or_bar_name 0 1.00 3 51 0 1605 0
cocoa_percent 0 1.00 3 6 0 46 0
ingredients 87 0.97 4 14 0 21 0
most_memorable_characteristics 0 1.00 3 37 0 2487 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
ref 0 1 1429.80 757.65 5 802 1454.00 2079.0 2712 ▆▇▇▇▇
review_date 0 1 2014.37 3.97 2006 2012 2015.00 2018.0 2021 ▃▅▇▆▅
rating 0 1 3.20 0.45 1 3 3.25 3.5 4 ▁▁▅▇▇
data1 <- chocolate %>%
    
    # Treat missing values 
    na.omit() %>%
    # Extract number of ingredients from ingredients
    separate(col = ingredients, into = c("n_ing", "ing"), sep = "-( |)") %>%
    # Separate ingredients into separate rows
    separate_rows(ing, sep = ",") %>%
    # Convert number of ingredients into numeric
    mutate(n_ing = n_ing %>%
               as.numeric()) %>%
    # Extract memorable characteristics from most_memorable_characteristics
    # separate(col = most_memorable_characteristics, into = c("most", "memorable", "characteristics"), sep = ",( |)") %>%
    # Separate ingredients into separate rows
    separate_rows(most_memorable_characteristics, sep = ",( |)") %>%
    # Drop N/A
    na.omit() %>%
    # Remove specific_bean_origin_or_bar_name as it's info is captured in another
    select(-specific_bean_origin_or_bar_name, -review_date) %>%
    # Convert Cocoa % into numeric
    mutate(cocoa_percent = cocoa_percent %>% str_remove("%") %>% as.numeric()) %>%
    # Convert all character variables to factor
    mutate(across(where(is.character), factor))

data1
## # A tibble: 20,815 × 9
##      ref company_manufacturer company_location country_of_bean_origin
##    <dbl> <fct>                <fct>            <fct>                 
##  1  2454 5150                 U.S.A.           Tanzania              
##  2  2454 5150                 U.S.A.           Tanzania              
##  3  2454 5150                 U.S.A.           Tanzania              
##  4  2454 5150                 U.S.A.           Tanzania              
##  5  2454 5150                 U.S.A.           Tanzania              
##  6  2454 5150                 U.S.A.           Tanzania              
##  7  2454 5150                 U.S.A.           Tanzania              
##  8  2454 5150                 U.S.A.           Tanzania              
##  9  2454 5150                 U.S.A.           Tanzania              
## 10  2458 5150                 U.S.A.           Dominican Republic    
## # ℹ 20,805 more rows
## # ℹ 5 more variables: cocoa_percent <dbl>, n_ing <dbl>, ing <fct>,
## #   most_memorable_characteristics <fct>, rating <dbl>

Explore the data

Identify good predictors.

ref

data1 %>%
    ggplot(aes(rating, ref)) +
    scale_y_log10() +
    geom_point()

cocoa_percent

data1 %>%
    ggplot(aes(rating, cocoa_percent)) +
    scale_y_log10() +
    geom_point()

EDA Shortcut

# Step 1: Prepare date
data_binarized_tbl <- data1 %>%
    select(-ref) %>%
    binarize()

data_binarized_tbl %>% glimpse()
## Rows: 20,815
## Columns: 90
## $ company_manufacturer__A._Morin             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Bonnat               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Fresco               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Guittard             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Pralus               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Scharffen_Berger     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Soma                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Valrhona             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Zotter               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `company_manufacturer__-OTHER`             <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ company_location__Australia                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Austria                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Belgium                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Brazil                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Canada                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Colombia                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Denmark                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Ecuador                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__France                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Germany                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Italy                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Japan                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Spain                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Switzerland              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__U.K.                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__U.S.A.                   <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ company_location__Venezuela                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `company_location__-OTHER`                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Belize             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Blend              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Bolivia            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Brazil             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Colombia           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Costa_Rica         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Dominican_Republic <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Ecuador            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Ghana              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Guatemala          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Haiti              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__India              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Jamaica            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Madagascar         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Mexico             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Nicaragua          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Papua_New_Guinea   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Peru               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Tanzania           <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ country_of_bean_origin__Trinidad           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__U.S.A.             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Venezuela          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Vietnam            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `country_of_bean_origin__-OTHER`           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `cocoa_percent__-Inf_70`                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cocoa_percent__70_74                       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cocoa_percent__74_Inf                      <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ n_ing__2                                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ n_ing__3                                   <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ n_ing__4                                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ n_ing__5                                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `n_ing__-OTHER`                            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ing__B                                     <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ ing__C                                     <dbl> 0, 0, 0, 0, 0, 0, 1, 1, 1, …
## $ ing__L                                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ing__S                                     <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ `ing__S*`                                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ing__V                                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `ing__-OTHER`                              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__cocoa      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__creamy     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__earthy     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__fatty      <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, …
## $ most_memorable_characteristics__floral     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__fruit      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__intense    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__molasses   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__nutty      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__rich       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__roasty     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sandy      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sour       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__spicy      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sticky     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sweet      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__vanilla    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__woody      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `most_memorable_characteristics__-OTHER`   <dbl> 1, 0, 1, 1, 0, 1, 1, 0, 1, …
## $ `rating__-Inf_3`                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ rating__3_3.25                             <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ rating__3.25_3.5                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ rating__3.5_Inf                            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
# Step 2: Correlate
data_corr_tbl <- data_binarized_tbl %>%
    correlate(rating__3.5_Inf)

data_corr_tbl
## # A tibble: 90 × 3
##    feature                        bin              correlation
##    <fct>                          <chr>                  <dbl>
##  1 rating                         3.5_Inf               1     
##  2 rating                         -Inf_3               -0.398 
##  3 rating                         3.25_3.5             -0.230 
##  4 rating                         3_3.25               -0.202 
##  5 company_manufacturer           -OTHER               -0.154 
##  6 company_manufacturer           Soma                  0.122 
##  7 most_memorable_characteristics creamy                0.119 
##  8 company_manufacturer           Bonnat                0.103 
##  9 company_manufacturer           Scharffen_Berger      0.0979
## 10 company_manufacturer           A._Morin              0.0757
## # ℹ 80 more rows
# Step 3: Plot
data_corr_tbl %>%
    plot_correlation_funnel()
## Warning: ggrepel: 77 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Build Models

Split Data

data1 <- sample_n(data1, 300)
# Split into train and test dataset
set.seed(1235)
data_split1 <- rsample::initial_split(data1)
data_train1 <- training(data_split1)
data_test1 <- testing(data_split1)

# Further Split training dataset for cross validation
set.seed(2345)
data_cv1 <- rsample::vfold_cv(data_train1)
data_cv1
## #  10-fold cross-validation 
## # A tibble: 10 × 2
##    splits           id    
##    <list>           <chr> 
##  1 <split [202/23]> Fold01
##  2 <split [202/23]> Fold02
##  3 <split [202/23]> Fold03
##  4 <split [202/23]> Fold04
##  5 <split [202/23]> Fold05
##  6 <split [203/22]> Fold06
##  7 <split [203/22]> Fold07
##  8 <split [203/22]> Fold08
##  9 <split [203/22]> Fold09
## 10 <split [203/22]> Fold10
library(usemodels)
usemodels::use_xgboost(rating ~ ., data = data_train1)
## xgboost_recipe <- 
##   recipe(formula = rating ~ ., data = data_train1) %>% 
##   step_zv(all_predictors()) 
## 
## xgboost_spec <- 
##   boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
##     loss_reduction = tune(), sample_size = tune()) %>% 
##   set_mode("classification") %>% 
##   set_engine("xgboost") 
## 
## xgboost_workflow <- 
##   workflow() %>% 
##   add_recipe(xgboost_recipe) %>% 
##   add_model(xgboost_spec) 
## 
## set.seed(87619)
## xgboost_tune <-
##   tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
# Specify Recipe
xgboost_recipe <- 
  recipe(formula = rating ~ ., data = data_train1) %>% 
    recipes::update_role(ref, new_role = "id variable") %>%
    step_other(company_manufacturer, company_location, country_of_bean_origin, ing, most_memorable_characteristics) %>%
    step_dummy(company_manufacturer, company_location, country_of_bean_origin, ing, most_memorable_characteristics, one_hot = TRUE) %>%
    step_log(cocoa_percent)
    
    xgboost_recipe %>% prep() %>% juice() %>% glimpse
## Rows: 225
## Columns: 24
## $ ref                                       <dbl> 1379, 967, 184, 2088, 2178, …
## $ cocoa_percent                             <dbl> 4.276666, 4.248495, 4.317488…
## $ n_ing                                     <dbl> 3, 3, 4, 4, 5, 3, 2, 3, 3, 3…
## $ rating                                    <dbl> 2.50, 2.75, 2.75, 3.00, 2.50…
## $ company_manufacturer_Fresco               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1…
## $ company_manufacturer_other                <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 0…
## $ company_location_Canada                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_location_France                   <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_location_U.S.A.                   <dbl> 0, 1, 0, 0, 1, 1, 1, 0, 1, 1…
## $ company_location_other                    <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 0, 0…
## $ country_of_bean_origin_Brazil             <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Dominican.Republic <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Ecuador            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Madagascar         <dbl> 0, 0, 1, 1, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Peru               <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0…
## $ country_of_bean_origin_Venezuela          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1…
## $ country_of_bean_origin_other              <dbl> 1, 0, 0, 0, 1, 1, 1, 0, 1, 0…
## $ ing_B                                     <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 1…
## $ ing_C                                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0…
## $ ing_L                                     <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0…
## $ ing_S                                     <dbl> 0, 1, 0, 0, 0, 1, 0, 1, 0, 0…
## $ ing_other                                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_sweet      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_other      <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
# Specify Model
xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune(), mtry = tune(), learn_rate = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost") 

# Combine recipe and model using workflow
xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

# Tune hyperparameters 
set.seed(12782)
xgboost_tune <-
  tune_grid(xgboost_workflow, 
            resamples = data_cv1,
            grid = 5)
## i Creating pre-processing data to finalize unknown parameter: mtry

Evaluate Models

tune::show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 10
##    mtry trees min_n learn_rate .metric .estimator  mean     n std_err .config   
##   <int> <int> <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>     
## 1     8   694    12    0.0217  rmse    standard   0.425    10  0.0175 Preproces…
## 2    22  1723     9    0.00428 rmse    standard   0.427    10  0.0196 Preproces…
## 3    16  1359    38    0.0583  rmse    standard   0.430    10  0.0149 Preproces…
## 4     3    66    24    0.277   rmse    standard   0.430    10  0.0209 Preproces…
## 5    13  1157    26    0.00113 rmse    standard   0.842    10  0.0263 Preproces…
# Update the model by selecting the best hyperparameters.
xgboost_fw1 <- tune::finalize_workflow(xgboost_workflow,
                        tune::select_best(xgboost_tune, metric = "rmse"))

# Fit the model on the entire training data and test it on the test data.
data_fit1 <- tune::last_fit(xgboost_fw1, data_split1)
tune::collect_metrics(data_fit1)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard      0.417  Preprocessor1_Model1
## 2 rsq     standard      0.0253 Preprocessor1_Model1
tune::collect_predictions(data_fit1) %>%
    ggplot(aes(rating, .pred)) +
    geom_point(alpha = 0.3, fill = "midnightblue") +
    geom_abline(lty = 2, color = "gray50") +
    coord_fixed()

Conclusion

I tried to improve my model by removing another variable from my model. I first removed review_date from being used in my machine learning algorithm since I didnt think that review date would have much of an impact. I then changed the plot I had for the relationship between review_date and rating to cocoa_percent and rating. I also removed the review_date from the step_log function in specifying the recipe. This increased the RMSE from .281 to .457 and decreased the rsq from .584 to .0281. This Rmse would still fall within a good range, but my model was stronger with the review_date still included, so I would probably have left my model as it was.