knitr::opts_chunk$set(echo = TRUE,
fig.align = 'center')
# Install packages FNN
# Load them here:
pacman::p_load(tidyverse, skimr, caTools, FNN)
# Changing the default theme
theme_set(theme_bw())
# Read in the two data sets then merging the econ and edu data sets together
counties_reg <-
# Merging the econ and edu data sets together based on the fips column
inner_join(x = read.csv("education.csv"),
y = read.csv("economic.csv"),
by = c("fips", "county", "state")) |>
# adding the county, state combo as the row names:
mutate(county_state = paste0(county, ", ", state)) |>
column_to_rownames(var = "county_state") |>
# Changing MHI to be in 1000s
mutate(med_home_income = med_home_income/1000) |>
# Just picking the columns we'll use in regression
dplyr::select(med_home_income, no_hs:poverty, -metro, -some_college) |>
# Removing the outliers with MHI above 150k and no_hs > 0.6
filter(
med_home_income < 150,
no_hs < 60
)
# Normalize function:
normalize <- function(x){return((x - min(x))/(max(x)-min(x)))}
# Fit stats function
reg_fit_stats <- function(y, y_hat){
error <- y - y_hat
SSE <- sum(error^2)
SST <- sum((y - mean(y))^2)
return(
list(
R2 = 1 - SSE/SST,
rmse = sqrt(SSE/(length(y) - 1)),
MAE = mean(abs(y - y_hat))
)
)
}
For kNN regression, we make a prediction by averaging the reponse from the k nearest neighbors, determined by the predictors/features:
\[\hat{y} = \frac{1}{k}\sum_{i = 1}^k y_i\]
How does the choice of k affect the complexity of our model?
Reminder, as complexity increases, the bias of our model decreases, but the variance of our model increases. Likewise, as complexity decreases, the bias increases and the variance decreases.
Likewise, choosing higher values of k decreases the complexity of the model, lowering the variability but increasing the bias.
It is easier to demonstrate what happens to the variance as k changes than the bias. That’s because we can use math and probability to determine the variance of an average!
The variance of our prediction is
\[\text{var}(\hat{y}) = \text{var}\left(\frac{1}{k}\sum_{i = 1}^k y_i\right) = \frac{\sigma^2}{k}\]
\[(k \uparrow) \implies \textrm{variance} \downarrow \\ (k \downarrow) \implies \textrm{variance} \uparrow\]
So as we increase the number of neighbors, k, we lower the variance of our predictions! If we lower the variance, then the bias increases :(
If we want to avoid overfitting, we choose a larger value of k: lower variance, but higher bias
If we want to avoid underfitting, we choose a smaller value of k: higher variance, but lower bias
How do we decide on the correct choice of k to use to prevent both over and underfitting? Try out a bunch of different k’s, test them using cross-validation, and use the choice of k and rescale method that minimizes SSE (or equivalently, maximizes \(R^2\) or minimize MAE)
But like we did with part 2, we need to start by rescaling the data!
Since we have to find the k nearest neighbors using distance, we need to rescale the explanatory variables in some way:
standardize: \[\frac{x - \bar{x}}{s}\]
normalize: \[\frac{x - \min(x)}{\max(x) - \min(x)}\]
NOTE: we only rescale the explanatory variables, not the response variable!
Let’s normalize and standardize the explanatory variables of the
grades data set. We can standardize the data using
mutate(), across() and scale(),
but to normalize the data we need create our own
normalize() function
# Standardize the data using scale
counties_stan <-
counties_reg |>
mutate(
across(
.cols = -med_home_income,
.fns = scale
)
)
# Normalize the data
counties_norm <-
counties_reg |>
mutate(
across(
.cols = -med_home_income,
.fns = normalize
)
)
Reminder: to use knn.reg() from FNN, we
specify 4 arguments:
train = the data set that will be used to make the
predictions (the neighbors)test = the data set with the rows we want to
predicttrain data
settrain: no response and
no categorical columnsy = a vector of the response variable for the
train data.y vector needs to be the same as the
number of rows in traink = the choice of nearest neighbors to make the
predictionIf we are using leave-one-out cross-validation, all we need to do is
not specify a test = argument!
Let’s specify k = 10 for the example of how to use the
function and save the result as MHI_knn10
MHI_knn10 <-
knn.reg(
train = counties_stan |> dplyr::select(-med_home_income),
y = counties_reg$med_home_income,
k = 10
)
# knn.reg will return a vector with the average income of the 5 nearest neighbors by using $pred
reg_fit_stats(
y = counties_reg$med_home_income,
y_hat = MHI_knn10$pred
)
## $R2
## [1] 0.8069286
##
## $rmse
## [1] 6.303123
##
## $MAE
## [1] 4.550066
With an unbiased \(R^2\) of 0.8, standardizing the data and using 10 counties does pretty well! But can we do better by either using the normalized data and/or a different choice of k?
We want to find the choice of k that will maximize either \(R^2\) or minimize MAE (maybe both?)
# Let's look over k = 1 to 100
# Start by creating a data.frame named k_search to store the results in:
# It should have three columns: k, R2_norm, MAE_norm
k_search <-
data.frame(
k = 1:100,
R2_norm = -1,
MAE_norm = -1
)
# Looping through the results
for (i in 1:nrow(k_search)){
# Getting the predictions for the ith choice of k
MHI_knn_loop <-
knn.reg(
train = counties_norm |> dplyr::select(-med_home_income),
y = counties_reg$med_home_income,
k = k_search$k[i]
)
# Getting the fit stats
fit_loop <- reg_fit_stats(y = counties_reg$med_home_income,
y_hat = MHI_knn_loop$pred)
# Saving the fit stats
k_search[i, 'R2_norm'] <- fit_loop$R2
k_search[i, 'MAE_norm'] <- fit_loop$MAE
}
k_search
## k R2_norm MAE_norm
## 1 1 0.6675446 6.065318
## 2 2 0.7519314 5.254435
## 3 3 0.7784314 4.935914
## 4 4 0.7900614 4.798888
## 5 5 0.7975102 4.717639
## 6 6 0.7992006 4.670024
## 7 7 0.8007122 4.641139
## 8 8 0.8020705 4.607212
## 9 9 0.8043544 4.589459
## 10 10 0.8062021 4.572159
## 11 11 0.8073656 4.558041
## 12 12 0.8075505 4.543290
## 13 13 0.8083859 4.522559
## 14 14 0.8085412 4.520290
## 15 15 0.8091266 4.512808
## 16 16 0.8082996 4.515934
## 17 17 0.8080665 4.516167
## 18 18 0.8082743 4.505875
## 19 19 0.8083740 4.504779
## 20 20 0.8077975 4.506387
## 21 21 0.8080527 4.505873
## 22 22 0.8085495 4.496427
## 23 23 0.8084738 4.497647
## 24 24 0.8083111 4.498970
## 25 25 0.8076457 4.508050
## 26 26 0.8069016 4.515985
## 27 27 0.8073420 4.510846
## 28 28 0.8069721 4.517150
## 29 29 0.8069940 4.518389
## 30 30 0.8068552 4.515148
## 31 31 0.8063516 4.519297
## 32 32 0.8056909 4.523845
## 33 33 0.8051419 4.524988
## 34 34 0.8041944 4.531141
## 35 35 0.8037763 4.536351
## 36 36 0.8032568 4.542465
## 37 37 0.8027109 4.545915
## 38 38 0.8021257 4.553090
## 39 39 0.8013187 4.559743
## 40 40 0.8010498 4.564400
## 41 41 0.8004806 4.568733
## 42 42 0.8001547 4.566668
## 43 43 0.7993733 4.573258
## 44 44 0.7988755 4.578294
## 45 45 0.7983016 4.582875
## 46 46 0.7979474 4.585823
## 47 47 0.7975303 4.588726
## 48 48 0.7974054 4.589637
## 49 49 0.7968148 4.596466
## 50 50 0.7965134 4.596714
## 51 51 0.7964426 4.599608
## 52 52 0.7959088 4.601997
## 53 53 0.7952178 4.607589
## 54 54 0.7947587 4.609052
## 55 55 0.7943993 4.612234
## 56 56 0.7937193 4.617787
## 57 57 0.7931903 4.623895
## 58 58 0.7928769 4.628721
## 59 59 0.7920557 4.635297
## 60 60 0.7913138 4.641476
## 61 61 0.7907559 4.645431
## 62 62 0.7902528 4.649338
## 63 63 0.7896423 4.653338
## 64 64 0.7890344 4.657682
## 65 65 0.7886638 4.660962
## 66 66 0.7882476 4.664379
## 67 67 0.7878553 4.669323
## 68 68 0.7869796 4.676911
## 69 69 0.7867974 4.680483
## 70 70 0.7861977 4.686460
## 71 71 0.7856718 4.691023
## 72 72 0.7851955 4.695172
## 73 73 0.7849979 4.697802
## 74 74 0.7843509 4.703671
## 75 75 0.7837675 4.710843
## 76 76 0.7834731 4.712923
## 77 77 0.7830754 4.716820
## 78 78 0.7825155 4.720410
## 79 79 0.7821669 4.724352
## 80 80 0.7817866 4.726910
## 81 81 0.7814563 4.727483
## 82 82 0.7807919 4.733292
## 83 83 0.7803109 4.737750
## 84 84 0.7800229 4.741671
## 85 85 0.7793791 4.746023
## 86 86 0.7788405 4.750171
## 87 87 0.7781709 4.756394
## 88 88 0.7777381 4.761858
## 89 89 0.7771281 4.766331
## 90 90 0.7767533 4.770355
## 91 91 0.7764516 4.774732
## 92 92 0.7760304 4.777757
## 93 93 0.7757606 4.780175
## 94 94 0.7752593 4.785054
## 95 95 0.7750488 4.788318
## 96 96 0.7745923 4.793723
## 97 97 0.7741423 4.797630
## 98 98 0.7735078 4.804740
## 99 99 0.7732228 4.807620
## 100 100 0.7727865 4.809856
Plot the results for both \(R^2\) and MAE:
# Plotting the results:
k_search |>
# Placing R2 and MAE in the same column
pivot_longer(
cols = -k,
names_to = 'fit_stat',
values_to = 'stat'
) |>
ggplot(
mapping = aes(
x = k,
y = stat
)
) +
geom_line() +
# separate graph for R2 and MAE
facet_wrap(
facets = vars(fit_stat),
scales = 'free_y',
ncol = 1
)
If we normalize the data, what is the best choice of k?
k_search |>
filter(
R2_norm == max(R2_norm) | MAE_norm == min(MAE_norm)
)
## k R2_norm MAE_norm
## 1 15 0.8091266 4.512808
## 2 22 0.8085495 4.496427
If we normalize the data, k = 15 maximizes \(R^2\) (minimizes SSE) and k = 22 minimizes MAE. So which one do we choose? It depends on your choice of objective function: \((y - \hat{y})^2\) vs \(|y - \hat{y}|\). You should decide on an objective function before running the methods!
What if we standardize the data instead?
# Add two columns to k_search: R2_stan, MAE_stan
k_search <-
k_search |>
mutate(
R2_stan = -1,
MAE_stan = -1
)
# Performing our grid search:
for (i in 1:nrow(k_search)){
# Getting the predictions for the ith choice of k
MHI_knn_loop <-
knn.reg(
train = counties_stan |> dplyr::select(-med_home_income),
y = counties_reg$med_home_income,
k = k_search$k[i]
)
# Getting the fit stats
fit_loop <- reg_fit_stats(y = counties_reg$med_home_income,
y_hat = MHI_knn_loop$pred)
# Saving the fit stats
k_search[i, 'R2_stan'] <- fit_loop$R2
k_search[i, 'MAE_stan'] <- fit_loop$MAE
}
k_search
## k R2_norm MAE_norm R2_stan MAE_stan
## 1 1 0.6675446 6.065318 0.6699069 6.014459
## 2 2 0.7519314 5.254435 0.7542423 5.202664
## 3 3 0.7784314 4.935914 0.7772593 4.941954
## 4 4 0.7900614 4.798888 0.7930770 4.758674
## 5 5 0.7975102 4.717639 0.7992423 4.690358
## 6 6 0.7992006 4.670024 0.8004058 4.647234
## 7 7 0.8007122 4.641139 0.8025986 4.620702
## 8 8 0.8020705 4.607212 0.8041850 4.604911
## 9 9 0.8043544 4.589459 0.8040486 4.588234
## 10 10 0.8062021 4.572159 0.8069286 4.550066
## 11 11 0.8073656 4.558041 0.8080286 4.539754
## 12 12 0.8075505 4.543290 0.8087393 4.540150
## 13 13 0.8083859 4.522559 0.8088272 4.530470
## 14 14 0.8085412 4.520290 0.8097894 4.517681
## 15 15 0.8091266 4.512808 0.8096112 4.516198
## 16 16 0.8082996 4.515934 0.8095091 4.505667
## 17 17 0.8080665 4.516167 0.8095752 4.500479
## 18 18 0.8082743 4.505875 0.8085804 4.506002
## 19 19 0.8083740 4.504779 0.8091728 4.493432
## 20 20 0.8077975 4.506387 0.8102357 4.484447
## 21 21 0.8080527 4.505873 0.8099625 4.491988
## 22 22 0.8085495 4.496427 0.8099426 4.489322
## 23 23 0.8084738 4.497647 0.8096919 4.492594
## 24 24 0.8083111 4.498970 0.8091688 4.492659
## 25 25 0.8076457 4.508050 0.8092626 4.490492
## 26 26 0.8069016 4.515985 0.8091778 4.493355
## 27 27 0.8073420 4.510846 0.8091095 4.493986
## 28 28 0.8069721 4.517150 0.8089509 4.493806
## 29 29 0.8069940 4.518389 0.8086961 4.496555
## 30 30 0.8068552 4.515148 0.8079859 4.501962
## 31 31 0.8063516 4.519297 0.8076841 4.504433
## 32 32 0.8056909 4.523845 0.8075931 4.507040
## 33 33 0.8051419 4.524988 0.8069314 4.508688
## 34 34 0.8041944 4.531141 0.8062657 4.511804
## 35 35 0.8037763 4.536351 0.8051942 4.518190
## 36 36 0.8032568 4.542465 0.8045361 4.520402
## 37 37 0.8027109 4.545915 0.8042531 4.525309
## 38 38 0.8021257 4.553090 0.8035881 4.528298
## 39 39 0.8013187 4.559743 0.8028229 4.537386
## 40 40 0.8010498 4.564400 0.8025711 4.538888
## 41 41 0.8004806 4.568733 0.8019254 4.544112
## 42 42 0.8001547 4.566668 0.8010487 4.551455
## 43 43 0.7993733 4.573258 0.8004058 4.556670
## 44 44 0.7988755 4.578294 0.7995960 4.563862
## 45 45 0.7983016 4.582875 0.7994763 4.567491
## 46 46 0.7979474 4.585823 0.7990485 4.571937
## 47 47 0.7975303 4.588726 0.7983721 4.578422
## 48 48 0.7974054 4.589637 0.7979168 4.583525
## 49 49 0.7968148 4.596466 0.7976584 4.584299
## 50 50 0.7965134 4.596714 0.7972432 4.586348
## 51 51 0.7964426 4.599608 0.7967747 4.590616
## 52 52 0.7959088 4.601997 0.7964695 4.592776
## 53 53 0.7952178 4.607589 0.7960195 4.595082
## 54 54 0.7947587 4.609052 0.7956490 4.598514
## 55 55 0.7943993 4.612234 0.7949235 4.607545
## 56 56 0.7937193 4.617787 0.7948087 4.609059
## 57 57 0.7931903 4.623895 0.7944486 4.612077
## 58 58 0.7928769 4.628721 0.7937655 4.617521
## 59 59 0.7920557 4.635297 0.7935557 4.619113
## 60 60 0.7913138 4.641476 0.7928315 4.625538
## 61 61 0.7907559 4.645431 0.7922046 4.629485
## 62 62 0.7902528 4.649338 0.7916994 4.633592
## 63 63 0.7896423 4.653338 0.7907909 4.640984
## 64 64 0.7890344 4.657682 0.7902763 4.646448
## 65 65 0.7886638 4.660962 0.7897435 4.652882
## 66 66 0.7882476 4.664379 0.7891464 4.658201
## 67 67 0.7878553 4.669323 0.7888886 4.659932
## 68 68 0.7869796 4.676911 0.7881654 4.668349
## 69 69 0.7867974 4.680483 0.7878368 4.670391
## 70 70 0.7861977 4.686460 0.7876446 4.671053
## 71 71 0.7856718 4.691023 0.7871535 4.674955
## 72 72 0.7851955 4.695172 0.7867203 4.676978
## 73 73 0.7849979 4.697802 0.7862640 4.682160
## 74 74 0.7843509 4.703671 0.7856588 4.684782
## 75 75 0.7837675 4.710843 0.7851139 4.691063
## 76 76 0.7834731 4.712923 0.7844048 4.697678
## 77 77 0.7830754 4.716820 0.7840680 4.700343
## 78 78 0.7825155 4.720410 0.7835654 4.704095
## 79 79 0.7821669 4.724352 0.7831503 4.707264
## 80 80 0.7817866 4.726910 0.7828418 4.711663
## 81 81 0.7814563 4.727483 0.7823444 4.716123
## 82 82 0.7807919 4.733292 0.7819249 4.721178
## 83 83 0.7803109 4.737750 0.7818300 4.725003
## 84 84 0.7800229 4.741671 0.7813097 4.730577
## 85 85 0.7793791 4.746023 0.7808430 4.733186
## 86 86 0.7788405 4.750171 0.7803085 4.738028
## 87 87 0.7781709 4.756394 0.7797002 4.742454
## 88 88 0.7777381 4.761858 0.7793457 4.747220
## 89 89 0.7771281 4.766331 0.7787723 4.752488
## 90 90 0.7767533 4.770355 0.7785207 4.756924
## 91 91 0.7764516 4.774732 0.7780599 4.762046
## 92 92 0.7760304 4.777757 0.7775551 4.766998
## 93 93 0.7757606 4.780175 0.7768993 4.774099
## 94 94 0.7752593 4.785054 0.7763276 4.779416
## 95 95 0.7750488 4.788318 0.7759662 4.782436
## 96 96 0.7745923 4.793723 0.7755486 4.787427
## 97 97 0.7741423 4.797630 0.7751063 4.791559
## 98 98 0.7735078 4.804740 0.7745101 4.798249
## 99 99 0.7732228 4.807620 0.7739883 4.803043
## 100 100 0.7727865 4.809856 0.7736391 4.805495
Now, plot the results:
k_search |>
# Placing R2 and MAE in the same column
pivot_longer(
cols = -k,
names_to = 'fit_stat',
values_to = 'stat'
) |>
# Separating the rescale method from the fit stat type using separate
separate(
col = fit_stat,
into = c('fit_stat', 'rescale'),
sep = '_'
) |>
# Creating the line graph
ggplot(
mapping = aes(
x = k,
y = stat,
color = rescale
)
) +
geom_line() +
# separate graph for R2 and MAE
facet_wrap(
facets = vars(fit_stat),
scales = 'free_y',
ncol = 1
)
What is the optimal choice of rescale method and k combination?
k_search |>
filter(
R2_norm == max(R2_norm) | R2_stan == max(R2_stan) |
MAE_norm == min(MAE_norm) | MAE_stan == min(MAE_stan)
)
## k R2_norm MAE_norm R2_stan MAE_stan
## 1 15 0.8091266 4.512808 0.8096112 4.516198
## 2 20 0.8077975 4.506387 0.8102357 4.484447
## 3 22 0.8085495 4.496427 0.8099426 4.489322
Let’s write a function called knn_reg_search() with 3
arguments:
X = a unscaled data frame of the predictors only
(note it is capital X)
y = a vector of the response variable to be
predicted
k = a vector of the choices of k to search
over
It should output a data frame with 5 columns: k,
R2_stan, R2_norm, MAE_stan,
MAE_norm