This template offers an opinionated guide on how to structure a modeling analysis. Your individual modeling analysis may require you to add to, subtract from, or otherwise change this structure, but consider this a general framework to start from. If you want to learn more about using tidymodels, check out our Getting Started guide.
In this analysis, we will connect the rating to the most memorable characteristic.
library(tidyverse)
url <- "https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-01-18/chocolate.csv"
chocolate <- read_csv(url)
## Rows: 2530 Columns: 10
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (7): company_manufacturer, company_location, country_of_bean_origin, spe...
## dbl (3): ref, review_date, rating
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
chocolate %>%
ggplot(aes(rating)) +
geom_histogram(bins = 15)
Exploratory data analysis (EDA) is an important part of the modeling process.
chocolate %>%
ggplot(aes(rating)) +
geom_histogram(bins = 15)
library(tidytext)
## Warning: package 'tidytext' was built under R version 4.4.1
tidy_chocolate <- chocolate %>%
unnest_tokens(word, most_memorable_characteristics)
tidy_chocolate %>%
count(word, sort = TRUE)
## # A tibble: 547 × 2
## word n
## <chr> <int>
## 1 cocoa 419
## 2 sweet 318
## 3 nutty 278
## 4 fruit 273
## 5 roasty 228
## 6 mild 226
## 7 sour 208
## 8 earthy 199
## 9 creamy 189
## 10 intense 178
## # ℹ 537 more rows
tidy_chocolate %>%
group_by(word) %>%
summarise(n = n(),
rating = mean(rating)) %>%
ggplot(aes(n, rating)) +
geom_hline(yintercept = mean(chocolate$rating),
lty =2, color = "gray50", size = 1.5) +
geom_point(color = "midnightblue", alpha = 0.7) +
geom_text(aes(label = word), family = "IBMPlexSans",
check_overlap = TRUE, vjust = "top", hjust = "left") +
scale_x_log10()
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
Let’s consider how to spend our data budget:
library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.4.1
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.6 ✔ rsample 1.2.1
## ✔ dials 1.3.0 ✔ tune 1.2.1
## ✔ infer 1.0.7 ✔ workflows 1.1.4
## ✔ modeldata 1.4.0 ✔ workflowsets 1.1.0
## ✔ parsnip 1.2.1 ✔ yardstick 1.3.1
## ✔ recipes 1.0.10
## Warning: package 'dials' was built under R version 4.4.1
## Warning: package 'infer' was built under R version 4.4.1
## Warning: package 'modeldata' was built under R version 4.4.1
## Warning: package 'parsnip' was built under R version 4.4.1
## Warning: package 'tune' was built under R version 4.4.1
## Warning: package 'workflows' was built under R version 4.4.1
## Warning: package 'workflowsets' was built under R version 4.4.1
## Warning: package 'yardstick' was built under R version 4.4.1
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Use tidymodels_prefer() to resolve common conflicts.
set.seed(123)
choco_split <- initial_split(chocolate, strata = rating)
choco_train <- training(choco_split)
choco_test <- testing(choco_split)
set.seed(234)
choco_folds <- vfold_cv(choco_train, strata = rating)
choco_folds
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [1705/191]> Fold01
## 2 <split [1705/191]> Fold02
## 3 <split [1705/191]> Fold03
## 4 <split [1706/190]> Fold04
## 5 <split [1706/190]> Fold05
## 6 <split [1706/190]> Fold06
## 7 <split [1707/189]> Fold07
## 8 <split [1707/189]> Fold08
## 9 <split [1708/188]> Fold09
## 10 <split [1709/187]> Fold10
Let’s set up our preprocessing:
library(textrecipes)
## Warning: package 'textrecipes' was built under R version 4.4.1
choco_rec <-
recipe(rating ~ most_memorable_characteristics, data = choco_train) %>%
step_tokenize(most_memorable_characteristics) %>%
step_tokenfilter(most_memorable_characteristics, max_tokens = 100) %>%
step_tf(most_memorable_characteristics)
Let’s create a model specification for each model we want to try:
ranger_spec <-
rand_forest(trees = 500) %>%
#set_engine("ranger") %>% We do not need to specify when using defaults.
set_mode("regression")
ranger_spec
## Random Forest Model Specification (regression)
##
## Main Arguments:
## trees = 500
##
## Computational engine: ranger
svm_spec <-
svm_linear() %>%
#set_engine("LiblineaR") %>% We do not need to specify when using defaults.
set_mode("regression")
svm_spec
## Linear Support Vector Machine Model Specification (regression)
##
## Computational engine: LiblineaR
To set up your modeling code, consider using the parsnip addin or the usemodels package.
Now let’s build a model workflow combining each model specification with a data preprocessor:
ranger_wf <- workflow(choco_rec, ranger_spec)
svm_wf <- workflow(choco_rec, svm_spec)
If your feature engineering needs are more complex than provided by a
formula like sex ~ .
, use a recipe. Read more about feature
engineering with recipes to learn how they work.
These models have no tuning parameters so we can evaluate them as they are. Learn about tuning hyperparameters here.
library(doParallel)
## Warning: package 'doParallel' was built under R version 4.4.1
## Loading required package: foreach
## Warning: package 'foreach' was built under R version 4.4.1
##
## Attaching package: 'foreach'
## The following objects are masked from 'package:purrr':
##
## accumulate, when
## Loading required package: iterators
## Warning: package 'iterators' was built under R version 4.4.1
## Loading required package: parallel
library(LiblineaR)
## Warning: package 'LiblineaR' was built under R version 4.4.1
doParallel::registerDoParallel()
choco_preds <- control_resamples(save_pred = TRUE)
svm_rs <- fit_resamples(
svm_wf,
resamples = choco_folds,
control = choco_preds
)
ranger_rs <- fit_resamples(
ranger_wf,
resamples = choco_folds,
control = choco_preds
)
How did these two models compare?
collect_metrics(svm_rs)
## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 rmse standard 0.348 10 0.00704 Preprocessor1_Model1
## 2 rsq standard 0.365 10 0.0146 Preprocessor1_Model1
collect_metrics(ranger_rs)
## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 rmse standard 0.344 10 0.00715 Preprocessor1_Model1
## 2 rsq standard 0.379 10 0.0151 Preprocessor1_Model1
We can visualize these results:
bind_rows(
collect_predictions(svm_rs) %>%
mutate(mod = "svm"),
collect_predictions(ranger_rs) %>%
mutate(mod = "ranger")
) %>%
ggplot(aes(rating, .pred, color = id)) +
geom_abline(lty = 2, color = "gray50", size = 1.2) +
geom_jitter(width = 0.5, alpha = 0.5) +
facet_wrap(vars(mod)) +
coord_fixed()
These models perform very similarly, so perhaps we would choose the
simpler, linear model. The function last_fit()
fits one final time on the training data and evaluates
on the testing data. This is the first time we have used the testing
data.
final_fitted <- last_fit(svm_wf, choco_split)
collect_metrics(final_fitted) ## metrics evaluated on the *testing* data
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 0.385 Preprocessor1_Model1
## 2 rsq standard 0.340 Preprocessor1_Model1
This object contains a fitted workflow that we can use for prediction.
final_wf <- extract_workflow(final_fitted)
predict(final_wf, choco_test[55,])
## # A tibble: 1 × 1
## .pred
## <dbl>
## 1 3.70
You can save this fitted final_wf
object to use later
with new data, for example with readr::write_rds()
.
extract_workflow(final_fitted) %>%
tidy() %>%
filter(term != "Bias") %>%
group_by(estimate > 0) %>%
slice_max(abs(estimate), n = 10) %>%
ungroup() %>%
mutate(term = str_remove(term, "tf_most_memorable_characteristics_")) %>%
ggplot(aes(estimate, fct_reorder(term, estimate), fill = estimate > 0)) +
geom_col(alpha = 0.8)