library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.2
## ✔ ggplot2   4.0.0     ✔ tibble    3.3.0
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.1
## ✔ purrr     1.1.0     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidytext)
## Warning: package 'tidytext' was built under R version 4.5.2
library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.5.2
## ── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
## ✔ broom        1.0.10     ✔ rsample      1.3.1 
## ✔ dials        1.4.2      ✔ tailor       0.1.0 
## ✔ infer        1.1.0      ✔ tune         2.0.1 
## ✔ modeldata    1.5.1      ✔ workflows    1.3.0 
## ✔ parsnip      1.4.1      ✔ workflowsets 1.1.1 
## ✔ recipes      1.3.1      ✔ yardstick    1.3.2
## Warning: package 'dials' was built under R version 4.5.2
## Warning: package 'infer' was built under R version 4.5.2
## Warning: package 'modeldata' was built under R version 4.5.2
## Warning: package 'parsnip' was built under R version 4.5.2
## Warning: package 'tailor' was built under R version 4.5.2
## Warning: package 'tune' was built under R version 4.5.2
## Warning: package 'workflows' was built under R version 4.5.2
## Warning: package 'workflowsets' was built under R version 4.5.2
## Warning: package 'yardstick' was built under R version 4.5.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
library(textrecipes)
## Warning: package 'textrecipes' was built under R version 4.5.2
# Load the data directly from the TidyTuesday repository
url <- "https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-01-18/chocolate.csv"
chocolate <- read_csv(url)
## Rows: 2530 Columns: 10
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (7): company_manufacturer, company_location, country_of_bean_origin, spe...
## dbl (3): ref, review_date, rating
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

EDA

chocolate %>%
  ggplot(aes(rating)) +
  geom_histogram(bins = 15)

tidy_chocolate <- chocolate %>%
  unnest_tokens(word, most_memorable_characteristics)

tidy_chocolate %>%
  count(word, sort = TRUE)
## # A tibble: 547 × 2
##    word        n
##    <chr>   <int>
##  1 cocoa     419
##  2 sweet     318
##  3 nutty     278
##  4 fruit     273
##  5 roasty    228
##  6 mild      226
##  7 sour      208
##  8 earthy    199
##  9 creamy    189
## 10 intense   178
## # ℹ 537 more rows

visulaize

tidy_chocolate %>%
  group_by(word) %>%
  summarise(
    n = n(),
    rating = mean(rating)
  ) %>%
  ggplot(aes(n, rating)) +
  geom_hline(yintercept = mean(chocolate$rating), lty = 2, color = "gray50", size = 1.5) +
  geom_jitter(color = "midnightblue", alpha = 0.7) +
  geom_text(aes(label = word), check_overlap = TRUE, vjust = "top", hjust = "left") +
  scale_x_log10()
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

build models

set.seed(123)
choco_split <- initial_split(chocolate, strata = rating)
choco_train <- training(choco_split)
choco_test <- testing(choco_split)

set.seed(234)
choco_folds <- vfold_cv(choco_train, strata = rating)
choco_rec <- recipe(rating ~ most_memorable_characteristics, data = choco_train) %>%
  step_tokenize(most_memorable_characteristics) %>%
  step_tokenfilter(most_memorable_characteristics, max_tokens = 100) %>%
  step_tfidf(most_memorable_characteristics)


prep(choco_rec) %>% bake(new_data = NULL)
## # A tibble: 1,896 × 101
##    rating tfidf_most_memorable_c…¹ tfidf_most_memorable…² tfidf_most_memorable…³
##     <dbl>                    <dbl>                  <dbl>                  <dbl>
##  1   3                        0                         0                      0
##  2   2.75                     0                         0                      0
##  3   3                        0                         0                      0
##  4   3                        0                         0                      0
##  5   2.75                     0                         0                      0
##  6   3                        1.38                      0                      0
##  7   2.75                     0                         0                      0
##  8   2.5                      0                         0                      0
##  9   2.75                     0                         0                      0
## 10   3                        0                         0                      0
## # ℹ 1,886 more rows
## # ℹ abbreviated names: ¹​tfidf_most_memorable_characteristics_acidic,
## #   ²​tfidf_most_memorable_characteristics_and,
## #   ³​tfidf_most_memorable_characteristics_astringent
## # ℹ 97 more variables: tfidf_most_memorable_characteristics_banana <dbl>,
## #   tfidf_most_memorable_characteristics_base <dbl>,
## #   tfidf_most_memorable_characteristics_basic <dbl>, …
rf_spec <- rand_forest(trees = 500) %>%
  set_engine("ranger") %>%
  set_mode("regression")

svm_spec <- svm_linear() %>%
  set_engine("LiblineaR") %>%
  set_mode("regression")

workflow

svm_wf <- workflow(choco_rec, svm_spec)
rf_wf <- workflow(choco_rec, rf_spec)

evaluate

doParallel::registerDoParallel()
contrl_preds <- control_resamples(save_pred = TRUE)

svm_rs <- fit_resamples(
  svm_wf,
  resamples = choco_folds,
  control = contrl_preds
)

ranger_rs <- fit_resamples(
  rf_wf,
  resamples = choco_folds,
  control = contrl_preds
)

metrics

collect_metrics(svm_rs)
## # A tibble: 2 × 6
##   .metric .estimator  mean     n std_err .config        
##   <chr>   <chr>      <dbl> <int>   <dbl> <chr>          
## 1 rmse    standard   0.347    10 0.00656 pre0_mod0_post0
## 2 rsq     standard   0.367    10 0.0181  pre0_mod0_post0
collect_metrics(ranger_rs)
## # A tibble: 2 × 6
##   .metric .estimator  mean     n std_err .config        
##   <chr>   <chr>      <dbl> <int>   <dbl> <chr>          
## 1 rmse    standard   0.350    10 0.00669 pre0_mod0_post0
## 2 rsq     standard   0.357    10 0.0163  pre0_mod0_post0

visu

bind_rows(
  collect_predictions(svm_rs) %>% mutate(mod = "SVM"),
  collect_predictions(ranger_rs) %>% mutate(mod = "ranger")
) %>%
  ggplot(aes(rating, .pred, color = id)) +
  geom_abline(lty = 2, color = "gray50", size = 1.2) +
  geom_jitter(width = 0.5, alpha = 0.5) +
  facet_wrap(vars(mod)) +
  coord_fixed()

final_fitted <- last_fit(svm_wf, choco_split)
collect_metrics(final_fitted)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config        
##   <chr>   <chr>          <dbl> <chr>          
## 1 rmse    standard       0.381 pre0_mod0_post0
## 2 rsq     standard       0.348 pre0_mod0_post0
final_wf <- extract_workflow(final_fitted)

final_wf %>%
  tidy() %>%
  filter(term != "Bias") %>%
  group_by(estimate > 0) %>%
  slice_max(abs(estimate), n = 10) %>%
  ungroup() %>%
  mutate(term = str_remove(term, "tfidf_most_memorable_characteristics_")) %>%
  ggplot(aes(estimate, fct_reorder(term, estimate), fill = estimate > 0)) +
  geom_col(alpha = 0.8) +
  labs(y = NULL, fill = "More from...")