library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.5      ✔ recipes      1.0.10
## ✔ dials        1.2.1      ✔ rsample      1.2.1 
## ✔ dplyr        1.1.4      ✔ tibble       3.2.1 
## ✔ ggplot2      3.5.1      ✔ tidyr        1.3.1 
## ✔ infer        1.0.7      ✔ tune         1.2.1 
## ✔ modeldata    1.3.0      ✔ workflows    1.1.4 
## ✔ parsnip      1.2.1      ✔ workflowsets 1.1.0 
## ✔ purrr        1.0.2      ✔ yardstick    1.3.1
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter()  masks stats::filter()
## ✖ dplyr::lag()     masks stats::lag()
## ✖ recipes::step()  masks stats::step()
## • Dig deeper into tidy modeling with R at https://www.tmwr.org
library(ISLR)

theme_set(theme_bw())

auto <- tibble(Auto)
portfolio <- tibble(Portfolio)
glimpse(auto)
## Rows: 392
## Columns: 9
## $ mpg          <dbl> 18, 15, 18, 16, 17, 15, 14, 14, 14, 15, 15, 14, 15, 14, 2…
## $ cylinders    <dbl> 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 6, 6, 6, 4, …
## $ displacement <dbl> 307, 350, 318, 304, 302, 429, 454, 440, 455, 390, 383, 34…
## $ horsepower   <dbl> 130, 165, 150, 150, 140, 198, 220, 215, 225, 190, 170, 16…
## $ weight       <dbl> 3504, 3693, 3436, 3433, 3449, 4341, 4354, 4312, 4425, 385…
## $ acceleration <dbl> 12.0, 11.5, 11.0, 12.0, 10.5, 10.0, 9.0, 8.5, 10.0, 8.5, …
## $ year         <dbl> 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 7…
## $ origin       <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 3, …
## $ name         <fct> chevrolet chevelle malibu, buick skylark 320, plymouth sa…
glimpse(portfolio)
## Rows: 100
## Columns: 2
## $ X <dbl> -0.89525089, -1.56245433, -0.41708988, 1.04435573, -0.31556841, -1.7…
## $ Y <dbl> -0.2349235, -0.8851760, 0.2718880, -0.7341975, 0.8419834, -2.0371910…
set.seed(1)
auto_split <- initial_split(Auto, prop = .5, strata = mpg)

auto_train <- training(auto_split)
auto_test <- testing(auto_split)
lm_model <- 
  linear_reg() %>% 
  set_engine("lm")

lm_fit <- 
  fit(
    lm_model, 
    mpg ~ horsepower,
    data = auto_train
  )

lm_fit
## parsnip model object
## 
## 
## Call:
## stats::lm(formula = mpg ~ horsepower, data = data)
## 
## Coefficients:
## (Intercept)   horsepower  
##     39.5424      -0.1567
lm_fit %>% pluck("fit") %>% summary()
## 
## Call:
## stats::lm(formula = mpg ~ horsepower, data = data)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -13.2574  -3.0897  -0.0029   2.6124  13.8695 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept) 39.542424   1.035803   38.18   <2e-16 ***
## horsepower  -0.156736   0.009431  -16.62   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4.76 on 192 degrees of freedom
## Multiple R-squared:  0.5899, Adjusted R-squared:  0.5878 
## F-statistic: 276.2 on 1 and 192 DF,  p-value: < 2.2e-16
pred <- augment(
  lm_fit,
  auto_test
  )
pred %>% 
  rmse(truth = mpg, estimate = .pred)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard        5.06
pred %>% 
  rsq(truth = mpg, estimate = .pred)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rsq     standard       0.621
poly_rec <- 
  recipe(mpg ~horsepower, data = auto_train) %>% 
  step_poly(horsepower, degree = 2)

poly_wf <- 
  workflow() %>% 
  add_recipe(poly_rec) %>% 
  add_model(lm_model)

poly_fit <- 
  fit(
    poly_wf,
    data = auto_train
  )

poly_fit
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: linear_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 1 Recipe Step
## 
## • step_poly()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## 
## Call:
## stats::lm(formula = ..y ~ ., data = data)
## 
## Coefficients:
##       (Intercept)  horsepower_poly_1  horsepower_poly_2  
##             23.29             -79.11              25.10
pred_poly <- augment(poly_fit, auto_test)
rmse(pred_poly,
     truth = mpg,
     estimate = .pred
     )
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard        4.37
rsq(pred_poly,
     truth = mpg,
     estimate = .pred
     )
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rsq     standard       0.722
set.seed(2021)
auto_cv <- vfold_cv(auto_train, v = 10)
tuned_poly_rec <- 
  recipe(mpg ~horsepower, data = auto_train) %>% 
  step_poly(horsepower, degree = tune())


tuned_poly_rec %>% parameters
## Warning: `parameters.workflow()` was deprecated in tune 0.1.6.9003.
## ℹ Please use `hardhat::extract_parameter_set_dials()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Collection of 1 parameters for tuning
## 
##  identifier   type    object
##      degree degree nparam[+]
tuned_poly_rec %>% hardhat::extract_parameter_dials("degree")
## Polynomial Degree (quantitative)
## Range: [1, 3]
tuned_poly_wf <- 
  workflow() %>% 
  add_recipe(tuned_poly_rec) %>% 
  add_model(lm_model)
poly_grid <- grid_regular(degree(c(1,10)), levels = 10)

#poly_grid_tbl <- crossing(degree = 1:10)
ctrl <- control_grid(save_pred = T, save_workflow = T)

poly_fit <- 
  tune_grid(
    tuned_poly_wf,
    resamples = auto_cv,
    grid = poly_grid,
    metrics = metric_set(rmse, rsq, mae),
    control = ctrl
  )
poly_fit %>% collect_metrics()
## # A tibble: 30 × 7
##    degree .metric .estimator  mean     n std_err .config              
##     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
##  1      1 mae     standard   3.65     10  0.241  Preprocessor01_Model1
##  2      1 rmse    standard   4.68     10  0.322  Preprocessor01_Model1
##  3      1 rsq     standard   0.605    10  0.0468 Preprocessor01_Model1
##  4      2 mae     standard   3.23     10  0.230  Preprocessor02_Model1
##  5      2 rmse    standard   4.33     10  0.357  Preprocessor02_Model1
##  6      2 rsq     standard   0.662    10  0.0559 Preprocessor02_Model1
##  7      3 mae     standard   3.25     10  0.227  Preprocessor03_Model1
##  8      3 rmse    standard   4.36     10  0.361  Preprocessor03_Model1
##  9      3 rsq     standard   0.658    10  0.0562 Preprocessor03_Model1
## 10      4 mae     standard   3.28     10  0.227  Preprocessor04_Model1
## # ℹ 20 more rows
autoplot(poly_fit) + scale_x_continuous(breaks = 1:10)

poly_fit %>% show_best(metric = "rmse")
## # A tibble: 5 × 7
##   degree .metric .estimator  mean     n std_err .config              
##    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
## 1      2 rmse    standard    4.33    10   0.357 Preprocessor02_Model1
## 2      3 rmse    standard    4.36    10   0.361 Preprocessor03_Model1
## 3      4 rmse    standard    4.41    10   0.362 Preprocessor04_Model1
## 4      5 rmse    standard    4.44    10   0.369 Preprocessor05_Model1
## 5      7 rmse    standard    4.45    10   0.367 Preprocessor07_Model1
best_tune <- poly_fit %>% select_best(metric = "rmse")

poly_fit %>% select_by_one_std_err(degree, metric = "rmse")
## # A tibble: 1 × 2
##   degree .config              
##    <dbl> <chr>                
## 1      1 Preprocessor01_Model1
poly_final_wf <- tuned_poly_wf %>% 
  finalize_workflow(best_tune)

final_poly_fit <- 
  last_fit(poly_final_wf,
           split = auto_split)
final_poly_fit %>% collect_metrics()
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       4.37  Preprocessor1_Model1
## 2 rsq     standard       0.722 Preprocessor1_Model1
final_poly_fit %>% collect_predictions() %>% 
  ggplot(aes(x = .pred, y = mpg)) + 
  geom_point() + 
  geom_smooth(method = "lm", fill = "lightblue")
## `geom_smooth()` using formula = 'y ~ x'

sessionInfo()
## R version 4.4.0 (2024-04-24)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 20.04.6 LTS
## 
## Matrix products: default
## BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.8.so;  LAPACK version 3.9.0
## 
## locale:
##  [1] LC_CTYPE=C.UTF-8       LC_NUMERIC=C           LC_TIME=C.UTF-8       
##  [4] LC_COLLATE=C.UTF-8     LC_MONETARY=C.UTF-8    LC_MESSAGES=C.UTF-8   
##  [7] LC_PAPER=C.UTF-8       LC_NAME=C              LC_ADDRESS=C          
## [10] LC_TELEPHONE=C         LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C   
## 
## time zone: UTC
## tzcode source: system (glibc)
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] ISLR_1.4           yardstick_1.3.1    workflowsets_1.1.0 workflows_1.1.4   
##  [5] tune_1.2.1         tidyr_1.3.1        tibble_3.2.1       rsample_1.2.1     
##  [9] recipes_1.0.10     purrr_1.0.2        parsnip_1.2.1      modeldata_1.3.0   
## [13] infer_1.0.7        ggplot2_3.5.1      dplyr_1.1.4        dials_1.2.1       
## [17] scales_1.3.0       broom_1.0.5        tidymodels_1.2.0  
## 
## loaded via a namespace (and not attached):
##  [1] tidyselect_1.2.1    timeDate_4032.109   farver_2.1.2       
##  [4] fastmap_1.1.1       digest_0.6.35       rpart_4.1.23       
##  [7] timechange_0.3.0    lifecycle_1.0.4     ellipsis_0.3.2     
## [10] survival_3.5-8      magrittr_2.0.3      compiler_4.4.0     
## [13] rlang_1.1.3         sass_0.4.9          tools_4.4.0        
## [16] utf8_1.2.4          yaml_2.3.8          data.table_1.15.4  
## [19] knitr_1.46          labeling_0.4.3      DiceDesign_1.10    
## [22] withr_3.0.0         nnet_7.3-19         grid_4.4.0         
## [25] fansi_1.0.6         colorspace_2.1-0    future_1.33.2      
## [28] globals_0.16.3      iterators_1.0.14    MASS_7.3-60.2      
## [31] cli_3.6.2           rmarkdown_2.26      generics_0.1.3     
## [34] rstudioapi_0.16.0   future.apply_1.11.2 cachem_1.0.8       
## [37] splines_4.4.0       parallel_4.4.0      vctrs_0.6.5        
## [40] hardhat_1.3.1       Matrix_1.7-0        jsonlite_1.8.8     
## [43] listenv_0.9.1       foreach_1.5.2       gower_1.0.1        
## [46] jquerylib_0.1.4     glue_1.7.0          parallelly_1.37.1  
## [49] codetools_0.2-20    lubridate_1.9.3     gtable_0.3.5       
## [52] munsell_0.5.1       GPfit_1.0-8         pillar_1.9.0       
## [55] furrr_0.3.1         htmltools_0.5.8.1   ipred_0.9-14       
## [58] lava_1.8.0          R6_2.5.1            lhs_1.1.6          
## [61] evaluate_0.23       lattice_0.22-6      highr_0.10         
## [64] backports_1.4.1     bslib_0.7.0         class_7.3-22       
## [67] Rcpp_1.0.12         nlme_3.1-164        prodlim_2023.08.28 
## [70] mgcv_1.9-1          xfun_0.43           pkgconfig_2.0.3