# Comparing KNN and Linear Regression on Credit Data
# load libraries
library(kknn)
## Warning: package 'kknn' was built under R version 4.4.3
library(ISLR)
## Warning: package 'ISLR' was built under R version 4.4.3
library(glmtoolbox)
## Warning: package 'glmtoolbox' was built under R version 4.4.3
library(DALEX)
## Warning: package 'DALEX' was built under R version 4.4.3
## Welcome to DALEX (version: 2.4.3).
## Find examples and detailed introduction at: http://ema.drwhy.ai/
library(MASS)
library(lmtest)
## Warning: package 'lmtest' was built under R version 4.4.3
## Loading required package: zoo
##
## Attaching package: 'zoo'
## The following objects are masked from 'package:base':
##
## as.Date, as.Date.numeric
library(jtools)
library(broom)
library(car)
## Loading required package: carData
library(vip)
## Warning: package 'vip' was built under R version 4.4.3
##
## Attaching package: 'vip'
## The following object is masked from 'package:DALEX':
##
## titanic
## The following object is masked from 'package:utils':
##
## vi
library(cowplot)
library(caret)
## Warning: package 'caret' was built under R version 4.4.3
## Loading required package: ggplot2
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:kknn':
##
## contr.dummy
library(glmnet)
## Warning: package 'glmnet' was built under R version 4.4.3
## Loading required package: Matrix
## Loaded glmnet 4.1-9
library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.4.3
## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
## ✔ dials 1.4.0 ✔ rsample 1.2.1
## ✔ dplyr 1.1.4 ✔ tibble 3.2.1
## ✔ infer 1.0.7 ✔ tidyr 1.3.1
## ✔ modeldata 1.4.0 ✔ tune 1.3.0
## ✔ parsnip 1.3.1 ✔ workflows 1.2.0
## ✔ purrr 1.0.4 ✔ workflowsets 1.1.0
## ✔ recipes 1.1.1 ✔ yardstick 1.3.2
## Warning: package 'dials' was built under R version 4.4.3
## Warning: package 'infer' was built under R version 4.4.3
## Warning: package 'modeldata' was built under R version 4.4.3
## Warning: package 'parsnip' was built under R version 4.4.3
## Warning: package 'recipes' was built under R version 4.4.3
## Warning: package 'rsample' was built under R version 4.4.3
## Warning: package 'tune' was built under R version 4.4.3
## Warning: package 'workflows' was built under R version 4.4.3
## Warning: package 'workflowsets' was built under R version 4.4.3
## Warning: package 'yardstick' was built under R version 4.4.3
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ tidyr::expand() masks Matrix::expand()
## ✖ dplyr::explain() masks DALEX::explain()
## ✖ dplyr::filter() masks stats::filter()
## ✖ yardstick::get_weights() masks jtools::get_weights()
## ✖ dplyr::lag() masks stats::lag()
## ✖ purrr::lift() masks caret::lift()
## ✖ tidyr::pack() masks Matrix::pack()
## ✖ yardstick::precision() masks caret::precision()
## ✖ yardstick::recall() masks caret::recall()
## ✖ dplyr::recode() masks car::recode()
## ✖ dplyr::select() masks MASS::select()
## ✖ yardstick::sensitivity() masks caret::sensitivity()
## ✖ purrr::some() masks car::some()
## ✖ yardstick::specificity() masks caret::specificity()
## ✖ recipes::step() masks stats::step()
## ✖ tidyr::unpack() masks Matrix::unpack()
## ✖ recipes::update() masks Matrix::update(), stats::update()
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ forcats 1.0.0 ✔ readr 2.1.5
## ✔ lubridate 1.9.4 ✔ stringr 1.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ readr::col_factor() masks scales::col_factor()
## ✖ purrr::discard() masks scales::discard()
## ✖ tidyr::expand() masks Matrix::expand()
## ✖ dplyr::explain() masks DALEX::explain()
## ✖ dplyr::filter() masks stats::filter()
## ✖ stringr::fixed() masks recipes::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ purrr::lift() masks caret::lift()
## ✖ tidyr::pack() masks Matrix::pack()
## ✖ dplyr::recode() masks car::recode()
## ✖ dplyr::select() masks MASS::select()
## ✖ purrr::some() masks car::some()
## ✖ readr::spec() masks yardstick::spec()
## ✖ lubridate::stamp() masks cowplot::stamp()
## ✖ tidyr::unpack() masks Matrix::unpack()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
# Prefer tidymodels ecosystem
tidymodels_prefer()
# Set ggplot them ans bw
theme_set(theme_bw())
## Step 1: Load and prepare data
# Load data
data(Credit)
summary(Credit)
## ID Income Limit Rating
## Min. : 1.0 Min. : 10.35 Min. : 855 Min. : 93.0
## 1st Qu.:100.8 1st Qu.: 21.01 1st Qu.: 3088 1st Qu.:247.2
## Median :200.5 Median : 33.12 Median : 4622 Median :344.0
## Mean :200.5 Mean : 45.22 Mean : 4736 Mean :354.9
## 3rd Qu.:300.2 3rd Qu.: 57.47 3rd Qu.: 5873 3rd Qu.:437.2
## Max. :400.0 Max. :186.63 Max. :13913 Max. :982.0
## Cards Age Education Gender Student
## Min. :1.000 Min. :23.00 Min. : 5.00 Male :193 No :360
## 1st Qu.:2.000 1st Qu.:41.75 1st Qu.:11.00 Female:207 Yes: 40
## Median :3.000 Median :56.00 Median :14.00
## Mean :2.958 Mean :55.67 Mean :13.45
## 3rd Qu.:4.000 3rd Qu.:70.00 3rd Qu.:16.00
## Max. :9.000 Max. :98.00 Max. :20.00
## Married Ethnicity Balance
## No :155 African American: 99 Min. : 0.00
## Yes:245 Asian :102 1st Qu.: 68.75
## Caucasian :199 Median : 459.50
## Mean : 520.01
## 3rd Qu.: 863.00
## Max. :1999.00
# Remove ID column and NA (missing) values
Credit = Credit %>%
select(-ID) %>%
drop_na()
### Step 2: Split data into training and testing sets
set.seed(123) # For reproducibility
split = initial_split(Credit, prop = 0.7, strata = Balance)
train_data = training(split)
test_data = testing(split)
### Step 3: Set common recipe and preprocess train_data
recipe = recipe(Balance ~ ., data = train_data) %>%
step_normalize(all_numeric_predictors()) %>%
step_dummy(all_nominal_predictors())
### Step 4: Define KNN model specification
knn_spec = nearest_neighbor(
mode = 'regression',
neighbors = tune(),
weight_func = tune(),
dist_power = tune() )
### Step 5: Set initial workflow
knn_workflow = workflow() %>%
add_recipe(recipe) %>%
add_model(knn_spec)
### Step 6: Create folds and set tuning grid
set.seed(123)
folds = vfold_cv(train_data, v = 10, strata = Balance)
# knn_grid = grid_regular(
# neighbors(range = c(1, 20)),
# weight_func(values = c("rectangular", "triangular", "optimal")),
# dist_power(range = c(1, 2)),
# levels = 10
# )
# OR using grid_latin_hypercube
knn_grid = grid_latin_hypercube(
neighbors(range = c(1, 20)),
weight_func(values = c("rectangular", "triangular", "optimal")),
dist_power(range = c(1, 2)),
size = 10
)
## Warning: `grid_latin_hypercube()` was deprecated in dials 1.3.0.
## ℹ Please use `grid_space_filling()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
### Step 7: Tune KNN model
knn_tune = tune_grid(
knn_workflow,
resamples = folds,
grid = knn_grid,
metrics = metric_set(rmse, rsq)
)
# See metrics of different folds
knn_tune %>%
collect_metrics() %>%
arrange(mean)
## # A tibble: 20 × 9
## neighbors weight_func dist_power .metric .estimator mean n std_err
## <int> <chr> <dbl> <chr> <chr> <dbl> <int> <dbl>
## 1 3 rectangular 1.92 rsq standard 0.637 10 0.0368
## 2 3 triangular 1.71 rsq standard 0.643 10 0.0349
## 3 11 rectangular 1.87 rsq standard 0.720 10 0.0290
## 4 8 rectangular 1.35 rsq standard 0.728 10 0.0322
## 5 17 optimal 1.67 rsq standard 0.740 10 0.0281
## 6 6 triangular 1.18 rsq standard 0.747 10 0.0273
## 7 16 optimal 1.48 rsq standard 0.750 10 0.0294
## 8 13 optimal 1.26 rsq standard 0.758 10 0.0280
## 9 19 triangular 1.50 rsq standard 0.759 10 0.0260
## 10 10 optimal 1.05 rsq standard 0.773 10 0.0274
## 11 10 optimal 1.05 rmse standard 225. 10 9.64
## 12 6 triangular 1.18 rmse standard 231. 10 10.1
## 13 13 optimal 1.26 rmse standard 235. 10 9.15
## 14 16 optimal 1.48 rmse standard 241. 10 8.90
## 15 19 triangular 1.50 rmse standard 242. 10 8.10
## 16 17 optimal 1.67 rmse standard 247. 10 8.97
## 17 8 rectangular 1.35 rmse standard 247. 10 10.0
## 18 11 rectangular 1.87 rmse standard 257. 10 8.96
## 19 3 triangular 1.71 rmse standard 272. 10 13.8
## 20 3 rectangular 1.92 rmse standard 273. 10 11.2
## # ℹ 1 more variable: .config <chr>
# Plot tuning results
knn_tune %>% autoplot()

# Find best hyperparameters
best_knn = knn_tune %>%
select_best(metric = "rmse")
best_knn
## # A tibble: 1 × 4
## neighbors weight_func dist_power .config
## <int> <chr> <dbl> <chr>
## 1 10 optimal 1.05 Preprocessor1_Model01
### Step 8: Finalize the KNN workflow
final_knn_workflow = knn_workflow %>%
finalize_workflow(best_knn)
### Step 9: Fit the final KNN model on the training data
final_knn_fit = final_knn_workflow %>%
fit(data = train_data)
### Step 10: Evaluate the KNN model on the test data
knn_predictions = final_knn_fit %>%
predict(new_data = test_data) %>%
bind_cols(test_data)
# .pred column is the predicted values of Balance
knn_predictions
## # A tibble: 121 × 12
## .pred Income Limit Rating Cards Age Education Gender Student Married
## <dbl> <dbl> <int> <int> <int> <int> <int> <fct> <fct> <fct>
## 1 1193. 106. 6645 483 3 82 15 "Female" Yes Yes
## 2 362. 55.9 4897 357 2 68 16 " Male" No Yes
## 3 96.2 21.0 3388 259 2 37 12 "Female" No No
## 4 0 15.0 1311 138 3 64 16 " Male" No No
## 5 60.5 20.1 2525 200 3 57 15 "Female" No Yes
## 6 285. 53.6 3714 286 3 73 17 "Female" No Yes
## 7 631. 42.1 6626 479 2 44 9 " Male" No No
## 8 896. 37.3 6378 458 1 72 17 "Female" No No
## 9 355. 14.1 4323 326 5 25 16 "Female" No Yes
## 10 525. 42.5 3625 289 6 44 12 "Female" Yes No
## # ℹ 111 more rows
## # ℹ 2 more variables: Ethnicity <fct>, Balance <int>
# Calculate RMSE and R-squared
knn_metrics = knn_predictions %>%
metrics(truth = Balance, estimate = .pred)
knn_metrics
## # A tibble: 3 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 234.
## 2 rsq standard 0.793
## 3 mae standard 172.
# Plot predictions vs actual values
ggplot(knn_predictions, aes(x = Balance, y = .pred)) +
geom_point() +
geom_smooth(method = 'lm') +
labs(title = "KNN Predictions vs Actual Values",
x = "Actual Balance",
y = "Predicted Balance")
## `geom_smooth()` using formula = 'y ~ x'

### Run linear regression model using tidymodels approach
# Linear regress specification
lin_reg_spec = linear_reg(mode = 'regression')
# Set linear regression workflow
lin_reg_workflow = workflow() %>%
add_recipe(recipe) %>%
add_model(lin_reg_spec)
# Train the linear regression model on the train data
final_lm = lin_reg_workflow %>% fit(data = train_data)
# Make predictions on the test data
lm_predictions = final_lm %>%
predict(new_data = test_data) %>%
bind_cols(test_data)
# Calculate RMSE and R-squared for linear regression
metrics(lm_predictions, truth = Balance, estimate = .pred)
## # A tibble: 3 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 92.3
## 2 rsq standard 0.964
## 3 mae standard 75.8
# Calculate RMSE and R-squared for KNN regression
metrics(knn_predictions, truth = Balance, estimate = .pred)
## # A tibble: 3 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 234.
## 2 rsq standard 0.793
## 3 mae standard 172.
# Predicting the test data
test_results = test_data %>%
mutate(
.pred_knn = knn_predictions$.pred,
.pred_lm = lm_predictions$.pred
)
head(test_results)
## Income Limit Rating Cards Age Education Gender Student Married
## 1 106.025 6645 483 3 82 15 Female Yes Yes
## 2 55.882 4897 357 2 68 16 Male No Yes
## 3 20.996 3388 259 2 37 12 Female No No
## 4 15.045 1311 138 3 64 16 Male No No
## 5 20.089 2525 200 3 57 15 Female No Yes
## 6 53.598 3714 286 3 73 17 Female No Yes
## Ethnicity Balance .pred_knn .pred_lm
## 1 Asian 903 1193.08906 914.26371
## 2 Caucasian 331 361.90311 409.04140
## 3 African American 203 96.21087 291.12924
## 4 Caucasian 0 0.00000 -174.83619
## 5 African American 0 60.53106 63.96688
## 6 African American 0 285.21505 113.45598
# Combining the metrics
lm_metrics = metrics(lm_predictions, truth = Balance, estimate = .pred) %>% mutate(model = "Linear Regression")
knn_metrics = metrics(knn_predictions, truth = Balance, estimate = .pred) %>% mutate(model = "KNN")
combined_metrics = bind_rows(lm_metrics, knn_metrics)
combined_metrics %>% arrange(.metric)
## # A tibble: 6 × 4
## .metric .estimator .estimate model
## <chr> <chr> <dbl> <chr>
## 1 mae standard 75.8 Linear Regression
## 2 mae standard 172. KNN
## 3 rmse standard 92.3 Linear Regression
## 4 rmse standard 234. KNN
## 5 rsq standard 0.964 Linear Regression
## 6 rsq standard 0.793 KNN
# Visual comparison the test results
head(test_results)
## Income Limit Rating Cards Age Education Gender Student Married
## 1 106.025 6645 483 3 82 15 Female Yes Yes
## 2 55.882 4897 357 2 68 16 Male No Yes
## 3 20.996 3388 259 2 37 12 Female No No
## 4 15.045 1311 138 3 64 16 Male No No
## 5 20.089 2525 200 3 57 15 Female No Yes
## 6 53.598 3714 286 3 73 17 Female No Yes
## Ethnicity Balance .pred_knn .pred_lm
## 1 Asian 903 1193.08906 914.26371
## 2 Caucasian 331 361.90311 409.04140
## 3 African American 203 96.21087 291.12924
## 4 Caucasian 0 0.00000 -174.83619
## 5 African American 0 60.53106 63.96688
## 6 African American 0 285.21505 113.45598
ggplot(test_results, aes(x = Balance)) +
geom_point(aes(y = .pred_knn, color = "KNN")) +
geom_point(aes(y = .pred_lm, color = "Linear Regression")) +
geom_smooth(aes(y = .pred_knn), method = 'lm', color = "blue") +
geom_smooth(aes(y = .pred_lm), method = 'lm', color = "red") +
labs(title = "KNN vs Linear Regression Predictions",
x = "Actual Balance",
y = "Predicted Balance") +
scale_color_manual(values = c("KNN" = "blue", "Linear Regression" = "red")) +
theme_minimal()
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
