We will use the ames data set from the modeldata library. It can be loaded using the following code

Ames Housing Data

library(tidymodels)
library(tidyverse)
data("ames")
ames
## # A tibble: 2,930 x 74
##    MS_SubClass      MS_Zoning    Lot_Frontage Lot_Area Street Alley   Lot_Shape 
##  * <fct>            <fct>               <dbl>    <int> <fct>  <fct>   <fct>     
##  1 One_Story_1946_… Residential…          141    31770 Pave   No_All… Slightly_…
##  2 One_Story_1946_… Residential…           80    11622 Pave   No_All… Regular   
##  3 One_Story_1946_… Residential…           81    14267 Pave   No_All… Slightly_…
##  4 One_Story_1946_… Residential…           93    11160 Pave   No_All… Regular   
##  5 Two_Story_1946_… Residential…           74    13830 Pave   No_All… Slightly_…
##  6 Two_Story_1946_… Residential…           78     9978 Pave   No_All… Slightly_…
##  7 One_Story_PUD_1… Residential…           41     4920 Pave   No_All… Regular   
##  8 One_Story_PUD_1… Residential…           43     5005 Pave   No_All… Slightly_…
##  9 One_Story_PUD_1… Residential…           39     5389 Pave   No_All… Slightly_…
## 10 Two_Story_1946_… Residential…           60     7500 Pave   No_All… Regular   
## # … with 2,920 more rows, and 67 more variables: Land_Contour <fct>,
## #   Utilities <fct>, Lot_Config <fct>, Land_Slope <fct>, Neighborhood <fct>,
## #   Condition_1 <fct>, Condition_2 <fct>, Bldg_Type <fct>, House_Style <fct>,
## #   Overall_Cond <fct>, Year_Built <int>, Year_Remod_Add <int>,
## #   Roof_Style <fct>, Roof_Matl <fct>, Exterior_1st <fct>, Exterior_2nd <fct>,
## #   Mas_Vnr_Type <fct>, Mas_Vnr_Area <dbl>, Exter_Cond <fct>, Foundation <fct>,
## #   Bsmt_Cond <fct>, Bsmt_Exposure <fct>, BsmtFin_Type_1 <fct>,
## #   BsmtFin_SF_1 <dbl>, BsmtFin_Type_2 <fct>, BsmtFin_SF_2 <dbl>,
## #   Bsmt_Unf_SF <dbl>, Total_Bsmt_SF <dbl>, Heating <fct>, Heating_QC <fct>,
## #   Central_Air <fct>, Electrical <fct>, First_Flr_SF <int>,
## #   Second_Flr_SF <int>, Gr_Liv_Area <int>, Bsmt_Full_Bath <dbl>,
## #   Bsmt_Half_Bath <dbl>, Full_Bath <int>, Half_Bath <int>,
## #   Bedroom_AbvGr <int>, Kitchen_AbvGr <int>, TotRms_AbvGrd <int>,
## #   Functional <fct>, Fireplaces <int>, Garage_Type <fct>, Garage_Finish <fct>,
## #   Garage_Cars <dbl>, Garage_Area <dbl>, Garage_Cond <fct>, Paved_Drive <fct>,
## #   Wood_Deck_SF <int>, Open_Porch_SF <int>, Enclosed_Porch <int>,
## #   Three_season_porch <int>, Screen_Porch <int>, Pool_Area <int>,
## #   Pool_QC <fct>, Fence <fct>, Misc_Feature <fct>, Misc_Val <int>,
## #   Mo_Sold <int>, Year_Sold <int>, Sale_Type <fct>, Sale_Condition <fct>,
## #   Sale_Price <int>, Longitude <dbl>, Latitude <dbl>

We will try to predict the Sale_Price of a house by the Longitude of its location (this would not be the best idea alone, but serves as an example). Use step_bs() to fit a spline onto Longitude use cross-validation to find the value of degree where the model performs best.

Data Visualization

According to the true underlying relationship between Longitude and Sale_Price , The plot does not follow the linear trend.

ggplot(ames, aes(Longitude, Sale_Price)) +
  geom_point() +
  theme_bw() +
  geom_smooth(se = FALSE)
## `geom_smooth()` using method = 'gam' and formula 'y ~ s(x, bs = "cs")'

step_bs()

To begin with, we will split ames into training and testing sets.

set.seed(1234)
ames_split <- initial_split(ames)
ames_train <- training(ames_split)
ames_test  <- testing(ames_split)

Construct linear regression model specification, basis spline recipe, and basis spline workflow

lm_spec <- linear_reg() %>%
  set_mode("regression") %>%
  set_engine("lm")

rec_bs <- recipe(Sale_Price ~ Longitude, ames_train) %>%
  step_bs(Longitude, degree = tune())

wf_bs <- workflow() %>%
  add_model(lm_spec) %>%
  add_recipe(rec_bs)

create 10-Fold Cross-Validation in the training data set.

set.seed(4321)
ames_folds <- vfold_cv(ames_train, strata = Sale_Price)
ames_folds
## #  10-fold cross-validation using stratification 
## # A tibble: 10 x 2
##    splits             id    
##    <list>             <chr> 
##  1 <split [1977/221]> Fold01
##  2 <split [1977/221]> Fold02
##  3 <split [1977/221]> Fold03
##  4 <split [1977/221]> Fold04
##  5 <split [1978/220]> Fold05
##  6 <split [1978/220]> Fold06
##  7 <split [1979/219]> Fold07
##  8 <split [1979/219]> Fold08
##  9 <split [1980/218]> Fold09
## 10 <split [1980/218]> Fold10
# param_gris <- grid_regular(degree_int(range = c(1, 5)), levels = 5)

Make a data frame with values ranging from 1 to 10. Later, we will piecewise fit the best multiple polynomials.

param_grid <- tibble(degree = 1:10)
#param_grid <- grid_regular(degree_int(range = c(1, 10)), levels = 20)
param_grid
## # A tibble: 10 x 1
##    degree
##     <int>
##  1      1
##  2      2
##  3      3
##  4      4
##  5      5
##  6      6
##  7      7
##  8      8
##  9      9
## 10     10
tune_res <- tune_grid(
  object = wf_bs,
  resamples = ames_folds,
  grid = param_grid, control = control_grid(verbose = TRUE)
)
# tune_res$.notes[[1]]

Our goal is find the lowest value of rmse and the highest value of rsq. Look at the plot; if the spline degree is 10, the rmse and rsq performance will be the best.

tune_res %>%
  collect_metrics()
## # A tibble: 20 x 7
##    degree .metric .estimator       mean     n   std_err .config              
##     <int> <chr>   <chr>           <dbl> <int>     <dbl> <chr>                
##  1      1 rmse    standard   79375.        10 2817.     Preprocessor01_Model1
##  2      1 rsq     standard       0.0690    10    0.0107 Preprocessor01_Model1
##  3      2 rmse    standard   76476.        10 2743.     Preprocessor02_Model1
##  4      2 rsq     standard       0.135     10    0.0140 Preprocessor02_Model1
##  5      3 rmse    standard   76434.        10 2726.     Preprocessor03_Model1
##  6      3 rsq     standard       0.137     10    0.0155 Preprocessor03_Model1
##  7      4 rmse    standard   73537.        10 2793.     Preprocessor04_Model1
##  8      4 rsq     standard       0.201     10    0.0173 Preprocessor04_Model1
##  9      5 rmse    standard   73445.        10 2737.     Preprocessor05_Model1
## 10      5 rsq     standard       0.204     10    0.0180 Preprocessor05_Model1
## 11      6 rmse    standard   72440.        10 2867.     Preprocessor06_Model1
## 12      6 rsq     standard       0.225     10    0.0199 Preprocessor06_Model1
## 13      7 rmse    standard   71347.        10 2837.     Preprocessor07_Model1
## 14      7 rsq     standard       0.250     10    0.0216 Preprocessor07_Model1
## 15      8 rmse    standard   71355.        10 2767.     Preprocessor08_Model1
## 16      8 rsq     standard       0.250     10    0.0210 Preprocessor08_Model1
## 17      9 rmse    standard   71082.        10 2773.     Preprocessor09_Model1
## 18      9 rsq     standard       0.256     10    0.0217 Preprocessor09_Model1
## 19     10 rmse    standard   70819.        10 2704.     Preprocessor10_Model1
## 20     10 rsq     standard       0.261     10    0.0209 Preprocessor10_Model1
autoplot(tune_res) +
  geom_vline(xintercept = 10, color = "red") 

tune_res %>%
  show_best(metric = "rmse")
## # A tibble: 5 x 7
##   degree .metric .estimator   mean     n std_err .config              
##    <int> <chr>   <chr>       <dbl> <int>   <dbl> <chr>                
## 1     10 rmse    standard   70819.    10   2704. Preprocessor10_Model1
## 2      9 rmse    standard   71082.    10   2773. Preprocessor09_Model1
## 3      7 rmse    standard   71347.    10   2837. Preprocessor07_Model1
## 4      8 rmse    standard   71355.    10   2767. Preprocessor08_Model1
## 5      6 rmse    standard   72440.    10   2867. Preprocessor06_Model1
tune_res %>%
  show_best(metric = "rsq")
## # A tibble: 5 x 7
##   degree .metric .estimator  mean     n std_err .config              
##    <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
## 1     10 rsq     standard   0.261    10  0.0209 Preprocessor10_Model1
## 2      9 rsq     standard   0.256    10  0.0217 Preprocessor09_Model1
## 3      7 rsq     standard   0.250    10  0.0216 Preprocessor07_Model1
## 4      8 rsq     standard   0.250    10  0.0210 Preprocessor08_Model1
## 5      6 rsq     standard   0.225    10  0.0199 Preprocessor06_Model1

We will take the best root mean squared error solution to fit the model. That is, the spline degree is 10.

final_wf_bs <- finalize_workflow(wf_bs, select_best(tune_res, metric = "rmse"))
final_fit_bs <- fit(final_wf_bs, data = ames_train)
tidy(final_fit_bs)
## # A tibble: 11 x 5
##    term             estimate std.error statistic  p.value
##    <chr>               <dbl>     <dbl>     <dbl>    <dbl>
##  1 (Intercept)       205881.    20599.     9.99  4.98e-23
##  2 Longitude_bs_01  -264013.   139249.    -1.90  5.81e- 2
##  3 Longitude_bs_02  1784511.   420087.     4.25  2.25e- 5
##  4 Longitude_bs_03 -5852191.  1040461.    -5.62  2.10e- 8
##  5 Longitude_bs_04  9640798.  1677426.     5.75  1.03e- 8
##  6 Longitude_bs_05 -9625273.  2026118.    -4.75  2.16e- 6
##  7 Longitude_bs_06  6759761.  1768430.     3.82  1.36e- 4
##  8 Longitude_bs_07 -3507083.  1155499.    -3.04  2.43e- 3
##  9 Longitude_bs_08  1123711.   535723.     2.10  3.61e- 2
## 10 Longitude_bs_09  -354138.   189119.    -1.87  6.13e- 2
## 11 Longitude_bs_10   -22570.    50954.    -0.443 6.58e- 1
augment(final_fit_bs, new_data = ames_test) %>%
  rmse(truth = Sale_Price, estimate = .pred) %>%
  mutate(note = "bs model") -> bs_model
## Warning in bs(x = c(-93.638925, -93.636947, -93.638647, -93.622971,
## -93.639366, : some 'x' values beyond boundary knots may cause ill-conditioned
## bases
bs_model
## # A tibble: 1 x 4
##   .metric .estimator .estimate note    
##   <chr>   <chr>          <dbl> <chr>   
## 1 rmse    standard    1193156. bs model

Next we will use step_discretize() and step_cut() to fit step function into Longitude to see if that works better.

step_discretize()

rec_discretize <- recipe(Sale_Price ~ Longitude, ames_train) %>%
  step_discretize(Longitude, num_breaks = tune())
wf_discretize <- workflow() %>%
  add_model(lm_spec) %>%
  add_recipe(rec_discretize)
param_grid2 <- tibble(num_breaks = 1:10)
#param_grid <- grid_regular(degree_int(range = c(1, 10)), levels = 20)
param_grid2
## # A tibble: 10 x 1
##    num_breaks
##         <int>
##  1          1
##  2          2
##  3          3
##  4          4
##  5          5
##  6          6
##  7          7
##  8          8
##  9          9
## 10         10
tune_res2 <- tune_grid(
  object = wf_discretize,
  resamples = ames_folds,
  grid = param_grid2, control = control_grid(verbose = TRUE)
)

# tune_res2$.notes # Error: There should be at least 2 cuts

We will take the best root mean squared error solution to fit the model. That is, the spline degree is 8.

tune_res2 %>%
  collect_metrics()
## # A tibble: 18 x 7
##    num_breaks .metric .estimator       mean     n   std_err .config             
##         <int> <chr>   <chr>           <dbl> <int>     <dbl> <chr>               
##  1          2 rmse    standard   78278.        10 2719.     Preprocessor02_Mode…
##  2          2 rsq     standard       0.0948    10    0.0121 Preprocessor02_Mode…
##  3          3 rmse    standard   76033.        10 2820.     Preprocessor03_Mode…
##  4          3 rsq     standard       0.144     10    0.0128 Preprocessor03_Mode…
##  5          4 rmse    standard   73555.        10 2443.     Preprocessor04_Mode…
##  6          4 rsq     standard       0.201     10    0.0143 Preprocessor04_Mode…
##  7          5 rmse    standard   73965.        10 2682.     Preprocessor05_Mode…
##  8          5 rsq     standard       0.195     10    0.0220 Preprocessor05_Mode…
##  9          6 rmse    standard   75489.        10 2726.     Preprocessor06_Mode…
## 10          6 rsq     standard       0.157     10    0.0142 Preprocessor06_Mode…
## 11          7 rmse    standard   72754.        10 3146.     Preprocessor07_Mode…
## 12          7 rsq     standard       0.220     10    0.0268 Preprocessor07_Mode…
## 13          8 rmse    standard   68920.        10 2522.     Preprocessor08_Mode…
## 14          8 rsq     standard       0.303     10    0.0248 Preprocessor08_Mode…
## 15          9 rmse    standard   72079.        10 2528.     Preprocessor09_Mode…
## 16          9 rsq     standard       0.235     10    0.0190 Preprocessor09_Mode…
## 17         10 rmse    standard   72025.        10 2914.     Preprocessor10_Mode…
## 18         10 rsq     standard       0.235     10    0.0223 Preprocessor10_Mode…
autoplot(tune_res2) +
  geom_vline(xintercept = 8, color = "red") 

tune_res2 %>%
  show_best(metric = "rmse")
## # A tibble: 5 x 7
##   num_breaks .metric .estimator   mean     n std_err .config              
##        <int> <chr>   <chr>       <dbl> <int>   <dbl> <chr>                
## 1          8 rmse    standard   68920.    10   2522. Preprocessor08_Model1
## 2         10 rmse    standard   72025.    10   2914. Preprocessor10_Model1
## 3          9 rmse    standard   72079.    10   2528. Preprocessor09_Model1
## 4          7 rmse    standard   72754.    10   3146. Preprocessor07_Model1
## 5          4 rmse    standard   73555.    10   2443. Preprocessor04_Model1
tune_res2 %>%
  show_best(metric = "rsq")
## # A tibble: 5 x 7
##   num_breaks .metric .estimator  mean     n std_err .config              
##        <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
## 1          8 rsq     standard   0.303    10  0.0248 Preprocessor08_Model1
## 2          9 rsq     standard   0.235    10  0.0190 Preprocessor09_Model1
## 3         10 rsq     standard   0.235    10  0.0223 Preprocessor10_Model1
## 4          7 rsq     standard   0.220    10  0.0268 Preprocessor07_Model1
## 5          4 rsq     standard   0.201    10  0.0143 Preprocessor04_Model1
final_wf_discretize <- finalize_workflow(wf_discretize, select_best(tune_res2, metric = "rmse"))
final_fit_discretize <- fit(final_wf_discretize, data = ames_train)
tidy(final_fit_discretize)
## # A tibble: 8 x 5
##   term          estimate std.error statistic  p.value
##   <chr>            <dbl>     <dbl>     <dbl>    <dbl>
## 1 (Intercept)    198258.     4204.    47.2   0       
## 2 Longitudebin2  -48830.     5946.    -8.21  3.66e-16
## 3 Longitudebin3   76389.     5951.    12.8   2.06e-36
## 4 Longitudebin4    5200.     5946.     0.875 3.82e- 1
## 5 Longitudebin5    1796.     5946.     0.302 7.63e- 1
## 6 Longitudebin6  -46094.     5951.    -7.75  1.45e-14
## 7 Longitudebin7  -59490.     5946.   -10.0   4.49e-23
## 8 Longitudebin8  -61947.     5946.   -10.4   7.70e-25
augment(final_fit_discretize, new_data = ames_test) %>%
  rmse(truth = Sale_Price, estimate = .pred) %>%
  mutate(note = "discretize model") -> discretize_model
discretize_model
## # A tibble: 1 x 4
##   .metric .estimator .estimate note            
##   <chr>   <chr>          <dbl> <chr>           
## 1 rmse    standard      57687. discretize model

step_cut()

Now, we can supply the breaks manually.

rec_cut <- recipe(Sale_Price ~ Longitude, ames_train) %>%
  step_cut(Longitude, breaks = c(-93.675, -93.650, -93.625))

#ames_train %>%
  #select(Longitude) %>%
 # arrange(desc(Longitude)) %>%
 # tail()
wf_cut <- workflow() %>%
  add_model(lm_spec) %>%
  add_recipe(rec_cut)
final_fit_cut <- fit(wf_cut, data = ames_train)
tidy(final_fit_cut)
## # A tibble: 4 x 5
##   term                     estimate std.error statistic  p.value
##   <chr>                       <dbl>     <dbl>     <dbl>    <dbl>
## 1 (Intercept)               188229.     4027.    46.7   0       
## 2 Longitude(-93.67,-93.65]   34441.     5383.     6.40  1.91e-10
## 3 Longitude(-93.65,-93.62]    4212.     4907.     0.858 3.91e- 1
## 4 Longitude(-93.62,-93.6]   -51522.     5019.   -10.3   3.50e-24
augment(final_fit_cut, new_data = ames_test) %>%
  rmse(truth = Sale_Price, estimate = .pred) %>%
  mutate(note = "cut model") -> cut_model
cut_model
## # A tibble: 1 x 4
##   .metric .estimator .estimate note     
##   <chr>   <chr>          <dbl> <chr>    
## 1 rmse    standard      66978. cut model

Comparison

The discretize model has the best performance in predicting Sale_Price based on Longitude because it has the smallest rmse.

bind_rows(bs_model, discretize_model, cut_model)
## # A tibble: 3 x 4
##   .metric .estimator .estimate note            
##   <chr>   <chr>          <dbl> <chr>           
## 1 rmse    standard    1193156. bs model        
## 2 rmse    standard      57687. discretize model
## 3 rmse    standard      66978. cut model