R Markdown

Goal: To predict transit costs Click here for the data

transit_cost <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-01-05/transit_cost.csv')

skimr::skim(transit_cost)
Data summary
Name transit_cost
Number of rows 544
Number of columns 20
_______________________
Column type frequency:
character 11
numeric 9
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
country 7 0.99 2 2 0 56 0
city 7 0.99 4 16 0 140 0
line 7 0.99 2 46 0 366 0
start_year 53 0.90 4 9 0 40 0
end_year 71 0.87 1 4 0 36 0
tunnel_per 32 0.94 5 7 0 134 0
source1 12 0.98 4 54 0 17 0
currency 7 0.99 2 3 0 39 0
real_cost 0 1.00 1 10 0 534 0
source2 10 0.98 3 16 0 12 0
reference 19 0.97 3 302 0 350 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
e 7 0.99 7738.76 463.23 7136.00 7403.00 7705.00 7977.00 9510.00 ▇▇▂▁▁
rr 8 0.99 0.06 0.24 0.00 0.00 0.00 0.00 1.00 ▇▁▁▁▁
length 5 0.99 58.34 621.20 0.60 6.50 15.77 29.08 12256.98 ▇▁▁▁▁
tunnel 32 0.94 29.38 344.04 0.00 3.40 8.91 21.52 7790.78 ▇▁▁▁▁
stations 15 0.97 13.81 13.70 0.00 4.00 10.00 20.00 128.00 ▇▁▁▁▁
cost 7 0.99 805438.12 6708033.07 0.00 2289.00 11000.00 27000.00 90000000.00 ▇▁▁▁▁
year 7 0.99 2014.91 5.64 1987.00 2012.00 2016.00 2019.00 2027.00 ▁▁▂▇▂
ppp_rate 9 0.98 0.66 0.87 0.00 0.24 0.26 1.00 5.00 ▇▂▁▁▁
cost_km_millions 2 1.00 232.98 257.22 7.79 134.86 181.25 241.43 3928.57 ▇▁▁▁▁
data1 <- transit_cost %>%
  
  # Treat missing values
  select(-line, -source1, -currency, -source2, -reference, -country) %>%

  
  # log transform variables with pos-skewed distribution
  mutate(cost = log(cost + 1)) %>%
  mutate(real_cost = as.numeric(real_cost)) %>%
  mutate(tunnel_per = tunnel_per %>% str_remove("%") %>% as.numeric()) %>%
  mutate(start_year = as.numeric(start_year)) %>%
  mutate(end_year = as.numeric(end_year)) %>%
  na.omit()
## Warning: There was 1 warning in `mutate()`.
## ℹ In argument: `real_cost = as.numeric(real_cost)`.
## Caused by warning:
## ! NAs introduced by coercion
## Warning: There was 1 warning in `mutate()`.
## ℹ In argument: `start_year = as.numeric(start_year)`.
## Caused by warning:
## ! NAs introduced by coercion
## Warning: There was 1 warning in `mutate()`.
## ℹ In argument: `end_year = as.numeric(end_year)`.
## Caused by warning:
## ! NAs introduced by coercion

Identify good predictors

Year

data1 %>% 
  ggplot(aes(cost, year)) +
  scale_y_log10() +
  geom_point()

Explore Data

Correlation between city and Cost

 transit_cost <- data1 %>%
  count(e, sort = TRUE) %>%
  head(20) %>%
  pull(e)

transit_cost
##  [1] 9459 9460 7136 7137 7138 7139 7144 7145 7146 7147 7152 7153 7154 7155 7160
## [16] 7161 7162 7163 7168 7169
data1 %>%
  
  #Filter for top 20 transit costs by city
  # filter(city %in% transit_cost) %>%
  ggplot(aes(cost, fct_reorder(city, cost))) +
  geom_boxplot() 

EDA Shortcut

EDA Shortcut

# Step 1: Prepare the data

data_binarized_tbl1 <- data1 %>%
  binarize() 

data_binarized_tbl1 %>% glimpse(.)
## Rows: 436
## Columns: 76
## $ `e__-Inf_7354.75`                       <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ e__7354.75_7589.5                       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ e__7589.5_7923.5                        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ e__7923.5_Inf                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Bangkok                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Barcelona                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Beijing                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Changchun                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Changsha                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Chengdu                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Chongqing                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Dongguan                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Guangzhou                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Guiyang                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Hangzhou                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Istanbul                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Madrid                            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Mumbai                            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Nanjing                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__New_York                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Paris                             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Riyadh                            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Seoul                             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Shanghai                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Shenzhen                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Sofia                             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Taipei                            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Tianjin                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Tokyo                             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ city__Toronto                           <dbl> 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ city__Wuhan                             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `city__-OTHER`                          <dbl> 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, …
## $ `start_year__-Inf_2009`                 <dbl> 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, …
## $ start_year__2009_2015                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ start_year__2015_2018                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ start_year__2018_Inf                    <dbl> 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ `end_year__-Inf_2016`                   <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ end_year__2016_2020                     <dbl> 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ end_year__2020_2023                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ end_year__2023_Inf                      <dbl> 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, …
## $ rr__0                                   <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ rr__1                                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `length__-Inf_6.1`                      <dbl> 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ length__6.1_14.75                       <dbl> 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, …
## $ length__14.75_28.2                      <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ length__28.2_Inf                        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `tunnel_per__-Inf_49.09`                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tunnel_per__49.09_Inf                   <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `tunnel__-Inf_3.275`                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tunnel__3.275_8.4                       <dbl> 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, …
## $ tunnel__8.4_20                          <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ tunnel__20_Inf                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `stations__-Inf_4`                      <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, …
## $ stations__4_10                          <dbl> 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ stations__10_20                         <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ stations__20_Inf                        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `cost__-Inf_7.4713630881871`            <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ cost__7.4713630881871_9.15883647922657  <dbl> 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, …
## $ cost__9.15883647922657_10.2633708195977 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cost__10.2633708195977_Inf              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `year__-Inf_2012`                       <dbl> 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, …
## $ year__2012_2016                         <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ year__2016_2018                         <dbl> 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, …
## $ year__2018_Inf                          <dbl> 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, …
## $ `ppp_rate__-Inf_0.2379`                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ppp_rate__0.2379_0.266                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ppp_rate__0.266_1.25                    <dbl> 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, …
## $ ppp_rate__1.25_Inf                      <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `real_cost__-Inf_1168.71`               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ real_cost__1168.71_2977.92              <dbl> 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, …
## $ real_cost__2977.92_5540.5525            <dbl> 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, …
## $ real_cost__5540.5525_Inf                <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ `cost_km_millions__-Inf_132.9840103`    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__132.9840103_184.085   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__184.085_243.2825      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__243.2825_Inf          <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
# Step 2 Correlate
 data_corr_tbl <- data_binarized_tbl1 %>%
  correlate(cost__10.2633708195977_Inf)

data_corr_tbl
## # A tibble: 76 × 3
##    feature    bin                               correlation
##    <fct>      <chr>                                   <dbl>
##  1 cost       10.2633708195977_Inf                    1    
##  2 ppp_rate   -Inf_0.2379                             0.684
##  3 real_cost  5540.5525_Inf                           0.376
##  4 cost       -Inf_7.4713630881871                   -0.335
##  5 cost       9.15883647922657_10.2633708195977      -0.333
##  6 cost       7.4713630881871_9.15883647922657       -0.331
##  7 ppp_rate   1.25_Inf                               -0.315
##  8 tunnel_per -Inf_49.09                              0.291
##  9 tunnel_per 49.09_Inf                              -0.291
## 10 ppp_rate   0.266_1.25                             -0.287
## # ℹ 66 more rows
# Step 3 Plot 
data_corr_tbl %>%
  plot_correlation_funnel()
## Warning: ggrepel: 26 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Prprocess Data

Build Models

Split Data

#data1 <- sample_n(data, 100)

# Split into train and test data set
set.seed(1234)
data_split <- rsample:: initial_split(data1)
data_train <- training(data_split)
data_test <- testing(data_split)

# Further split training data set for cross-validation
set.seed(2345)
data_cv <- rsample::vfold_cv(data_train)
data_cv
## #  10-fold cross-validation 
## # A tibble: 10 × 2
##    splits           id    
##    <list>           <chr> 
##  1 <split [294/33]> Fold01
##  2 <split [294/33]> Fold02
##  3 <split [294/33]> Fold03
##  4 <split [294/33]> Fold04
##  5 <split [294/33]> Fold05
##  6 <split [294/33]> Fold06
##  7 <split [294/33]> Fold07
##  8 <split [295/32]> Fold08
##  9 <split [295/32]> Fold09
## 10 <split [295/32]> Fold10
library(usemodels)
## Warning: package 'usemodels' was built under R version 4.3.2
usemodels:: use_xgboost(cost ~ ., data1)
## xgboost_recipe <- 
##   recipe(formula = cost ~ ., data = data1) %>% 
##   step_zv(all_predictors()) 
## 
## xgboost_spec <- 
##   boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
##     loss_reduction = tune(), sample_size = tune()) %>% 
##   set_mode("classification") %>% 
##   set_engine("xgboost") 
## 
## xgboost_workflow <- 
##   workflow() %>% 
##   add_recipe(xgboost_recipe) %>% 
##   add_model(xgboost_spec) 
## 
## set.seed(24468)
## xgboost_tune <-
##   tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
# Specify Recipe

xgboost_recipe <- 
  recipe(formula = cost ~ ., data = data_train) %>%
  step_other(city, threshold = 0.01) %>%
  step_dummy(city) %>%
  step_YeoJohnson(year, cost, stations)

xgboost_recipe %>% prep() %>%  juice() %>% glimpse()
## Rows: 327
## Columns: 39
## $ e                <dbl> 7808, 7945, 8177, 7338, 7360, 8139, 7408, 9462, 8163,…
## $ start_year       <dbl> 2019, 2016, 2014, 2013, 2010, 2005, 2007, 2016, 2022,…
## $ end_year         <dbl> 2024, 2020, 2017, 2019, 2013, 2008, 2014, 2020, 2027,…
## $ rr               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ length           <dbl> 7.830, 34.800, 40.700, 37.000, 4.200, 28.100, 1.600, …
## $ tunnel_per       <dbl> 100.00, 100.00, 14.00, 100.00, 0.00, 0.00, 100.00, 10…
## $ tunnel           <dbl> 7.830, 34.800, 5.700, 37.000, 0.000, 0.000, 1.600, 38…
## $ stations         <dbl> 2.0794415, 2.8903718, 3.0445224, 3.3672958, 1.3862944…
## $ year             <dbl> 2019, 2015, 2013, 2016, 2012, 2005, 2011, 2013, 2019,…
## $ ppp_rate         <dbl> 0.2382, 0.2583, 1.8200, 1.7000, 1.3000, 0.3517, 1.000…
## $ real_cost        <dbl> 1775.13, 6174.12, 9103.64, 5100.00, 416.00, 218.89, 2…
## $ cost_km_millions <dbl> 226.71000, 177.42000, 223.67666, 137.83784, 99.04762,…
## $ cost             <dbl> 3.632871, 3.900502, 3.536830, 3.409920, 2.794208, 2.9…
## $ city_Barcelona   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Beijing     <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,…
## $ city_Berlin      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Changchun   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Changsha    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,…
## $ city_Chengdu     <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Chongqing   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Dongguan    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Guangzhou   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Guiyang     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Hangzhou    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,…
## $ city_Istanbul    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Mumbai      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,…
## $ city_Nanjing     <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_New.York    <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Paris       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Riyadh      <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Rome        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Shanghai    <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Shenzhen    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Stockholm   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Taipei      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Tianjin     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Tokyo       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_Wuhan       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ city_other       <dbl> 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,…
# Specify Model
  
xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
    loss_reduction = tune(), sample_size = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost") 

# Combine recipe and model using workflow

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

# Tune hyperparamaters

set.seed(5218)
xgboost_tune <-
  tune_grid(xgboost_workflow, 
            resamples = data_cv, 
            grid = 10)
## Warning: package 'xgboost' was built under R version 4.3.2

Evaluate Models

tune::show_best(xgboost_tune, metric = "rsq")
## # A tibble: 5 × 12
##   trees min_n tree_depth learn_rate loss_reduction sample_size .metric
##   <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>  
## 1    49     4         15    0.0723        1.71e- 8       0.966 rsq    
## 2  1978    12          8    0.00689       1.21e- 1       0.569 rsq    
## 3  1245    20         10    0.0406        1.86e- 1       0.828 rsq    
## 4   641     6          2    0.120         5.10e-10       0.110 rsq    
## 5   420    33         13    0.0206        5.48e- 3       0.783 rsq    
## # ℹ 5 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>,
## #   .config <chr>
# Update the model by selecting the best hyperparameters
xgboost_fw <- tune::finalize_workflow(xgboost_workflow,
                        tune::select_best(xgboost_tune, metric = "rsq"))

# Fit the model on the entire training data and test it on the test data.
 data_fit <- tune::last_fit(xgboost_fw, data_split)
tune::collect_metrics(data_fit)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       0.116 Preprocessor1_Model1
## 2 rsq     standard       0.986 Preprocessor1_Model1
tune::collect_predictions(data_fit) %>%
  ggplot(aes(cost, .pred)) +
  geom_point(alpha = 0.3, fill = "midnightblue") +
  geom_abline(lty = 2, color = "gray50") + 
  coord_fixed()