Non-Linear Regression: Polynomial regression, Step functions, Regression splines, Smoothing splines, Local regression

  1. Polynomial regression: adding extra predictors by raising power
  2. Step functions: cutting the X into K regions and fit piece-wise constant function
  3. Regression splines: Poly + Step \(\Rightarrow\) Diving X into K parts and fitting a polynomial in each part
  4. Cubic splines: 3rd degree polynomial splines for smoothing.
  5. Local regression (loess) [Locally Estimated Scatterplot Smoothing, non-parametric]: Similar to splines but the regions are allowed to overlap in a smooth way.
pacman::p_load(ISLR, jtools, splines, gam, mgcv, tidyverse)

data(Wage)
head(Wage)
##        year age           maritl     race       education             region
## 231655 2006  18 1. Never Married 1. White    1. < HS Grad 2. Middle Atlantic
## 86582  2004  24 1. Never Married 1. White 4. College Grad 2. Middle Atlantic
## 161300 2003  45       2. Married 1. White 3. Some College 2. Middle Atlantic
## 155159 2003  43       2. Married 3. Asian 4. College Grad 2. Middle Atlantic
## 11443  2005  50      4. Divorced 1. White      2. HS Grad 2. Middle Atlantic
## 376662 2008  54       2. Married 1. White 4. College Grad 2. Middle Atlantic
##              jobclass         health health_ins  logwage      wage
## 231655  1. Industrial      1. <=Good      2. No 4.318063  75.04315
## 86582  2. Information 2. >=Very Good      2. No 4.255273  70.47602
## 161300  1. Industrial      1. <=Good     1. Yes 4.875061 130.98218
## 155159 2. Information 2. >=Very Good     1. Yes 5.041393 154.68529
## 11443  2. Information      1. <=Good     1. Yes 4.318063  75.04315
## 376662 2. Information 2. >=Very Good     1. Yes 4.845098 127.11574
summary(Wage)
##       year           age                     maritl           race     
##  Min.   :2003   Min.   :18.00   1. Never Married: 648   1. White:2480  
##  1st Qu.:2004   1st Qu.:33.75   2. Married      :2074   2. Black: 293  
##  Median :2006   Median :42.00   3. Widowed      :  19   3. Asian: 190  
##  Mean   :2006   Mean   :42.41   4. Divorced     : 204   4. Other:  37  
##  3rd Qu.:2008   3rd Qu.:51.00   5. Separated    :  55                  
##  Max.   :2009   Max.   :80.00                                          
##                                                                        
##               education                     region               jobclass   
##  1. < HS Grad      :268   2. Middle Atlantic   :3000   1. Industrial :1544  
##  2. HS Grad        :971   1. New England       :   0   2. Information:1456  
##  3. Some College   :650   3. East North Central:   0                        
##  4. College Grad   :685   4. West North Central:   0                        
##  5. Advanced Degree:426   5. South Atlantic    :   0                        
##                           6. East South Central:   0                        
##                           (Other)              :   0                        
##             health      health_ins      logwage           wage       
##  1. <=Good     : 858   1. Yes:2083   Min.   :3.000   Min.   : 20.09  
##  2. >=Very Good:2142   2. No : 917   1st Qu.:4.447   1st Qu.: 85.38  
##                                      Median :4.653   Median :104.92  
##                                      Mean   :4.654   Mean   :111.70  
##                                      3rd Qu.:4.857   3rd Qu.:128.68  
##                                      Max.   :5.763   Max.   :318.34  
## 
spline.df = as.data.frame(spline(Wage$age, Wage$wage))
## Warning in regularize.values(x, y, ties, missing(ties)): collapsing to unique
## 'x' values
head(spline.df)
##          x        y
## 1 18.00000 64.49306
## 2 18.34066 55.78684
## 3 18.68132 52.79561
## 4 19.02198 54.18650
## 5 19.36264 58.55654
## 6 19.70330 64.27147
ggplot(Wage, aes(x = age, y = wage)) +
geom_point(alpha = 0.4, color = 'steelblue') +
geom_smooth(method = 'loess', formula = y ~ x, color = 'grey30', se = F) +
geom_smooth(method = 'lm', formula = y ~ x, color = 'green', se = T) +
geom_smooth(method = 'lm', formula = y ~ poly(x, 4), color = 'orange', se = F) +
geom_smooth(method = 'lm', formula = y ~ bs(x, knots = c(25, 40, 50)), se = F, linewidth = 2, linetype = 'dashed') +
geom_line(data = spline.df, aes(x=x, y=y), color = 'black', linetype = 'solid', linewidth = 1) +
theme_apa()

# Fit linear model and check R2
fit_linear = lm(wage ~ age, Wage)
summ(fit_linear)
Observations 3000
Dependent variable wage
Type OLS linear regression
F(1,2998) 119.31
0.04
Adj. R² 0.04
Est. S.E. t val. p
(Intercept) 81.70 2.85 28.71 0.00
age 0.71 0.06 10.92 0.00
Standard errors: OLS
# Fit 4th degree polynomial and check R2
fit_poly = lm(wage ~ poly(age, 4), Wage)
summ(fit_poly)
Observations 3000
Dependent variable wage
Type OLS linear regression
F(4,2995) 70.69
0.09
Adj. R² 0.09
Est. S.E. t val. p
(Intercept) 111.70 0.73 153.28 0.00
poly(age, 4)1 447.07 39.91 11.20 0.00
poly(age, 4)2 -478.32 39.91 -11.98 0.00
poly(age, 4)3 125.52 39.91 3.14 0.00
poly(age, 4)4 -77.91 39.91 -1.95 0.05
Standard errors: OLS
# Fit splines (bs = basis functions) model and check R2
# We can set the knots manually
fit_splines = lm(wage ~ bs(age, knots = c(25, 40, 60)), Wage)
summ(fit_splines)
Observations 3000
Dependent variable wage
Type OLS linear regression
F(6,2993) 47.19
0.09
Adj. R² 0.08
Est. S.E. t val. p
(Intercept) 60.49 9.46 6.39 0.00
bs(age, knots = c(25, 40, 60))1 3.98 12.54 0.32 0.75
bs(age, knots = c(25, 40, 60))2 44.63 9.63 4.64 0.00
bs(age, knots = c(25, 40, 60))3 62.84 10.76 5.84 0.00
bs(age, knots = c(25, 40, 60))4 55.99 10.71 5.23 0.00
bs(age, knots = c(25, 40, 60))5 50.69 14.40 3.52 0.00
bs(age, knots = c(25, 40, 60))6 16.61 19.13 0.87 0.39
Standard errors: OLS
# Fit splines (bs = basis functions using df) model and check R2
# Cubic splines with three knots has 7 df (intercept + 6 basis functions)
# df option set knots at uniform quantiles. df 6 has three knots (33.75, 42, 51)

fit_splines_df = lm(wage ~ bs(age, df = 6), Wage)
summ(fit_splines_df)
Observations 3000
Dependent variable wage
Type OLS linear regression
F(6,2993) 47.71
0.09
Adj. R² 0.09
Est. S.E. t val. p
(Intercept) 56.31 7.26 7.76 0.00
bs(age, df = 6)1 27.82 12.43 2.24 0.03
bs(age, df = 6)2 54.06 7.13 7.59 0.00
bs(age, df = 6)3 65.83 8.32 7.91 0.00
bs(age, df = 6)4 55.81 8.72 6.40 0.00
bs(age, df = 6)5 72.13 13.74 5.25 0.00
bs(age, df = 6)6 14.75 16.21 0.91 0.36
Standard errors: OLS
# We can use natural splines (ns), df 4, coefficients 4
fit_splines_ns = lm(wage ~ ns(age, df = 4), Wage)
summ(fit_splines_ns)
Observations 3000
Dependent variable wage
Type OLS linear regression
F(4,2995) 70.43
0.09
Adj. R² 0.08
Est. S.E. t val. p
(Intercept) 58.56 4.24 13.83 0.00
ns(age, df = 4)1 60.46 4.19 14.43 0.00
ns(age, df = 4)2 41.96 4.37 9.60 0.00
ns(age, df = 4)3 97.02 10.39 9.34 0.00
ns(age, df = 4)4 9.77 8.66 1.13 0.26
Standard errors: OLS
# Fit local regression and check R2
# GAMs = Generalized Additive Models
# gam found in 2 packages. Either pick the one you want with mgcv::gam or gam::gam
# span is the hyper-parameter of loess regression.
# We can change and check the results to find out the optimum span value.

fit_loess = loess(wage ~ age, data = Wage, span = 0.5)
summary(fit_loess)
## Call:
## loess(formula = wage ~ age, data = Wage, span = 0.5)
## 
## Number of Observations: 3000 
## Equivalent Number of Parameters: 7.13 
## Residual Standard Error: 39.89 
## Trace of smoother matrix: 7.85  (exact)
## 
## Control settings:
##   span     :  0.5 
##   degree   :  2 
##   family   :  gaussian
##   surface  :  interpolate      cell = 0.2
##   normalize:  TRUE
##  parametric:  FALSE
## drop.square:  FALSE
# lo local regression
fit_loess_gam = mgcv::gam(wage ~ lo(age, span = 0.7), data = Wage)  

# Check the adjusted R2
summary(fit_loess_gam)
## 
## Family: gaussian 
## Link function: identity 
## 
## Formula:
## wage ~ lo(age, span = 0.7)
## 
## Parametric coefficients:
##                     Estimate Std. Error t value Pr(>|t|)    
## (Intercept)         81.70474    2.84624   28.71   <2e-16 ***
## lo(age, span = 0.7)  0.70728    0.06475   10.92   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## 
## R-sq.(adj) =  0.038   Deviance explained = 3.83%
## GCV = 1676.3  Scale est. = 1675.2    n = 3000
# Fit local regression and splines in the same model
# See the adjusted R2, better that others
fit_slo <- mgcv::gam(wage ~ s(year, k=4) + lo(age, span = 0.5),  data = Wage)

summary(fit_slo)
## 
## Family: gaussian 
## Link function: identity 
## 
## Formula:
## wage ~ s(year, k = 4) + lo(age, span = 0.5)
## 
## Parametric coefficients:
##                     Estimate Std. Error t value Pr(>|t|)    
## (Intercept)          82.0340     2.8437   28.85   <2e-16 ***
## lo(age, span = 0.5)   0.6995     0.0647   10.81   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Approximate significance of smooth terms:
##           edf Ref.df     F p-value   
## s(year) 1.206  1.379 6.934 0.00308 **
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## R-sq.(adj) =  0.0411   Deviance explained = 4.18%
## GCV = 1671.5  Scale est. = 1669.7    n = 3000
plot(fit_slo)

Now compare RMSE of different models. Select the model with the lowest RMSE.

# Linear
fit_linear$residuals^2 %>% mean() %>% sqrt()
## [1] 40.91543
# Polynomial
fit_poly$residuals^2 %>% mean() %>% sqrt()
## [1] 39.88151
# B-spline
fit_splines$residuals^2 %>% mean() %>% sqrt()
## [1] 39.87805
# B-spline df
fit_splines_df$residuals^2 %>% mean() %>% sqrt()
## [1] 39.8591
# Natural spine
fit_splines_ns$residuals^2 %>% mean() %>% sqrt()
## [1] 39.8878
# Loess
fit_loess$residuals^2 %>% mean() %>% sqrt()
## [1] 39.83456
# Loess GAM
fit_loess_gam$residuals^2 %>% mean() %>% sqrt()
## [1] 40.91543
# Loess + Spline
fit_slo$residuals^2 %>% mean() %>% sqrt()
## [1] 40.84052

So, we see that loess regression performs better.