install.packages(c("ISLR2", "skimr", "corrr", "tidymodels", "gt", "ggplot2"))
## Installing packages into '/cloud/lib/x86_64-pc-linux-gnu-library/4.4'
## (as 'lib' is unspecified)
library(ISLR2)
library(skimr)
library(corrr)
## 
## Attaching package: 'corrr'
## The following object is masked from 'package:skimr':
## 
##     focus
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
## ✔ broom        1.0.7     ✔ recipes      1.1.1
## ✔ dials        1.4.0     ✔ rsample      1.2.1
## ✔ dplyr        1.1.4     ✔ tibble       3.2.1
## ✔ ggplot2      3.5.1     ✔ tidyr        1.3.1
## ✔ infer        1.0.7     ✔ tune         1.3.0
## ✔ modeldata    1.4.0     ✔ workflows    1.2.0
## ✔ parsnip      1.3.1     ✔ workflowsets 1.1.0
## ✔ purrr        1.0.4     ✔ yardstick    1.3.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter()  masks stats::filter()
## ✖ dplyr::lag()     masks stats::lag()
## ✖ recipes::step()  masks stats::step()
library(gt)
library(ggplot2)
library(yardstick)
#4.7.1 The Stock Market Data
# Load Smarket dataset
smarket <- ISLR2::Smarket
# Summarize data
skimr::skim(smarket)
Data summary
Name smarket
Number of rows 1250
Number of columns 9
_______________________
Column type frequency:
factor 1
numeric 8
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
Direction 0 1 FALSE 2 Up: 648, Dow: 602

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Year 0 1 2003.02 1.41 2001.00 2002.00 2003.00 2004.00 2005.00 ▇▇▇▇▇
Lag1 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Lag2 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Lag3 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Lag4 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Lag5 0 1 0.01 1.15 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Volume 0 1 1.48 0.36 0.36 1.26 1.42 1.64 3.15 ▁▇▅▁▁
Today 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
# View first few rows
head(smarket)
##   Year   Lag1   Lag2   Lag3   Lag4   Lag5 Volume  Today Direction
## 1 2001  0.381 -0.192 -2.624 -1.055  5.010 1.1913  0.959        Up
## 2 2001  0.959  0.381 -0.192 -2.624 -1.055 1.2965  1.032        Up
## 3 2001  1.032  0.959  0.381 -0.192 -2.624 1.4112 -0.623      Down
## 4 2001 -0.623  1.032  0.959  0.381 -0.192 1.2760  0.614        Up
## 5 2001  0.614 -0.623  1.032  0.959  0.381 1.2057  0.213        Up
## 6 2001  0.213  0.614 -0.623  1.032  0.959 1.3491  1.392        Up
# Check pairwise correlations (excluding categorical variable)
smarket %>%
  select(-Direction) %>%
  corrr::correlate(method = "pearson", quiet = TRUE) %>%
  gt(rowname_col = "term")
Year Lag1 Lag2 Lag3 Lag4 Lag5 Volume Today
Year NA 0.029699649 0.030596422 0.033194581 0.035688718 0.029787995 0.53900647 0.030095229
Lag1 0.02969965 NA -0.026294328 -0.010803402 -0.002985911 -0.005674606 0.04090991 -0.026155045
Lag2 0.03059642 -0.026294328 NA -0.025896670 -0.010853533 -0.003557949 -0.04338321 -0.010250033
Lag3 0.03319458 -0.010803402 -0.025896670 NA -0.024051036 -0.018808338 -0.04182369 -0.002447647
Lag4 0.03568872 -0.002985911 -0.010853533 -0.024051036 NA -0.027083641 -0.04841425 -0.006899527
Lag5 0.02978799 -0.005674606 -0.003557949 -0.018808338 -0.027083641 NA -0.02200231 -0.034860083
Volume 0.53900647 0.040909908 -0.043383215 -0.041823686 -0.048414246 -0.022002315 NA 0.014591823
Today 0.03009523 -0.026155045 -0.010250033 -0.002447647 -0.006899527 -0.034860083 0.01459182 NA
# Visualization: Volume vs Year
ggplot(smarket, aes(x = factor(Year), y = Volume)) +
  geom_jitter(width = 0.3, color = "blue") +
  geom_boxplot(alpha = 0.3, outlier.shape = NA, width = 0.2)

#4.7.2 Logistic Regression
# Load Smarket dataset
smarket <- ISLR2::Smarket

# Summarize data
skimr::skim(smarket)
Data summary
Name smarket
Number of rows 1250
Number of columns 9
_______________________
Column type frequency:
factor 1
numeric 8
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
Direction 0 1 FALSE 2 Up: 648, Dow: 602

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Year 0 1 2003.02 1.41 2001.00 2002.00 2003.00 2004.00 2005.00 ▇▇▇▇▇
Lag1 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Lag2 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Lag3 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Lag4 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Lag5 0 1 0.01 1.15 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
Volume 0 1 1.48 0.36 0.36 1.26 1.42 1.64 3.15 ▁▇▅▁▁
Today 0 1 0.00 1.14 -4.92 -0.64 0.04 0.60 5.73 ▁▃▇▁▁
library(ISLR2)
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ forcats   1.0.0     ✔ readr     2.1.5
## ✔ lubridate 1.9.4     ✔ stringr   1.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ readr::col_factor() masks scales::col_factor()
## ✖ purrr::discard()    masks scales::discard()
## ✖ dplyr::filter()     masks stats::filter()
## ✖ stringr::fixed()    masks recipes::fixed()
## ✖ dplyr::lag()        masks stats::lag()
## ✖ readr::spec()       masks yardstick::spec()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels) # For modeling
library(ggplot2)
library(yardstick)
library(magrittr)
## 
## Attaching package: 'magrittr'
## 
## The following object is masked from 'package:tidyr':
## 
##     extract
## 
## The following object is masked from 'package:purrr':
## 
##     set_names
library(tidymodels)  # Loads all tidymodels sub-packages, including rsample
library(rsample)      # Explicitly load rsample to ensure initial_split() works
# Set seed for reproducibility
set.seed(123)

# Load Smarket dataset
smarket <- ISLR2::Smarket

# Split data (80% training, 20% testing), stratified by "Direction"
smarket_split <- rsample::initial_split(smarket, prop = 0.8, strata = Direction)

# Create training and testing datasets
smarket_train <- training(smarket_split)
smarket_test <- testing(smarket_split)

# Check split sizes
dim(smarket_train)
## [1] 999   9
dim(smarket_test)
## [1] 251   9
# ---- 2. Train Logistic Regression Model ----
glm_direction_fit <- logistic_reg(mode = "classification", engine = "glm") %>%
  fit(Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + Lag5 + Volume, data = smarket_train)
# ---- 3. Evaluate Model on Training Set ----
glm_train_pred <- glm_direction_fit %>%
  augment(smarket_train)

# Compute Accuracy on Training Data
train_accuracy <- glm_train_pred %>%
  accuracy(truth = Direction, estimate = .pred_class)

print(train_accuracy)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.541
# Plot ROC Curve for Training Data
train_roc <- glm_train_pred %>%
  roc_curve(truth = Direction, .pred_Up) %>%
  autoplot()

print(train_roc)

# ---- 4. Evaluate Model on Test Data ----
glm_test_pred <- glm_direction_fit %>%
  augment(smarket_test)

# Compute Accuracy on Test Data
test_accuracy <- glm_test_pred %>%
  accuracy(truth = Direction, estimate = .pred_class)

print(test_accuracy)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.502
# ---- 5. Simpler Model Using Lag1 & Lag2 Only ----
glm_simple_fit <- logistic_reg(mode = "classification", engine = "glm") %>%
  fit(Direction ~ Lag1 + Lag2, data = smarket_train)

glm_simple_pred <- glm_simple_fit %>%
  augment(smarket_test)

# Compute Accuracy for Simpler Model
simple_model_accuracy <- glm_simple_pred %>%
  accuracy(truth = Direction, estimate = .pred_class)

print(simple_model_accuracy)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.494