Libraries

library(tidymodels)
library(ISLR)      # Smarket dataset
library(corrr)

4.1 The Stock Market Data

We examine the Smarket dataset, which contains daily percentage returns for the S&P 500 index over 2001–2005, along with a Direction variable ("Up" / "Down").

Correlation Matrix

We remove the non-numeric Direction column, then compute pairwise Pearson correlations.

cor_Smarket <- Smarket %>%
  select(-Direction) %>%
  correlate()

cor_Smarket
## # A tibble: 8 × 9
##   term      Year     Lag1     Lag2     Lag3     Lag4     Lag5  Volume    Today
##   <chr>    <dbl>    <dbl>    <dbl>    <dbl>    <dbl>    <dbl>   <dbl>    <dbl>
## 1 Year   NA       0.0297   0.0306   0.0332   0.0357   0.0298   0.539   0.0301 
## 2 Lag1    0.0297 NA       -0.0263  -0.0108  -0.00299 -0.00567  0.0409 -0.0262 
## 3 Lag2    0.0306 -0.0263  NA       -0.0259  -0.0109  -0.00356 -0.0434 -0.0103 
## 4 Lag3    0.0332 -0.0108  -0.0259  NA       -0.0241  -0.0188  -0.0418 -0.00245
## 5 Lag4    0.0357 -0.00299 -0.0109  -0.0241  NA       -0.0271  -0.0484 -0.00690
## 6 Lag5    0.0298 -0.00567 -0.00356 -0.0188  -0.0271  NA       -0.0220 -0.0349 
## 7 Volume  0.539   0.0409  -0.0434  -0.0418  -0.0484  -0.0220  NA       0.0146 
## 8 Today   0.0301 -0.0262  -0.0103  -0.00245 -0.00690 -0.0349   0.0146 NA

Correlation Plot

rplot(cor_Smarket, colours = c("indianred2", "black", "skyblue1"))

Most variables are nearly uncorrelated with each other. The notable exception is Year and Volume, which show a moderate positive relationship.

Heatmap-style Correlation Chart

cor_Smarket %>%
  stretch() %>%
  ggplot(aes(x, y, fill = r)) +
  geom_tile() +
  geom_text(aes(label = as.character(fashion(r)))) +
  scale_fill_gradient2(low = "indianred2", mid = "white", high = "skyblue1",
                       midpoint = 0, limits = c(-1, 1)) +
  labs(title = "Correlation heatmap – Smarket", x = NULL, y = NULL)

Volume Over Time

Plotting Year vs Volume confirms an upward trend in trading volume.

ggplot(Smarket, aes(Year, Volume)) +
  geom_jitter(height = 0, alpha = 0.4, colour = "steelblue") +
  labs(
    title = "Trading Volume by Year",
    x = "Year",
    y = "Volume (billions of shares traded)"
  )


4.2 Logistic Regression

Model Specification

We use logistic_reg() from parsnip with the glm engine (the default for logistic regression).

lr_spec <- logistic_reg() %>%
  set_engine("glm") %>%
  set_mode("classification")

Fit on Full Data

We model Direction using the five previous days’ returns (Lag1Lag5) plus Volume.

lr_fit <- lr_spec %>%
  fit(
    Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + Lag5 + Volume,
    data = Smarket
  )

lr_fit
## parsnip model object
## 
## 
## Call:  stats::glm(formula = Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + 
##     Lag5 + Volume, family = stats::binomial, data = data)
## 
## Coefficients:
## (Intercept)         Lag1         Lag2         Lag3         Lag4         Lag5  
##   -0.126000    -0.073074    -0.042301     0.011085     0.009359     0.010313  
##      Volume  
##    0.135441  
## 
## Degrees of Freedom: 1249 Total (i.e. Null);  1243 Residual
## Null Deviance:       1731 
## Residual Deviance: 1728  AIC: 1742

Summary

lr_fit %>%
  pluck("fit") %>%
  summary()
## 
## Call:
## stats::glm(formula = Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + 
##     Lag5 + Volume, family = stats::binomial, data = data)
## 
## Coefficients:
##              Estimate Std. Error z value Pr(>|z|)
## (Intercept) -0.126000   0.240736  -0.523    0.601
## Lag1        -0.073074   0.050167  -1.457    0.145
## Lag2        -0.042301   0.050086  -0.845    0.398
## Lag3         0.011085   0.049939   0.222    0.824
## Lag4         0.009359   0.049974   0.187    0.851
## Lag5         0.010313   0.049511   0.208    0.835
## Volume       0.135441   0.158360   0.855    0.392
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 1731.2  on 1249  degrees of freedom
## Residual deviance: 1727.6  on 1243  degrees of freedom
## AIC: 1741.6
## 
## Number of Fisher Scoring iterations: 3

Tidy Coefficients

tidy(lr_fit)
## # A tibble: 7 × 5
##   term        estimate std.error statistic p.value
##   <chr>          <dbl>     <dbl>     <dbl>   <dbl>
## 1 (Intercept) -0.126      0.241     -0.523   0.601
## 2 Lag1        -0.0731     0.0502    -1.46    0.145
## 3 Lag2        -0.0423     0.0501    -0.845   0.398
## 4 Lag3         0.0111     0.0499     0.222   0.824
## 5 Lag4         0.00936    0.0500     0.187   0.851
## 6 Lag5         0.0103     0.0495     0.208   0.835
## 7 Volume       0.135      0.158      0.855   0.392

None of the predictors have a statistically significant p-value, suggesting it is hard to predict market direction from these lagged returns.

Predictions

Class predictions

predict(lr_fit, new_data = Smarket)
## # A tibble: 1,250 × 1
##    .pred_class
##    <fct>      
##  1 Up         
##  2 Down       
##  3 Down       
##  4 Up         
##  5 Up         
##  6 Up         
##  7 Down       
##  8 Up         
##  9 Up         
## 10 Down       
## # ℹ 1,240 more rows

Probability predictions

predict(lr_fit, new_data = Smarket, type = "prob")
## # A tibble: 1,250 × 2
##    .pred_Down .pred_Up
##         <dbl>    <dbl>
##  1      0.493    0.507
##  2      0.519    0.481
##  3      0.519    0.481
##  4      0.485    0.515
##  5      0.489    0.511
##  6      0.493    0.507
##  7      0.507    0.493
##  8      0.491    0.509
##  9      0.482    0.518
## 10      0.511    0.489
## # ℹ 1,240 more rows

Confusion Matrix (training data)

augment(lr_fit, new_data = Smarket) %>%
  conf_mat(truth = Direction, estimate = .pred_class)
##           Truth
## Prediction Down  Up
##       Down  145 141
##       Up    457 507
augment(lr_fit, new_data = Smarket) %>%
  conf_mat(truth = Direction, estimate = .pred_class) %>%
  autoplot(type = "heatmap") +
  labs(title = "Confusion Matrix – Logistic Regression (training data)")

augment(lr_fit, new_data = Smarket) %>%
  accuracy(truth = Direction, estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.522

Training accuracy is only ~52%, barely better than random guessing.


Train / Test Split by Year

Since the data has a time component, we split by year: train on 2001–2004, test on 2005.

Smarket_train <- Smarket %>% filter(Year != 2005)
Smarket_test  <- Smarket %>% filter(Year == 2005)

dim(Smarket_train)
## [1] 998   9
dim(Smarket_test)
## [1] 252   9

Fit on Training Data

lr_fit2 <- lr_spec %>%
  fit(
    Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + Lag5 + Volume,
    data = Smarket_train
  )

Evaluate on Test Data

augment(lr_fit2, new_data = Smarket_test) %>%
  conf_mat(truth = Direction, estimate = .pred_class)
##           Truth
## Prediction Down Up
##       Down   77 97
##       Up     34 44
augment(lr_fit2, new_data = Smarket_test) %>%
  accuracy(truth = Direction, estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.480

Test accuracy drops to ~48% — worse than random chance. Evaluating on held-out data exposes that the model generalises poorly.


Reduced Model: Lag1 + Lag2 Only

Variables with large p-values add variance without reducing bias. We drop Lag3Lag5 and Volume.

lr_fit3 <- lr_spec %>%
  fit(
    Direction ~ Lag1 + Lag2,
    data = Smarket_train
  )

augment(lr_fit3, new_data = Smarket_test) %>%
  conf_mat(truth = Direction, estimate = .pred_class)
##           Truth
## Prediction Down  Up
##       Down   35  35
##       Up     76 106
augment(lr_fit3, new_data = Smarket_test) %>%
  accuracy(truth = Direction, estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.560

Test accuracy improves to ~56% — a meaningful gain from removing noisy predictors.


Predicting for New Observations

We predict market direction for two hypothetical days:

Day Lag1 Lag2
1 1.2 1.1
2 1.5 -0.8
Smarket_new <- tibble(
  Lag1 = c(1.2, 1.5),
  Lag2 = c(1.1, -0.8)
)

predict(lr_fit3, new_data = Smarket_new, type = "prob")
## # A tibble: 2 × 2
##   .pred_Down .pred_Up
##        <dbl>    <dbl>
## 1      0.521    0.479
## 2      0.504    0.496

Both days are predicted to have a slightly higher probability of going Down, with the probabilities quite close to 50/50, reflecting the difficulty of predicting short-term market movements.


## R version 4.5.1 (2025-06-13 ucrt)
## Platform: x86_64-w64-mingw32/x64
## Running under: Windows 10 x64 (build 19045)
## 
## Matrix products: default
##   LAPACK version 3.12.1
## 
## locale:
## [1] LC_COLLATE=English_United States.utf8 
## [2] LC_CTYPE=English_United States.utf8   
## [3] LC_MONETARY=English_United States.utf8
## [4] LC_NUMERIC=C                          
## [5] LC_TIME=English_United States.utf8    
## 
## time zone: Asia/Taipei
## tzcode source: internal
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] corrr_0.4.5        ISLR_1.4           yardstick_1.3.2    workflowsets_1.1.1
##  [5] workflows_1.3.0    tune_2.0.1         tidyr_1.3.1        tailor_0.1.0      
##  [9] rsample_1.3.1      recipes_1.3.1      purrr_1.1.0        parsnip_1.4.1     
## [13] modeldata_1.5.1    infer_1.1.0        ggplot2_4.0.0      dplyr_1.1.4       
## [17] dials_1.4.2        scales_1.4.0       broom_1.0.10       tidymodels_1.4.1  
## 
## loaded via a namespace (and not attached):
##  [1] tidyselect_1.2.1    timeDate_4041.110   farver_2.1.2       
##  [4] S7_0.2.0            fastmap_1.2.0       digest_0.6.37      
##  [7] rpart_4.1.24        timechange_0.3.0    lifecycle_1.0.4    
## [10] survival_3.8-3      magrittr_2.0.4      compiler_4.5.1     
## [13] rlang_1.1.6         sass_0.4.10         tools_4.5.1        
## [16] utf8_1.2.6          yaml_2.3.10         data.table_1.17.8  
## [19] knitr_1.50          labeling_0.4.3      DiceDesign_1.10    
## [22] RColorBrewer_1.1-3  withr_3.0.2         nnet_7.3-20        
## [25] grid_4.5.1          sparsevctrs_0.3.4   future_1.67.0      
## [28] globals_0.18.0      MASS_7.3-65         cli_3.6.5          
## [31] rmarkdown_2.29      generics_0.1.4      rstudioapi_0.17.1  
## [34] future.apply_1.20.0 cachem_1.1.0        splines_4.5.1      
## [37] parallel_4.5.1      vctrs_0.6.5         hardhat_1.4.2      
## [40] Matrix_1.7-3        jsonlite_2.0.0      listenv_0.9.1      
## [43] gower_1.0.2         jquerylib_0.1.4     glue_1.8.0         
## [46] parallelly_1.45.1   codetools_0.2-20    lubridate_1.9.4    
## [49] gtable_0.3.6        GPfit_1.0-9         tibble_3.3.0       
## [52] pillar_1.11.1       furrr_0.3.1         htmltools_0.5.8.1  
## [55] ipred_0.9-15        lava_1.8.1          R6_2.6.1           
## [58] lhs_1.2.1           evaluate_1.0.5      lattice_0.22-7     
## [61] backports_1.5.0     bslib_0.9.0         class_7.3-23       
## [64] Rcpp_1.1.0          prodlim_2025.04.28  xfun_0.52          
## [67] pkgconfig_2.0.3