Goal: Build a regression model to predict ratings for chocolate based on their main characteristics
Click here for the data.
chocolate <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/main/data/2022/2022-01-18/chocolate.csv')
## Rows: 2530 Columns: 10
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (7): company_manufacturer, company_location, country_of_bean_origin, spe...
## dbl (3): ref, review_date, rating
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(chocolate)
Name | chocolate |
Number of rows | 2530 |
Number of columns | 10 |
_______________________ | |
Column type frequency: | |
character | 7 |
numeric | 3 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
company_manufacturer | 0 | 1.00 | 2 | 39 | 0 | 580 | 0 |
company_location | 0 | 1.00 | 4 | 21 | 0 | 67 | 0 |
country_of_bean_origin | 0 | 1.00 | 4 | 21 | 0 | 62 | 0 |
specific_bean_origin_or_bar_name | 0 | 1.00 | 3 | 51 | 0 | 1605 | 0 |
cocoa_percent | 0 | 1.00 | 3 | 6 | 0 | 46 | 0 |
ingredients | 87 | 0.97 | 4 | 14 | 0 | 21 | 0 |
most_memorable_characteristics | 0 | 1.00 | 3 | 37 | 0 | 2487 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
ref | 0 | 1 | 1429.80 | 757.65 | 5 | 802 | 1454.00 | 2079.0 | 2712 | ▆▇▇▇▇ |
review_date | 0 | 1 | 2014.37 | 3.97 | 2006 | 2012 | 2015.00 | 2018.0 | 2021 | ▃▅▇▆▅ |
rating | 0 | 1 | 3.20 | 0.45 | 1 | 3 | 3.25 | 3.5 | 4 | ▁▁▅▇▇ |
data1 <- chocolate %>%
# Treat missing values
na.omit() %>%
# Extract number of ingredients from ingredients
separate(col = ingredients, into = c("n_ing", "ing"), sep = "-( |)") %>%
# Separate ingredients into separate rows
separate_rows(ing, sep = ",") %>%
# Convert number of ingredients into numeric
mutate(n_ing = n_ing %>%
as.numeric()) %>%
# Extract memorable characteristics from most_memorable_characteristics
# separate(col = most_memorable_characteristics, into = c("most", "memorable", "characteristics"), sep = ",( |)") %>%
# Separate ingredients into separate rows
separate_rows(most_memorable_characteristics, sep = ",( |)") %>%
# Drop N/A
na.omit() %>%
# Remove specific_bean_origin_or_bar_name as it's info is captured in another
select(-specific_bean_origin_or_bar_name, -review_date) %>%
# Convert Cocoa % into numeric
mutate(cocoa_percent = cocoa_percent %>% str_remove("%") %>% as.numeric()) %>%
# Convert all character variables to factor
mutate(across(where(is.character), factor))
data1
## # A tibble: 20,815 × 9
## ref company_manufacturer company_location country_of_bean_origin
## <dbl> <fct> <fct> <fct>
## 1 2454 5150 U.S.A. Tanzania
## 2 2454 5150 U.S.A. Tanzania
## 3 2454 5150 U.S.A. Tanzania
## 4 2454 5150 U.S.A. Tanzania
## 5 2454 5150 U.S.A. Tanzania
## 6 2454 5150 U.S.A. Tanzania
## 7 2454 5150 U.S.A. Tanzania
## 8 2454 5150 U.S.A. Tanzania
## 9 2454 5150 U.S.A. Tanzania
## 10 2458 5150 U.S.A. Dominican Republic
## # ℹ 20,805 more rows
## # ℹ 5 more variables: cocoa_percent <dbl>, n_ing <dbl>, ing <fct>,
## # most_memorable_characteristics <fct>, rating <dbl>
Identify good predictors.
ref
data1 %>%
ggplot(aes(rating, ref)) +
scale_y_log10() +
geom_point()
cocoa_percent
data1 %>%
ggplot(aes(rating, cocoa_percent)) +
scale_y_log10() +
geom_point()
EDA Shortcut
# Step 1: Prepare date
data_binarized_tbl <- data1 %>%
select(-ref) %>%
binarize()
data_binarized_tbl %>% glimpse()
## Rows: 20,815
## Columns: 90
## $ company_manufacturer__A._Morin <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Bonnat <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Fresco <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Guittard <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Pralus <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Scharffen_Berger <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Soma <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Valrhona <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_manufacturer__Zotter <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `company_manufacturer__-OTHER` <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ company_location__Australia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Austria <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Belgium <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Brazil <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Canada <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Colombia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Denmark <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Ecuador <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__France <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Germany <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Italy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Japan <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Spain <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__Switzerland <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__U.K. <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ company_location__U.S.A. <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ company_location__Venezuela <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `company_location__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Belize <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Blend <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Bolivia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Brazil <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Colombia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Costa_Rica <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Dominican_Republic <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Ecuador <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Ghana <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Guatemala <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Haiti <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__India <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Jamaica <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Madagascar <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Mexico <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Nicaragua <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Papua_New_Guinea <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Peru <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Tanzania <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ country_of_bean_origin__Trinidad <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__U.S.A. <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Venezuela <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ country_of_bean_origin__Vietnam <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `country_of_bean_origin__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `cocoa_percent__-Inf_70` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cocoa_percent__70_74 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cocoa_percent__74_Inf <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ n_ing__2 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ n_ing__3 <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ n_ing__4 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ n_ing__5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `n_ing__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ing__B <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ ing__C <dbl> 0, 0, 0, 0, 0, 0, 1, 1, 1, …
## $ ing__L <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ing__S <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ `ing__S*` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ing__V <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `ing__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__cocoa <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__creamy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__earthy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__fatty <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, …
## $ most_memorable_characteristics__floral <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__fruit <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__intense <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__molasses <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__nutty <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__rich <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__roasty <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sandy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sour <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__spicy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sticky <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__sweet <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__vanilla <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ most_memorable_characteristics__woody <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `most_memorable_characteristics__-OTHER` <dbl> 1, 0, 1, 1, 0, 1, 1, 0, 1, …
## $ `rating__-Inf_3` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ rating__3_3.25 <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ rating__3.25_3.5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ rating__3.5_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, …
# Step 2: Correlate
data_corr_tbl <- data_binarized_tbl %>%
correlate(rating__3.5_Inf)
data_corr_tbl
## # A tibble: 90 × 3
## feature bin correlation
## <fct> <chr> <dbl>
## 1 rating 3.5_Inf 1
## 2 rating -Inf_3 -0.398
## 3 rating 3.25_3.5 -0.230
## 4 rating 3_3.25 -0.202
## 5 company_manufacturer -OTHER -0.154
## 6 company_manufacturer Soma 0.122
## 7 most_memorable_characteristics creamy 0.119
## 8 company_manufacturer Bonnat 0.103
## 9 company_manufacturer Scharffen_Berger 0.0979
## 10 company_manufacturer A._Morin 0.0757
## # ℹ 80 more rows
# Step 3: Plot
data_corr_tbl %>%
plot_correlation_funnel()
## Warning: ggrepel: 77 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
Split Data
data1 <- sample_n(data1, 300)
# Split into train and test dataset
set.seed(1235)
data_split1 <- rsample::initial_split(data1)
data_train1 <- training(data_split1)
data_test1 <- testing(data_split1)
# Further Split training dataset for cross validation
set.seed(2345)
data_cv1 <- rsample::vfold_cv(data_train1)
data_cv1
## # 10-fold cross-validation
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [202/23]> Fold01
## 2 <split [202/23]> Fold02
## 3 <split [202/23]> Fold03
## 4 <split [202/23]> Fold04
## 5 <split [202/23]> Fold05
## 6 <split [203/22]> Fold06
## 7 <split [203/22]> Fold07
## 8 <split [203/22]> Fold08
## 9 <split [203/22]> Fold09
## 10 <split [203/22]> Fold10
library(usemodels)
usemodels::use_xgboost(rating ~ ., data = data_train1)
## xgboost_recipe <-
## recipe(formula = rating ~ ., data = data_train1) %>%
## step_zv(all_predictors())
##
## xgboost_spec <-
## boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(),
## loss_reduction = tune(), sample_size = tune()) %>%
## set_mode("classification") %>%
## set_engine("xgboost")
##
## xgboost_workflow <-
## workflow() %>%
## add_recipe(xgboost_recipe) %>%
## add_model(xgboost_spec)
##
## set.seed(87619)
## xgboost_tune <-
## tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
# Specify Recipe
xgboost_recipe <-
recipe(formula = rating ~ ., data = data_train1) %>%
recipes::update_role(ref, new_role = "id variable") %>%
step_other(company_manufacturer, company_location, country_of_bean_origin, ing, most_memorable_characteristics, threshold = 0.02) %>%
step_dummy(company_manufacturer, company_location, country_of_bean_origin, ing, most_memorable_characteristics, one_hot = TRUE) %>%
step_log(cocoa_percent)
xgboost_recipe %>% prep() %>% juice() %>% glimpse
## Rows: 225
## Columns: 54
## $ ref <dbl> 572, 2138, 2414, 1732, 2230,…
## $ cocoa_percent <dbl> 4.317488, 4.248495, 4.317488…
## $ n_ing <dbl> 4, 4, 4, 3, 3, 2, 3, 3, 5, 4…
## $ rating <dbl> 3.75, 3.25, 3.25, 2.50, 3.00…
## $ company_manufacturer_Bonnat <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_manufacturer_Fresco <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_manufacturer_Soma <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_manufacturer_Valrhona <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_manufacturer_Zotter <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_manufacturer_other <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
## $ company_location_Austria <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_location_Belgium <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0…
## $ company_location_Canada <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_location_France <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0…
## $ company_location_Italy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ company_location_U.K. <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0…
## $ company_location_U.S.A. <dbl> 0, 1, 0, 1, 0, 0, 0, 1, 0, 1…
## $ company_location_other <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 1, 0…
## $ country_of_bean_origin_Belize <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Blend <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Bolivia <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Brazil <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Colombia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0…
## $ country_of_bean_origin_Dominican.Republic <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Ecuador <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Guatemala <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Haiti <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Indonesia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Madagascar <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Mexico <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0…
## $ country_of_bean_origin_Nicaragua <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Papua.New.Guinea <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0…
## $ country_of_bean_origin_Peru <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Tanzania <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Venezuela <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_Vietnam <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ country_of_bean_origin_other <dbl> 0, 0, 1, 0, 1, 0, 0, 1, 0, 1…
## $ ing_B <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 0…
## $ ing_C <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0…
## $ ing_L <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ ing_S <dbl> 0, 1, 0, 0, 0, 1, 0, 0, 0, 0…
## $ ing_V <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 1…
## $ ing_other <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_cocoa <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_creamy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_fatty <dbl> 0, 0, 1, 0, 1, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_fruity <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_nutty <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_oily <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_roasty <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_sandy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_sour <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ most_memorable_characteristics_sweet <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1…
## $ most_memorable_characteristics_other <dbl> 0, 1, 0, 1, 0, 1, 1, 1, 1, 0…
# Specify Model
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune(), mtry = tune(), learn_rate = tune()) %>%
set_mode("regression") %>%
set_engine("xgboost")
# Combine recipe and model using workflow
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_recipe) %>%
add_model(xgboost_spec)
# Tune hyperparameters
set.seed(12782)
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_cv1,
grid = 5)
## i Creating pre-processing data to finalize unknown parameter: mtry
tune::show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 10
## mtry trees min_n learn_rate .metric .estimator mean n std_err .config
## <int> <int> <int> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 6 66 24 0.277 rmse standard 0.426 10 0.0226 Preproces…
## 2 19 694 12 0.0217 rmse standard 0.429 10 0.0246 Preproces…
## 3 51 1723 9 0.00428 rmse standard 0.429 10 0.0238 Preproces…
## 4 37 1359 38 0.0583 rmse standard 0.442 10 0.0247 Preproces…
## 5 31 1157 26 0.00113 rmse standard 0.856 10 0.0275 Preproces…
# Update the model by selecting the best hyperparameters.
xgboost_fw1 <- tune::finalize_workflow(xgboost_workflow,
tune::select_best(xgboost_tune, metric = "rmse"))
# Fit the model on the entire training data and test it on the test data.
data_fit1 <- tune::last_fit(xgboost_fw1, data_split1)
tune::collect_metrics(data_fit1)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 0.430 Preprocessor1_Model1
## 2 rsq standard 0.0536 Preprocessor1_Model1
tune::collect_predictions(data_fit1) %>%
ggplot(aes(rating, .pred)) +
geom_point(alpha = 0.3, fill = "midnightblue") +
geom_abline(lty = 2, color = "gray50") +
coord_fixed()
I tried to improve my model by removing another variable from my model. I first removed review_date from being used in my machine learning algorithm since I didnt think that review date would have much of an impact. I then changed the plot I had for the relationship between review_date and rating to cocoa_percent and rating. I also removed the review_date from the step_log function in specifying the recipe. This increased the RMSE from .281 to .457 and decreased the rsq from .584 to .0281. This Rmse would still fall within a good range, but my model was stronger with the review_date still included, so I would probably have left my model as it was.