# Comparing KNN and Linear Regression on Credit Data

# load libraries
library(kknn)
## Warning: package 'kknn' was built under R version 4.4.3
library(ISLR)
## Warning: package 'ISLR' was built under R version 4.4.3
library(glmtoolbox)
## Warning: package 'glmtoolbox' was built under R version 4.4.3
library(DALEX)
## Warning: package 'DALEX' was built under R version 4.4.3
## Welcome to DALEX (version: 2.4.3).
## Find examples and detailed introduction at: http://ema.drwhy.ai/
library(MASS)
library(lmtest)
## Warning: package 'lmtest' was built under R version 4.4.3
## Loading required package: zoo
## 
## Attaching package: 'zoo'
## The following objects are masked from 'package:base':
## 
##     as.Date, as.Date.numeric
library(jtools)
library(broom)
library(car)
## Loading required package: carData
library(vip)
## Warning: package 'vip' was built under R version 4.4.3
## 
## Attaching package: 'vip'
## The following object is masked from 'package:DALEX':
## 
##     titanic
## The following object is masked from 'package:utils':
## 
##     vi
library(cowplot)
library(caret)
## Warning: package 'caret' was built under R version 4.4.3
## Loading required package: ggplot2
## Loading required package: lattice
## 
## Attaching package: 'caret'
## The following object is masked from 'package:kknn':
## 
##     contr.dummy
library(glmnet)
## Warning: package 'glmnet' was built under R version 4.4.3
## Loading required package: Matrix
## Loaded glmnet 4.1-9
library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.4.3
## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
## ✔ dials        1.4.0     ✔ rsample      1.2.1
## ✔ dplyr        1.1.4     ✔ tibble       3.2.1
## ✔ infer        1.0.7     ✔ tidyr        1.3.1
## ✔ modeldata    1.4.0     ✔ tune         1.3.0
## ✔ parsnip      1.3.1     ✔ workflows    1.2.0
## ✔ purrr        1.0.4     ✔ workflowsets 1.1.0
## ✔ recipes      1.1.1     ✔ yardstick    1.3.2
## Warning: package 'dials' was built under R version 4.4.3
## Warning: package 'infer' was built under R version 4.4.3
## Warning: package 'modeldata' was built under R version 4.4.3
## Warning: package 'parsnip' was built under R version 4.4.3
## Warning: package 'recipes' was built under R version 4.4.3
## Warning: package 'rsample' was built under R version 4.4.3
## Warning: package 'tune' was built under R version 4.4.3
## Warning: package 'workflows' was built under R version 4.4.3
## Warning: package 'workflowsets' was built under R version 4.4.3
## Warning: package 'yardstick' was built under R version 4.4.3
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard()         masks scales::discard()
## ✖ tidyr::expand()          masks Matrix::expand()
## ✖ dplyr::explain()         masks DALEX::explain()
## ✖ dplyr::filter()          masks stats::filter()
## ✖ yardstick::get_weights() masks jtools::get_weights()
## ✖ dplyr::lag()             masks stats::lag()
## ✖ purrr::lift()            masks caret::lift()
## ✖ tidyr::pack()            masks Matrix::pack()
## ✖ yardstick::precision()   masks caret::precision()
## ✖ yardstick::recall()      masks caret::recall()
## ✖ dplyr::recode()          masks car::recode()
## ✖ dplyr::select()          masks MASS::select()
## ✖ yardstick::sensitivity() masks caret::sensitivity()
## ✖ purrr::some()            masks car::some()
## ✖ yardstick::specificity() masks caret::specificity()
## ✖ recipes::step()          masks stats::step()
## ✖ tidyr::unpack()          masks Matrix::unpack()
## ✖ recipes::update()        masks Matrix::update(), stats::update()
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ forcats   1.0.0     ✔ readr     2.1.5
## ✔ lubridate 1.9.4     ✔ stringr   1.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ readr::col_factor() masks scales::col_factor()
## ✖ purrr::discard()    masks scales::discard()
## ✖ tidyr::expand()     masks Matrix::expand()
## ✖ dplyr::explain()    masks DALEX::explain()
## ✖ dplyr::filter()     masks stats::filter()
## ✖ stringr::fixed()    masks recipes::fixed()
## ✖ dplyr::lag()        masks stats::lag()
## ✖ purrr::lift()       masks caret::lift()
## ✖ tidyr::pack()       masks Matrix::pack()
## ✖ dplyr::recode()     masks car::recode()
## ✖ dplyr::select()     masks MASS::select()
## ✖ purrr::some()       masks car::some()
## ✖ readr::spec()       masks yardstick::spec()
## ✖ lubridate::stamp()  masks cowplot::stamp()
## ✖ tidyr::unpack()     masks Matrix::unpack()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
# Prefer tidymodels ecosystem
tidymodels_prefer()

# Set ggplot them ans bw
theme_set(theme_bw())

## Step 1: Load and prepare data
# Load data
data(Credit)

summary(Credit)
##        ID            Income           Limit           Rating     
##  Min.   :  1.0   Min.   : 10.35   Min.   :  855   Min.   : 93.0  
##  1st Qu.:100.8   1st Qu.: 21.01   1st Qu.: 3088   1st Qu.:247.2  
##  Median :200.5   Median : 33.12   Median : 4622   Median :344.0  
##  Mean   :200.5   Mean   : 45.22   Mean   : 4736   Mean   :354.9  
##  3rd Qu.:300.2   3rd Qu.: 57.47   3rd Qu.: 5873   3rd Qu.:437.2  
##  Max.   :400.0   Max.   :186.63   Max.   :13913   Max.   :982.0  
##      Cards            Age          Education        Gender    Student  
##  Min.   :1.000   Min.   :23.00   Min.   : 5.00    Male :193   No :360  
##  1st Qu.:2.000   1st Qu.:41.75   1st Qu.:11.00   Female:207   Yes: 40  
##  Median :3.000   Median :56.00   Median :14.00                         
##  Mean   :2.958   Mean   :55.67   Mean   :13.45                         
##  3rd Qu.:4.000   3rd Qu.:70.00   3rd Qu.:16.00                         
##  Max.   :9.000   Max.   :98.00   Max.   :20.00                         
##  Married              Ethnicity      Balance       
##  No :155   African American: 99   Min.   :   0.00  
##  Yes:245   Asian           :102   1st Qu.:  68.75  
##            Caucasian       :199   Median : 459.50  
##                                   Mean   : 520.01  
##                                   3rd Qu.: 863.00  
##                                   Max.   :1999.00
# Remove ID column and NA (missing) values
Credit = Credit %>%
  select(-ID) %>%
  drop_na()


### Step 2: Split data into training and testing sets
set.seed(123)  # For reproducibility

split = initial_split(Credit, prop = 0.7, strata = Balance)
train_data = training(split)
test_data = testing(split)

### Step 3: Set common recipe and preprocess train_data
recipe = recipe(Balance ~ ., data = train_data) %>%
  step_normalize(all_numeric_predictors()) %>%
  step_dummy(all_nominal_predictors())

### Step 4: Define KNN model specification
knn_spec = nearest_neighbor(
  mode = 'regression',
  neighbors = tune(),
  weight_func = tune(),
  dist_power = tune()      )

### Step 5: Set initial workflow
knn_workflow = workflow() %>% 
  add_recipe(recipe) %>% 
  add_model(knn_spec)

### Step 6: Create folds and set tuning grid
set.seed(123)

folds = vfold_cv(train_data, v = 10, strata = Balance)

# knn_grid = grid_regular(
#   neighbors(range = c(1, 20)),
#   weight_func(values = c("rectangular", "triangular", "optimal")),
#   dist_power(range = c(1, 2)),
#   levels = 10
# )

# OR using grid_latin_hypercube

knn_grid = grid_latin_hypercube(
  neighbors(range = c(1, 20)),
  weight_func(values = c("rectangular", "triangular", "optimal")),
  dist_power(range = c(1, 2)),
  size = 10
)
## Warning: `grid_latin_hypercube()` was deprecated in dials 1.3.0.
## ℹ Please use `grid_space_filling()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
### Step 7: Tune KNN model
knn_tune = tune_grid(
  knn_workflow,
  resamples = folds,
  grid = knn_grid,
  metrics = metric_set(rmse, rsq)
)

# See metrics of different folds
knn_tune %>%
  collect_metrics() %>%
  arrange(mean)
## # A tibble: 20 × 9
##    neighbors weight_func dist_power .metric .estimator    mean     n std_err
##        <int> <chr>            <dbl> <chr>   <chr>        <dbl> <int>   <dbl>
##  1         3 rectangular       1.92 rsq     standard     0.637    10  0.0368
##  2         3 triangular        1.71 rsq     standard     0.643    10  0.0349
##  3        11 rectangular       1.87 rsq     standard     0.720    10  0.0290
##  4         8 rectangular       1.35 rsq     standard     0.728    10  0.0322
##  5        17 optimal           1.67 rsq     standard     0.740    10  0.0281
##  6         6 triangular        1.18 rsq     standard     0.747    10  0.0273
##  7        16 optimal           1.48 rsq     standard     0.750    10  0.0294
##  8        13 optimal           1.26 rsq     standard     0.758    10  0.0280
##  9        19 triangular        1.50 rsq     standard     0.759    10  0.0260
## 10        10 optimal           1.05 rsq     standard     0.773    10  0.0274
## 11        10 optimal           1.05 rmse    standard   225.       10  9.64  
## 12         6 triangular        1.18 rmse    standard   231.       10 10.1   
## 13        13 optimal           1.26 rmse    standard   235.       10  9.15  
## 14        16 optimal           1.48 rmse    standard   241.       10  8.90  
## 15        19 triangular        1.50 rmse    standard   242.       10  8.10  
## 16        17 optimal           1.67 rmse    standard   247.       10  8.97  
## 17         8 rectangular       1.35 rmse    standard   247.       10 10.0   
## 18        11 rectangular       1.87 rmse    standard   257.       10  8.96  
## 19         3 triangular        1.71 rmse    standard   272.       10 13.8   
## 20         3 rectangular       1.92 rmse    standard   273.       10 11.2   
## # ℹ 1 more variable: .config <chr>
# Plot tuning results
knn_tune %>% autoplot()

# Find best hyperparameters
best_knn = knn_tune %>%
  select_best(metric = "rmse")

best_knn
## # A tibble: 1 × 4
##   neighbors weight_func dist_power .config              
##       <int> <chr>            <dbl> <chr>                
## 1        10 optimal           1.05 Preprocessor1_Model01
### Step 8: Finalize the KNN workflow
final_knn_workflow = knn_workflow %>%
  finalize_workflow(best_knn)

### Step 9: Fit the final KNN model on the training data
final_knn_fit = final_knn_workflow %>%
  fit(data = train_data)

### Step 10: Evaluate the KNN model on the test data
knn_predictions = final_knn_fit %>%
  predict(new_data = test_data) %>%
  bind_cols(test_data)

# .pred column is the predicted values of Balance
knn_predictions
## # A tibble: 121 × 12
##     .pred Income Limit Rating Cards   Age Education Gender   Student Married
##     <dbl>  <dbl> <int>  <int> <int> <int>     <int> <fct>    <fct>   <fct>  
##  1 1193.   106.   6645    483     3    82        15 "Female" Yes     Yes    
##  2  362.    55.9  4897    357     2    68        16 " Male"  No      Yes    
##  3   96.2   21.0  3388    259     2    37        12 "Female" No      No     
##  4    0     15.0  1311    138     3    64        16 " Male"  No      No     
##  5   60.5   20.1  2525    200     3    57        15 "Female" No      Yes    
##  6  285.    53.6  3714    286     3    73        17 "Female" No      Yes    
##  7  631.    42.1  6626    479     2    44         9 " Male"  No      No     
##  8  896.    37.3  6378    458     1    72        17 "Female" No      No     
##  9  355.    14.1  4323    326     5    25        16 "Female" No      Yes    
## 10  525.    42.5  3625    289     6    44        12 "Female" Yes     No     
## # ℹ 111 more rows
## # ℹ 2 more variables: Ethnicity <fct>, Balance <int>
# Calculate RMSE and R-squared
knn_metrics = knn_predictions %>%
  metrics(truth = Balance, estimate = .pred)

knn_metrics
## # A tibble: 3 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard     234.   
## 2 rsq     standard       0.793
## 3 mae     standard     172.
# Plot predictions vs actual values
ggplot(knn_predictions, aes(x = Balance, y = .pred)) +
  geom_point() +
  geom_smooth(method = 'lm') +
  labs(title = "KNN Predictions vs Actual Values",
       x = "Actual Balance",
       y = "Predicted Balance")
## `geom_smooth()` using formula = 'y ~ x'

### Run linear regression model using tidymodels approach

# Linear regress specification

lin_reg_spec = linear_reg(mode = 'regression')

# Set linear regression workflow
lin_reg_workflow = workflow() %>%
  add_recipe(recipe) %>%
  add_model(lin_reg_spec)

# Train the linear regression model on the train data
final_lm = lin_reg_workflow %>% fit(data = train_data)

# Make predictions on the test data
lm_predictions = final_lm %>%
  predict(new_data = test_data) %>%
  bind_cols(test_data)

# Calculate RMSE and R-squared for linear regression
metrics(lm_predictions, truth = Balance, estimate = .pred)
## # A tibble: 3 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard      92.3  
## 2 rsq     standard       0.964
## 3 mae     standard      75.8
# Calculate RMSE and R-squared for KNN regression
metrics(knn_predictions, truth = Balance, estimate = .pred)
## # A tibble: 3 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard     234.   
## 2 rsq     standard       0.793
## 3 mae     standard     172.
# Predicting the test data
test_results = test_data %>% 
  mutate(
    .pred_knn = knn_predictions$.pred,
    .pred_lm = lm_predictions$.pred
  )

head(test_results)
##    Income Limit Rating Cards Age Education Gender Student Married
## 1 106.025  6645    483     3  82        15 Female     Yes     Yes
## 2  55.882  4897    357     2  68        16   Male      No     Yes
## 3  20.996  3388    259     2  37        12 Female      No      No
## 4  15.045  1311    138     3  64        16   Male      No      No
## 5  20.089  2525    200     3  57        15 Female      No     Yes
## 6  53.598  3714    286     3  73        17 Female      No     Yes
##          Ethnicity Balance  .pred_knn   .pred_lm
## 1            Asian     903 1193.08906  914.26371
## 2        Caucasian     331  361.90311  409.04140
## 3 African American     203   96.21087  291.12924
## 4        Caucasian       0    0.00000 -174.83619
## 5 African American       0   60.53106   63.96688
## 6 African American       0  285.21505  113.45598
# Combining the metrics

lm_metrics = metrics(lm_predictions, truth = Balance, estimate = .pred) %>% mutate(model = "Linear Regression")

knn_metrics = metrics(knn_predictions, truth = Balance, estimate = .pred) %>% mutate(model = "KNN")

combined_metrics = bind_rows(lm_metrics, knn_metrics)
combined_metrics %>% arrange(.metric)
## # A tibble: 6 × 4
##   .metric .estimator .estimate model            
##   <chr>   <chr>          <dbl> <chr>            
## 1 mae     standard      75.8   Linear Regression
## 2 mae     standard     172.    KNN              
## 3 rmse    standard      92.3   Linear Regression
## 4 rmse    standard     234.    KNN              
## 5 rsq     standard       0.964 Linear Regression
## 6 rsq     standard       0.793 KNN
# Visual comparison the test results
head(test_results)
##    Income Limit Rating Cards Age Education Gender Student Married
## 1 106.025  6645    483     3  82        15 Female     Yes     Yes
## 2  55.882  4897    357     2  68        16   Male      No     Yes
## 3  20.996  3388    259     2  37        12 Female      No      No
## 4  15.045  1311    138     3  64        16   Male      No      No
## 5  20.089  2525    200     3  57        15 Female      No     Yes
## 6  53.598  3714    286     3  73        17 Female      No     Yes
##          Ethnicity Balance  .pred_knn   .pred_lm
## 1            Asian     903 1193.08906  914.26371
## 2        Caucasian     331  361.90311  409.04140
## 3 African American     203   96.21087  291.12924
## 4        Caucasian       0    0.00000 -174.83619
## 5 African American       0   60.53106   63.96688
## 6 African American       0  285.21505  113.45598
ggplot(test_results, aes(x = Balance)) +
  geom_point(aes(y = .pred_knn, color = "KNN")) +
  geom_point(aes(y = .pred_lm, color = "Linear Regression")) +
  geom_smooth(aes(y = .pred_knn), method = 'lm', color = "blue") +
  geom_smooth(aes(y = .pred_lm), method = 'lm', color = "red") +
  labs(title = "KNN vs Linear Regression Predictions",
       x = "Actual Balance",
       y = "Predicted Balance") +
  scale_color_manual(values = c("KNN" = "blue", "Linear Regression" = "red")) +
  theme_minimal()
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'