library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.5 ✔ recipes 1.0.10
## ✔ dials 1.2.1 ✔ rsample 1.2.1
## ✔ dplyr 1.1.4 ✔ tibble 3.2.1
## ✔ ggplot2 3.5.1 ✔ tidyr 1.3.1
## ✔ infer 1.0.7 ✔ tune 1.2.1
## ✔ modeldata 1.3.0 ✔ workflows 1.1.4
## ✔ parsnip 1.2.1 ✔ workflowsets 1.1.0
## ✔ purrr 1.0.2 ✔ yardstick 1.3.1
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ recipes::step() masks stats::step()
## • Dig deeper into tidy modeling with R at https://www.tmwr.org
library(ISLR)
theme_set(theme_bw())
auto <- tibble(Auto)
portfolio <- tibble(Portfolio)
glimpse(auto)
## Rows: 392
## Columns: 9
## $ mpg <dbl> 18, 15, 18, 16, 17, 15, 14, 14, 14, 15, 15, 14, 15, 14, 2…
## $ cylinders <dbl> 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 6, 6, 6, 4, …
## $ displacement <dbl> 307, 350, 318, 304, 302, 429, 454, 440, 455, 390, 383, 34…
## $ horsepower <dbl> 130, 165, 150, 150, 140, 198, 220, 215, 225, 190, 170, 16…
## $ weight <dbl> 3504, 3693, 3436, 3433, 3449, 4341, 4354, 4312, 4425, 385…
## $ acceleration <dbl> 12.0, 11.5, 11.0, 12.0, 10.5, 10.0, 9.0, 8.5, 10.0, 8.5, …
## $ year <dbl> 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 7…
## $ origin <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 3, …
## $ name <fct> chevrolet chevelle malibu, buick skylark 320, plymouth sa…
glimpse(portfolio)
## Rows: 100
## Columns: 2
## $ X <dbl> -0.89525089, -1.56245433, -0.41708988, 1.04435573, -0.31556841, -1.7…
## $ Y <dbl> -0.2349235, -0.8851760, 0.2718880, -0.7341975, 0.8419834, -2.0371910…
set.seed(1)
auto_split <- initial_split(Auto, prop = .5, strata = mpg)
auto_train <- training(auto_split)
auto_test <- testing(auto_split)
lm_model <-
linear_reg() %>%
set_engine("lm")
lm_fit <-
fit(
lm_model,
mpg ~ horsepower,
data = auto_train
)
lm_fit
## parsnip model object
##
##
## Call:
## stats::lm(formula = mpg ~ horsepower, data = data)
##
## Coefficients:
## (Intercept) horsepower
## 39.5424 -0.1567
lm_fit %>% pluck("fit") %>% summary()
##
## Call:
## stats::lm(formula = mpg ~ horsepower, data = data)
##
## Residuals:
## Min 1Q Median 3Q Max
## -13.2574 -3.0897 -0.0029 2.6124 13.8695
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 39.542424 1.035803 38.18 <2e-16 ***
## horsepower -0.156736 0.009431 -16.62 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4.76 on 192 degrees of freedom
## Multiple R-squared: 0.5899, Adjusted R-squared: 0.5878
## F-statistic: 276.2 on 1 and 192 DF, p-value: < 2.2e-16
pred <- augment(
lm_fit,
auto_test
)
pred %>%
rmse(truth = mpg, estimate = .pred)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 5.06
pred %>%
rsq(truth = mpg, estimate = .pred)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rsq standard 0.621
poly_rec <-
recipe(mpg ~horsepower, data = auto_train) %>%
step_poly(horsepower, degree = 2)
poly_wf <-
workflow() %>%
add_recipe(poly_rec) %>%
add_model(lm_model)
poly_fit <-
fit(
poly_wf,
data = auto_train
)
poly_fit
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: linear_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 1 Recipe Step
##
## • step_poly()
##
## ── Model ───────────────────────────────────────────────────────────────────────
##
## Call:
## stats::lm(formula = ..y ~ ., data = data)
##
## Coefficients:
## (Intercept) horsepower_poly_1 horsepower_poly_2
## 23.29 -79.11 25.10
pred_poly <- augment(poly_fit, auto_test)
rmse(pred_poly,
truth = mpg,
estimate = .pred
)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 4.37
rsq(pred_poly,
truth = mpg,
estimate = .pred
)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rsq standard 0.722
set.seed(2021)
auto_cv <- vfold_cv(auto_train, v = 10)
tuned_poly_rec <-
recipe(mpg ~horsepower, data = auto_train) %>%
step_poly(horsepower, degree = tune())
tuned_poly_rec %>% parameters
## Warning: `parameters.workflow()` was deprecated in tune 0.1.6.9003.
## ℹ Please use `hardhat::extract_parameter_set_dials()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Collection of 1 parameters for tuning
##
## identifier type object
## degree degree nparam[+]
tuned_poly_rec %>% hardhat::extract_parameter_dials("degree")
## Polynomial Degree (quantitative)
## Range: [1, 3]
tuned_poly_wf <-
workflow() %>%
add_recipe(tuned_poly_rec) %>%
add_model(lm_model)
poly_grid <- grid_regular(degree(c(1,10)), levels = 10)
#poly_grid_tbl <- crossing(degree = 1:10)
ctrl <- control_grid(save_pred = T, save_workflow = T)
poly_fit <-
tune_grid(
tuned_poly_wf,
resamples = auto_cv,
grid = poly_grid,
metrics = metric_set(rmse, rsq, mae),
control = ctrl
)
poly_fit %>% collect_metrics()
## # A tibble: 30 × 7
## degree .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 mae standard 3.65 10 0.241 Preprocessor01_Model1
## 2 1 rmse standard 4.68 10 0.322 Preprocessor01_Model1
## 3 1 rsq standard 0.605 10 0.0468 Preprocessor01_Model1
## 4 2 mae standard 3.23 10 0.230 Preprocessor02_Model1
## 5 2 rmse standard 4.33 10 0.357 Preprocessor02_Model1
## 6 2 rsq standard 0.662 10 0.0559 Preprocessor02_Model1
## 7 3 mae standard 3.25 10 0.227 Preprocessor03_Model1
## 8 3 rmse standard 4.36 10 0.361 Preprocessor03_Model1
## 9 3 rsq standard 0.658 10 0.0562 Preprocessor03_Model1
## 10 4 mae standard 3.28 10 0.227 Preprocessor04_Model1
## # ℹ 20 more rows
autoplot(poly_fit) + scale_x_continuous(breaks = 1:10)

poly_fit %>% show_best(metric = "rmse")
## # A tibble: 5 × 7
## degree .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 2 rmse standard 4.33 10 0.357 Preprocessor02_Model1
## 2 3 rmse standard 4.36 10 0.361 Preprocessor03_Model1
## 3 4 rmse standard 4.41 10 0.362 Preprocessor04_Model1
## 4 5 rmse standard 4.44 10 0.369 Preprocessor05_Model1
## 5 7 rmse standard 4.45 10 0.367 Preprocessor07_Model1
best_tune <- poly_fit %>% select_best(metric = "rmse")
poly_fit %>% select_by_one_std_err(degree, metric = "rmse")
## # A tibble: 1 × 2
## degree .config
## <dbl> <chr>
## 1 1 Preprocessor01_Model1
poly_final_wf <- tuned_poly_wf %>%
finalize_workflow(best_tune)
final_poly_fit <-
last_fit(poly_final_wf,
split = auto_split)
final_poly_fit %>% collect_metrics()
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 4.37 Preprocessor1_Model1
## 2 rsq standard 0.722 Preprocessor1_Model1
final_poly_fit %>% collect_predictions() %>%
ggplot(aes(x = .pred, y = mpg)) +
geom_point() +
geom_smooth(method = "lm", fill = "lightblue")
## `geom_smooth()` using formula = 'y ~ x'

sessionInfo()
## R version 4.4.0 (2024-04-24)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 20.04.6 LTS
##
## Matrix products: default
## BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.8.so; LAPACK version 3.9.0
##
## locale:
## [1] LC_CTYPE=C.UTF-8 LC_NUMERIC=C LC_TIME=C.UTF-8
## [4] LC_COLLATE=C.UTF-8 LC_MONETARY=C.UTF-8 LC_MESSAGES=C.UTF-8
## [7] LC_PAPER=C.UTF-8 LC_NAME=C LC_ADDRESS=C
## [10] LC_TELEPHONE=C LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C
##
## time zone: UTC
## tzcode source: system (glibc)
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] ISLR_1.4 yardstick_1.3.1 workflowsets_1.1.0 workflows_1.1.4
## [5] tune_1.2.1 tidyr_1.3.1 tibble_3.2.1 rsample_1.2.1
## [9] recipes_1.0.10 purrr_1.0.2 parsnip_1.2.1 modeldata_1.3.0
## [13] infer_1.0.7 ggplot2_3.5.1 dplyr_1.1.4 dials_1.2.1
## [17] scales_1.3.0 broom_1.0.5 tidymodels_1.2.0
##
## loaded via a namespace (and not attached):
## [1] tidyselect_1.2.1 timeDate_4032.109 farver_2.1.2
## [4] fastmap_1.1.1 digest_0.6.35 rpart_4.1.23
## [7] timechange_0.3.0 lifecycle_1.0.4 ellipsis_0.3.2
## [10] survival_3.5-8 magrittr_2.0.3 compiler_4.4.0
## [13] rlang_1.1.3 sass_0.4.9 tools_4.4.0
## [16] utf8_1.2.4 yaml_2.3.8 data.table_1.15.4
## [19] knitr_1.46 labeling_0.4.3 DiceDesign_1.10
## [22] withr_3.0.0 nnet_7.3-19 grid_4.4.0
## [25] fansi_1.0.6 colorspace_2.1-0 future_1.33.2
## [28] globals_0.16.3 iterators_1.0.14 MASS_7.3-60.2
## [31] cli_3.6.2 rmarkdown_2.26 generics_0.1.3
## [34] rstudioapi_0.16.0 future.apply_1.11.2 cachem_1.0.8
## [37] splines_4.4.0 parallel_4.4.0 vctrs_0.6.5
## [40] hardhat_1.3.1 Matrix_1.7-0 jsonlite_1.8.8
## [43] listenv_0.9.1 foreach_1.5.2 gower_1.0.1
## [46] jquerylib_0.1.4 glue_1.7.0 parallelly_1.37.1
## [49] codetools_0.2-20 lubridate_1.9.3 gtable_0.3.5
## [52] munsell_0.5.1 GPfit_1.0-8 pillar_1.9.0
## [55] furrr_0.3.1 htmltools_0.5.8.1 ipred_0.9-14
## [58] lava_1.8.0 R6_2.5.1 lhs_1.1.6
## [61] evaluate_0.23 lattice_0.22-6 highr_0.10
## [64] backports_1.4.1 bslib_0.7.0 class_7.3-22
## [67] Rcpp_1.0.12 nlme_3.1-164 prodlim_2023.08.28
## [70] mgcv_1.9-1 xfun_0.43 pkgconfig_2.0.3