library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
## ✔ broom        1.0.10     ✔ recipes      1.3.1 
## ✔ dials        1.4.2      ✔ rsample      1.3.1 
## ✔ dplyr        1.1.4      ✔ tailor       0.1.0 
## ✔ ggplot2      4.0.2      ✔ tidyr        1.3.1 
## ✔ infer        1.1.0      ✔ tune         2.0.1 
## ✔ modeldata    1.5.1      ✔ workflows    1.3.0 
## ✔ parsnip      1.4.1      ✔ workflowsets 1.1.1 
## ✔ purrr        1.1.0      ✔ yardstick    1.3.2
## Warning: package 'ggplot2' was built under R version 4.5.2
## Warning: package 'infer' was built under R version 4.5.2
## Warning: package 'parsnip' was built under R version 4.5.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter()  masks stats::filter()
## ✖ dplyr::lag()     masks stats::lag()
## ✖ recipes::step()  masks stats::step()
library(ISLR) # For the Smarket data set
library(ISLR2) # For the Bikeshare data set
## 
## Attaching package: 'ISLR2'
## The following objects are masked from 'package:ISLR':
## 
##     Auto, Credit
library(discrim)
## Warning: package 'discrim' was built under R version 4.5.2
## 
## Attaching package: 'discrim'
## The following object is masked from 'package:dials':
## 
##     smoothness
library(poissonreg)
library(corrr)
cor_Smarket <- Smarket %>%
  select(-Direction) %>%
  correlate()
## Correlation computed with
## • Method: 'pearson'
## • Missing treated using: 'pairwise.complete.obs'
rplot(cor_Smarket, colours = c("indianred2", "black", "skyblue1"))
## Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
## ℹ Please use tidy evaluation idioms with `aes()`.
## ℹ See also `vignette("ggplot2-in-packages")` for more information.
## ℹ The deprecated feature was likely used in the corrr package.
##   Please report the issue at <https://github.com/tidymodels/corrr/issues>.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

library(paletteer)
## Warning: package 'paletteer' was built under R version 4.5.2
cor_Smarket %>%
  stretch() %>%
  ggplot(aes(x, y, fill = r)) +
  geom_tile() +
  geom_text(aes(label = as.character(fashion(r)))) +
  scale_fill_paletteer_c("scico::roma", limits = c(-1, 1), direction = -1)

ggplot(Smarket, aes(Year, Volume)) +
  geom_jitter(height = 0)

# Load dataset 
data(Smarket)
# Convert response variable to factor 
Smarket$Direction <- as.factor(Smarket$Direction)
# View structure
glimpse(Smarket)
## Rows: 1,250
## Columns: 9
## $ Year      <dbl> 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, …
## $ Lag1      <dbl> 0.381, 0.959, 1.032, -0.623, 0.614, 0.213, 1.392, -0.403, 0.…
## $ Lag2      <dbl> -0.192, 0.381, 0.959, 1.032, -0.623, 0.614, 0.213, 1.392, -0…
## $ Lag3      <dbl> -2.624, -0.192, 0.381, 0.959, 1.032, -0.623, 0.614, 0.213, 1…
## $ Lag4      <dbl> -1.055, -2.624, -0.192, 0.381, 0.959, 1.032, -0.623, 0.614, …
## $ Lag5      <dbl> 5.010, -1.055, -2.624, -0.192, 0.381, 0.959, 1.032, -0.623, …
## $ Volume    <dbl> 1.1913, 1.2965, 1.4112, 1.2760, 1.2057, 1.3491, 1.4450, 1.40…
## $ Today     <dbl> 0.959, 1.032, -0.623, 0.614, 0.213, 1.392, -0.403, 0.027, 1.…
## $ Direction <fct> Up, Up, Down, Up, Up, Up, Down, Up, Up, Up, Down, Down, Up, …
lr_spec <- logistic_reg() %>% 
  set_engine("glm") %>% 
  set_mode("classification")
lr_fit <- lr_spec %>% 
  fit(Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + Lag5 + Volume,
      data = Smarket)
lr_fit %>%
  pluck("fit") %>%
  summary()
## 
## Call:
## stats::glm(formula = Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + 
##     Lag5 + Volume, family = stats::binomial, data = data)
## 
## Coefficients:
##              Estimate Std. Error z value Pr(>|z|)
## (Intercept) -0.126000   0.240736  -0.523    0.601
## Lag1        -0.073074   0.050167  -1.457    0.145
## Lag2        -0.042301   0.050086  -0.845    0.398
## Lag3         0.011085   0.049939   0.222    0.824
## Lag4         0.009359   0.049974   0.187    0.851
## Lag5         0.010313   0.049511   0.208    0.835
## Volume       0.135441   0.158360   0.855    0.392
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 1731.2  on 1249  degrees of freedom
## Residual deviance: 1727.6  on 1243  degrees of freedom
## AIC: 1741.6
## 
## Number of Fisher Scoring iterations: 3
augment(lr_fit, new_data = Smarket) %>%
  conf_mat(truth = Direction, estimate = .pred_class) %>%
  autoplot(type = "heatmap")

tidy(lr_fit)
## # A tibble: 7 × 5
##   term        estimate std.error statistic p.value
##   <chr>          <dbl>     <dbl>     <dbl>   <dbl>
## 1 (Intercept) -0.126      0.241     -0.523   0.601
## 2 Lag1        -0.0731     0.0502    -1.46    0.145
## 3 Lag2        -0.0423     0.0501    -0.845   0.398
## 4 Lag3         0.0111     0.0499     0.222   0.824
## 5 Lag4         0.00936    0.0500     0.187   0.851
## 6 Lag5         0.0103     0.0495     0.208   0.835
## 7 Volume       0.135      0.158      0.855   0.392
pred_class <- predict(lr_fit, new_data = Smarket)
head(pred_class)
## # A tibble: 6 × 1
##   .pred_class
##   <fct>      
## 1 Up         
## 2 Down       
## 3 Down       
## 4 Up         
## 5 Up         
## 6 Up
pred_prob <- predict(lr_fit, new_data = Smarket, type = "prob")
head(pred_prob)
## # A tibble: 6 × 2
##   .pred_Down .pred_Up
##        <dbl>    <dbl>
## 1      0.493    0.507
## 2      0.519    0.481
## 3      0.519    0.481
## 4      0.485    0.515
## 5      0.489    0.511
## 6      0.493    0.507
Smarket_pred <- bind_cols(Smarket, pred_class)

conf_mat(Smarket_pred, truth = Direction, estimate = .pred_class)
##           Truth
## Prediction Down  Up
##       Down  145 141
##       Up    457 507
accuracy(Smarket_pred, truth = Direction, estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.522