#4.4.1 Linear Discriminant Analysis for
mu1 <- -1.25
mu2 <- 1.25
sigma1 <- 1
sigma2 <- 1
bayes_boundary <- (mu1 + mu2) / 2
library(ggplot2)
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ lubridate 1.9.4 ✔ tibble 3.2.1
## ✔ purrr 1.0.4 ✔ tidyr 1.3.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidyverse) # Loads ggplot2, dplyr, and other useful packages
mu1 <- -1.25
mu2 <- 1.25
sigma1 <- 1
sigma2 <- 1
bayes_boundary <- (mu1 + mu2) / 2
p1 <- ggplot(data = tibble(x = seq(-4, 4, 0.1)), aes(x)) +
stat_function(fun = dnorm, args = list(mean = mu1, sd = sigma1),
geom = "line", size = 1.5, color = "green") + # Replace with actual color
stat_function(fun = dnorm, args = list(mean = mu2, sd = sigma2),
geom = "line", size = 1.5, color = "purple") + # Replace with actual color
geom_vline(xintercept = bayes_boundary, lty = 2, size = 1.5) +
theme(axis.text.y = element_blank(), axis.ticks.y = element_blank(), axis.title.y = element_blank()) # Manually removing y-axis
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
set.seed(42)
d <- tribble(
~class, ~x,
1, rnorm(20, mean = mu1, sd = sigma1),
2, rnorm(20, mean = mu2, sd = sigma2)
) %>% unnest(x)
lda_boundary <- (mean(filter(d, class == 1)$x) + mean(filter(d, class == 2)$x)) / 2
p2 <- d %>%
ggplot(aes(x, fill = factor(class), color = factor(class))) +
geom_histogram(bins = 13, alpha = 0.5, position = "identity") +
geom_vline(xintercept = bayes_boundary, lty = 2, size = 1.5) +
geom_vline(xintercept = lda_boundary, lty = 1, size = 1.5) +
scale_fill_manual(values = c("green", "purple")) +
scale_color_manual(values = c("green", "purple")) +
theme(legend.position = "none")
install.packages("patchwork")
## Installing package into '/cloud/lib/x86_64-pc-linux-gnu-library/4.4'
## (as 'lib' is unspecified)
library(patchwork)
p1 | p2 # Combining the plots side by side (patchwork)
set.seed(2021)
d <- tribble(
~class, ~x,
1, rnorm(1e3, mean = mu1, sd = sigma1),
2, rnorm(1e3, mean = mu2, sd = sigma2)
) %>%
unnest(x)
# The LDA boundary must be recomputed with the new data
lda_boundary <-
(mean(filter(d, class == 1)$x) + mean(filter(d, class == 2)$x)) / 2
#4.4.2 Linear Discriminant Analysis for p>1
library(tidyr)
library(dplyr)
library(mvtnorm)
library(ggplot2)
library(patchwork)
d <- crossing(x1 = seq(-2, 2, 0.1), x2 = seq(-2, 2, 0.1))
d1 <- d %>%
bind_cols(
prob = mvtnorm::dmvnorm(
x = as.matrix(d),
mean = c(0, 0), sigma = matrix(c(1, 0, 0, 1), nrow = 2)
)
)
d2 <- d %>%
bind_cols(
prob = mvtnorm::dmvnorm(
x = as.matrix(d),
mean = c(0, 0), sigma = matrix(c(1, 0.7, 0.7, 1), nrow = 2)
)
)
p1 <- d1 %>%
ggplot(aes(x = x1, y = x2)) +
geom_tile(aes(fill = prob)) +
scale_x_continuous(expand = c(0, 0)) +
scale_y_continuous(expand = c(0, 0)) +
theme(legend.position = "none")
p2 <- d2 %>%
ggplot(aes(x = x1, y = x2)) +
geom_tile(aes(fill = prob)) +
scale_x_continuous(expand = c(0, 0)) +
scale_y_continuous(expand = c(0, 0)) +
theme(legend.position = "none")
p1 | p2
#4.5 A Comparison of Classification Methods
make_blobs <- function(
n_samples = 40, n_features = 2,
# By default, class 1 is centered at (0, 0) and class 2 at (1, 1)
cluster_centers = matrix(c(0, 0, 1, 1), nrow = 2, byrow = TRUE),
# By default, the two features are uncorrelated with variance = 1
cluster_covar = matrix(c(1, 0, 0, 1), nrow = 2),
dist = c("norm", "t"), t_df = 5
) {
if (ncol(cluster_centers) != n_features) {
stop("Dimensionality of centers must equal number of features")
}
if ((nrow(cluster_covar) != n_features) |
(ncol(cluster_covar) != n_features)) {
stop("Dimensionality of covariance matrix must match number of features")
}
dist <- match.arg(dist)
# Equally divides each of `n_samples` into the different categories according
# to the number of provided classes
categories <- rep(1:nrow(cluster_centers), length.out = n_samples)
if (dist == "norm") {
points <- MASS::mvrnorm(n = n_samples, mu = c(0, 0), Sigma = cluster_covar)
} else if (dist == "t") {
points <- mvtnorm::rmvt(n = n_samples, delta = c(0, 0), df = t_df,
sigma = cluster_covar)
}
points <- points + cluster_centers[categories, ]
colnames(points) <- c("x", "y")
as_tibble(points) %>%
bind_cols(category = factor(categories))
}
library(dplyr)
library(tidyr)
sim_linear_train <- tribble(
~ scenario, ~ n_samples, ~ corr, ~ dist,
"Scenario 1", 40, 0.0, "norm",
"Scenario 2", 40, -0.5, "norm",
"Scenario 3", 40, -0.5, "t"
) %>%
crossing(sim = 1:100) %>%
rowwise() %>%
mutate(
train_data = list(make_blobs(
n_samples = n_samples,
cluster_covar = matrix(c(1, corr, corr, 1), nrow = 2),
dist = dist
))
) %>%
ungroup()
sim_linear_test <- sim_linear_train %>%
distinct(scenario, corr, dist) %>%
rowwise() %>%
mutate(
test_data = list(make_blobs(
n_samples = 1000,
cluster_covar = matrix(c(1, corr, corr, 1), nrow = 2),
dist = dist
))
) %>%
ungroup()
# A helper function for fitting on a training set and getting accuracy from
# a testing set
calc_test_accuracy <- function(model_label, train_data, test_data, model) {
wf <- workflow() %>%
add_recipe(recipe(category ~ x + y, data = train_data)) %>%
add_model(model)
if (model_label == "KNN-CV") {
# 5 fold cross-validation
train_data_folds <- vfold_cv(train_data, v = 5)
tune_res <- wf %>%
tune_grid(
resamples = train_data_folds,
# Try 1 to 10 neighbors
grid = tibble(neighbors = 1:10)
)
# Overwrite the workflow with the best `neighbors` value by CV accuracy
wf <- finalize_workflow(wf, select_best(tune_res, "accuracy"))
}
wf %>%
fit(data = train_data) %>%
augment(test_data) %>%
accuracy(truth = category, estimate = .pred_class) %>%
pull(.estimate)
}
library(tictoc)
tic()
#4.8 Exercises
library(ISLR2)
weekly <- ISLR2::Weekly
library(skimr)
skimr::skim(weekly)
Name | weekly |
Number of rows | 1089 |
Number of columns | 9 |
_______________________ | |
Column type frequency: | |
factor | 1 |
numeric | 8 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
Direction | 0 | 1 | FALSE | 2 | Up: 605, Dow: 484 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
Year | 0 | 1 | 2000.05 | 6.03 | 1990.00 | 1995.00 | 2000.00 | 2005.00 | 2010.00 | ▇▆▆▆▆ |
Lag1 | 0 | 1 | 0.15 | 2.36 | -18.20 | -1.15 | 0.24 | 1.41 | 12.03 | ▁▁▆▇▁ |
Lag2 | 0 | 1 | 0.15 | 2.36 | -18.20 | -1.15 | 0.24 | 1.41 | 12.03 | ▁▁▆▇▁ |
Lag3 | 0 | 1 | 0.15 | 2.36 | -18.20 | -1.16 | 0.24 | 1.41 | 12.03 | ▁▁▆▇▁ |
Lag4 | 0 | 1 | 0.15 | 2.36 | -18.20 | -1.16 | 0.24 | 1.41 | 12.03 | ▁▁▆▇▁ |
Lag5 | 0 | 1 | 0.14 | 2.36 | -18.20 | -1.17 | 0.23 | 1.41 | 12.03 | ▁▁▆▇▁ |
Volume | 0 | 1 | 1.57 | 1.69 | 0.09 | 0.33 | 1.00 | 2.05 | 9.33 | ▇▂▁▁▁ |
Today | 0 | 1 | 0.15 | 2.36 | -18.20 | -1.15 | 0.24 | 1.41 | 12.03 | ▁▁▆▇▁ |
auto <- ISLR2::Auto %>%
mutate(mpg01 = ifelse(mpg > median(mpg), 1, 0),
mpg01 = factor(mpg01))
glimpse(auto)
## Rows: 392
## Columns: 10
## $ mpg <dbl> 18, 15, 18, 16, 17, 15, 14, 14, 14, 15, 15, 14, 15, 14, 2…
## $ cylinders <int> 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 6, 6, 6, 4, …
## $ displacement <dbl> 307, 350, 318, 304, 302, 429, 454, 440, 455, 390, 383, 34…
## $ horsepower <int> 130, 165, 150, 150, 140, 198, 220, 215, 225, 190, 170, 16…
## $ weight <int> 3504, 3693, 3436, 3433, 3449, 4341, 4354, 4312, 4425, 385…
## $ acceleration <dbl> 12.0, 11.5, 11.0, 12.0, 10.5, 10.0, 9.0, 8.5, 10.0, 8.5, …
## $ year <int> 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 7…
## $ origin <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 3, …
## $ name <fct> chevrolet chevelle malibu, buick skylark 320, plymouth sa…
## $ mpg01 <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
auto <- auto %>%
mutate(origin = factor(origin, levels = 1:3,
labels = c("American", "European", "Japanese")))
auto %>%
count(origin, mpg01) %>%
ggplot(aes(y = origin, x = mpg01)) +
geom_tile(aes(fill = n)) +
geom_text(aes(label = n), color = "white") +
scale_x_discrete(expand = c(0, 0)) +
scale_y_discrete(expand = c(0, 0)) +
theme(legend.position = "none")
boston <- ISLR2::Boston %>%
mutate(
crim01 = ifelse(crim > median(crim), 1, 0),
crim01 = factor(crim01),
# Convert the binary chas variable to TRUE/FALSE
chas = chas == 1
)
glimpse(boston)
## Rows: 506
## Columns: 14
## $ crim <dbl> 0.00632, 0.02731, 0.02729, 0.03237, 0.06905, 0.02985, 0.08829,…
## $ zn <dbl> 18.0, 0.0, 0.0, 0.0, 0.0, 0.0, 12.5, 12.5, 12.5, 12.5, 12.5, 1…
## $ indus <dbl> 2.31, 7.07, 7.07, 2.18, 2.18, 2.18, 7.87, 7.87, 7.87, 7.87, 7.…
## $ chas <lgl> FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE,…
## $ nox <dbl> 0.538, 0.469, 0.469, 0.458, 0.458, 0.458, 0.524, 0.524, 0.524,…
## $ rm <dbl> 6.575, 6.421, 7.185, 6.998, 7.147, 6.430, 6.012, 6.172, 5.631,…
## $ age <dbl> 65.2, 78.9, 61.1, 45.8, 54.2, 58.7, 66.6, 96.1, 100.0, 85.9, 9…
## $ dis <dbl> 4.0900, 4.9671, 4.9671, 6.0622, 6.0622, 6.0622, 5.5605, 5.9505…
## $ rad <int> 1, 2, 2, 3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4,…
## $ tax <dbl> 296, 242, 242, 222, 222, 222, 311, 311, 311, 311, 311, 311, 31…
## $ ptratio <dbl> 15.3, 17.8, 17.8, 18.7, 18.7, 18.7, 15.2, 15.2, 15.2, 15.2, 15…
## $ lstat <dbl> 4.98, 9.14, 4.03, 2.94, 5.33, 5.21, 12.43, 19.15, 29.93, 17.10…
## $ medv <dbl> 24.0, 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15…
## $ crim01 <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,…
boston %>%
count(chas, crim01) %>%
ggplot(aes(y = chas, x = crim01)) +
geom_tile(aes(fill = n)) +
geom_text(aes(label = n), color = "white") +
scale_x_discrete(expand = c(0, 0)) +
scale_y_discrete(expand = c(0, 0)) +
theme(legend.position = "none")