library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.4.2
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.6 ✔ rsample 1.2.1
## ✔ dials 1.3.0 ✔ tune 1.2.1
## ✔ infer 1.0.7 ✔ workflows 1.1.4
## ✔ modeldata 1.4.0 ✔ workflowsets 1.1.0
## ✔ parsnip 1.2.1 ✔ yardstick 1.3.2
## ✔ recipes 1.1.0
## Warning: package 'dials' was built under R version 4.4.2
## Warning: package 'infer' was built under R version 4.4.2
## Warning: package 'modeldata' was built under R version 4.4.2
## Warning: package 'parsnip' was built under R version 4.4.2
## Warning: package 'tune' was built under R version 4.4.2
## Warning: package 'workflows' was built under R version 4.4.2
## Warning: package 'workflowsets' was built under R version 4.4.2
## Warning: package 'yardstick' was built under R version 4.4.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Search for functions across packages at https://www.tidymodels.org/find/
url <- "https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-01-18/chocolate.csv"
chocolate <- read_csv(url)
## Rows: 2530 Columns: 10
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (7): company_manufacturer, company_location, country_of_bean_origin, spe...
## dbl (3): ref, review_date, rating
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
chocolate %>%
ggplot(aes(rating)) +
geom_histogram(bins = 15)
library(tidytext)
## Warning: package 'tidytext' was built under R version 4.4.2
tidy_chocolate <- chocolate %>%
unnest_tokens(word, most_memorable_characteristics)
tidy_chocolate %>%
count(word, sort = TRUE)
## # A tibble: 547 × 2
## word n
## <chr> <int>
## 1 cocoa 419
## 2 sweet 318
## 3 nutty 278
## 4 fruit 273
## 5 roasty 228
## 6 mild 226
## 7 sour 208
## 8 earthy 199
## 9 creamy 189
## 10 intense 178
## # ℹ 537 more rows
tidy_chocolate %>%
group_by(word) %>%
summarise(n = n(),
rating = mean(rating)) %>%
ggplot(aes(n, rating)) +
geom_hline(yintercept = mean(chocolate$rating),
lty = 2, color = "gray50", size = 1.5) +
geom_point(color = "midnightblue", alpha = 0.7) +
geom_text(aes(label = word), family = "IBMPlexSons", check_overlap = TRUE, vjust = "top", hjust = "left") +
scale_x_log10()
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
library(tidymodels)
set.seed(123)
choco_split <- initial_split(chocolate, strata = rating)
choco_train <- training(choco_split)
choco_test <- testing(choco_split)
set.seed(234)
choco_folds <- vfold_cv(choco_train, strata = rating)
choco_folds
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [1705/191]> Fold01
## 2 <split [1705/191]> Fold02
## 3 <split [1705/191]> Fold03
## 4 <split [1706/190]> Fold04
## 5 <split [1706/190]> Fold05
## 6 <split [1706/190]> Fold06
## 7 <split [1707/189]> Fold07
## 8 <split [1707/189]> Fold08
## 9 <split [1708/188]> Fold09
## 10 <split [1709/187]> Fold10
Preprocessing
library(textrecipes)
## Warning: package 'textrecipes' was built under R version 4.4.2
choco_rec <-
recipe(rating ~ most_memorable_characteristics, data = choco_train) %>%
step_tokenize(most_memorable_characteristics) %>%
step_tokenfilter(most_memorable_characteristics, max_tokens = 100) %>%
step_tfidf(most_memorable_characteristics)
## just to check this works
prep(choco_rec) %>% bake(new_data = NULL)
## # A tibble: 1,896 × 101
## rating tfidf_most_memorable_c…¹ tfidf_most_memorable…² tfidf_most_memorable…³
## <dbl> <dbl> <dbl> <dbl>
## 1 3 0 0 0
## 2 2.75 0 0 0
## 3 3 0 0 0
## 4 3 0 0 0
## 5 2.75 0 0 0
## 6 3 1.38 0 0
## 7 2.75 0 0 0
## 8 2.5 0 0 0
## 9 2.75 0 0 0
## 10 3 0 0 0
## # ℹ 1,886 more rows
## # ℹ abbreviated names: ¹tfidf_most_memorable_characteristics_acidic,
## # ²tfidf_most_memorable_characteristics_and,
## # ³tfidf_most_memorable_characteristics_astringent
## # ℹ 97 more variables: tfidf_most_memorable_characteristics_banana <dbl>,
## # tfidf_most_memorable_characteristics_base <dbl>,
## # tfidf_most_memorable_characteristics_basic <dbl>, …
rf_spec <-
rand_forest(trees = 500) %>%
set_mode("regression")
rf_spec
## Random Forest Model Specification (regression)
##
## Main Arguments:
## trees = 500
##
## Computational engine: ranger
svm_spec <-
svm_linear() %>%
set_mode("regression")
svm_spec
## Linear Support Vector Machine Model Specification (regression)
##
## Computational engine: LiblineaR
svm_wf <- workflow(choco_rec, svm_spec)
rf_wf <- workflow(choco_rec, rf_spec)
library(LiblineaR)
## Warning: package 'LiblineaR' was built under R version 4.4.2
library(ranger)
## Warning: package 'ranger' was built under R version 4.4.2
library(doParallel)
## Warning: package 'doParallel' was built under R version 4.4.2
## Loading required package: foreach
## Warning: package 'foreach' was built under R version 4.4.2
##
## Attaching package: 'foreach'
## The following objects are masked from 'package:purrr':
##
## accumulate, when
## Loading required package: iterators
## Warning: package 'iterators' was built under R version 4.4.2
## Loading required package: parallel
doParallel::registerDoParallel()
contrl_preds <- control_resamples(save_pred = TRUE)
svm_rs <- fit_resamples(
svm_wf,
resamples = choco_folds,
control = contrl_preds
)
ranger_rs <- fit_resamples(
rf_wf,
resamples = choco_folds,
control = contrl_preds
)
collect_metrics(svm_rs)
## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 rmse standard 0.347 10 0.00656 Preprocessor1_Model1
## 2 rsq standard 0.367 10 0.0181 Preprocessor1_Model1
collect_metrics(ranger_rs)
## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 rmse standard 0.350 10 0.00666 Preprocessor1_Model1
## 2 rsq standard 0.358 10 0.0161 Preprocessor1_Model1
bind_rows(
collect_predictions(svm_rs) %>%
mutate(mod = "SVM"),
collect_predictions(ranger_rs) %>%
mutate(mod = "ranger")
) %>%
ggplot(aes(rating, .pred, color = id)) +
geom_abline(lty = 2, color = "gray50", size = 1.2) +
geom_jitter(width = 0.5, alpha = 0.5) +
facet_wrap(vars(mod)) +
coord_fixed()
final_fitted <- last_fit(svm_wf, choco_split)
collect_metrics(final_fitted) ## metrics evaluated on the *testing* data
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 0.381 Preprocessor1_Model1
## 2 rsq standard 0.348 Preprocessor1_Model1
final_wf <- extract_workflow(final_fitted)
predict(final_wf, choco_test[55, ])
## # A tibble: 1 × 1
## .pred
## <dbl>
## 1 3.51
extract_workflow(final_fitted) %>%
tidy() %>%
filter(term != "Bias") %>%
group_by(estimate > 0) %>%
slice_max(abs(estimate), n = 10) %>%
ungroup() %>%
mutate(term = str_remove(term, "tfidf_most_memorable_characteristics_")) %>%
ggplot(aes(estimate, fct_reorder(term, estimate), fill = estimate > 0)) +
geom_col(alpha = 0.8) +
scale_fill_discrete(labels = c("low ratings", "high ratings")) +
labs(y = NULL, fill = "More from...")