Transit Costs: Why do transit-infrastructure projects in New York cost 20 times more on a per kilometer basis than in Seoul? Build a regression model to predict the cost (real_cost). Use the transit_cost dataset.

Import Data

transit_cost <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-01-05/transit_cost.csv')
## Rows: 544 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (11): country, city, line, start_year, end_year, tunnel_per, source1, cu...
## dbl  (9): e, rr, length, tunnel, stations, cost, year, ppp_rate, cost_km_mil...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(transit_cost)
Data summary
Name transit_cost
Number of rows 544
Number of columns 20
_______________________
Column type frequency:
character 11
numeric 9
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
country 7 0.99 2 2 0 56 0
city 7 0.99 4 16 0 140 0
line 7 0.99 2 46 0 366 0
start_year 53 0.90 4 9 0 40 0
end_year 71 0.87 1 4 0 36 0
tunnel_per 32 0.94 5 7 0 134 0
source1 12 0.98 4 54 0 17 0
currency 7 0.99 2 3 0 39 0
real_cost 0 1.00 1 10 0 534 0
source2 10 0.98 3 16 0 12 0
reference 19 0.97 3 302 0 350 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
e 7 0.99 7738.76 463.23 7136.00 7403.00 7705.00 7977.00 9510.00 ▇▇▂▁▁
rr 8 0.99 0.06 0.24 0.00 0.00 0.00 0.00 1.00 ▇▁▁▁▁
length 5 0.99 58.34 621.20 0.60 6.50 15.77 29.08 12256.98 ▇▁▁▁▁
tunnel 32 0.94 29.38 344.04 0.00 3.40 8.91 21.52 7790.78 ▇▁▁▁▁
stations 15 0.97 13.81 13.70 0.00 4.00 10.00 20.00 128.00 ▇▁▁▁▁
cost 7 0.99 805438.12 6708033.07 0.00 2289.00 11000.00 27000.00 90000000.00 ▇▁▁▁▁
year 7 0.99 2014.91 5.64 1987.00 2012.00 2016.00 2019.00 2027.00 ▁▁▂▇▂
ppp_rate 9 0.98 0.66 0.87 0.00 0.24 0.26 1.00 5.00 ▇▂▁▁▁
cost_km_millions 2 1.00 232.98 257.22 7.79 134.86 181.25 241.43 3928.57 ▇▁▁▁▁
data <- transit_cost %>%
    
    # Select relevant variables
    select(e, cost_km_millions, country, city, year, rr, stations) %>%
    
    # rr convert it to factor: 1 = railroad
    mutate(rr = as.factor(rr)) %>%
    
    # Remove missing values
    na.omit() %>%
    
    # Tranform the target var
    mutate(cost_km_millions = log(cost_km_millions))

Explore Data

correlation between country and cost

top20_cities_vec <- data %>%
    count(city, sort = TRUE) %>%
    head(20) %>%
    pull(city)

top20_cities_vec
##  [1] "Shanghai"  "Beijing"   "Wuhan"     "Istanbul"  "Shenzhen"  "Changsha" 
##  [7] "Mumbai"    "Nanjing"   "Chengdu"   "Chongqing" "Hangzhou"  "Paris"    
## [13] "Guangzhou" "Hefei"     "Tokyo"     "Kunming"   "Taipei"    "Tianjin"  
## [19] "Bangkok"   "Changchun"
data %>%
    
    # Filter for top 20 cities
    filter(city %in% top20_cities_vec) %>%
    
    # Plot
    ggplot(aes(cost_km_millions, fct_reorder(city, cost_km_millions))) +
    geom_boxplot()

EDA shortcut

# Step 1: Prepare data
data_binarized_tbl <- data %>%
    select(-e) %>%
    binarize()

data_binarized_tbl %>% glimpse()
## Rows: 526
## Columns: 59
## $ `cost_km_millions__-Inf_4.90431298715264`           <dbl> 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__4.90431298715264_5.19655135163657 <dbl> 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__5.19655135163657_5.47823632836085 <dbl> 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__5.47823632836085_Inf              <dbl> 1, 1, 1, 1, 1, 1, …
## $ country__BG                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__CA                                         <dbl> 1, 1, 1, 1, 1, 0, …
## $ country__CN                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__DE                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__ES                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__FR                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__IN                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__IT                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__JP                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__KR                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__SA                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__TH                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__TR                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__TW                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__US                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ `country__-OTHER`                                   <dbl> 0, 0, 0, 0, 0, 1, …
## $ city__Bangkok                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Barcelona                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Beijing                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Changchun                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Changsha                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Chengdu                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Chongqing                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Dongguan                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Guangzhou                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Hangzhou                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Hefei                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Istanbul                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Kunming                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Madrid                                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Mumbai                                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Nanjing                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Paris                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Riyadh                                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Shanghai                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Shenzhen                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Sofia                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Taipei                                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Tianjin                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Tokyo                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Toronto                                       <dbl> 0, 1, 1, 1, 1, 0, …
## $ city__Wuhan                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ `city__Xi'an`                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Zhengzhou                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ `city__-OTHER`                                      <dbl> 1, 0, 0, 0, 0, 1, …
## $ `year__-Inf_2013`                                   <dbl> 0, 1, 0, 0, 0, 1, …
## $ year__2013_2016                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ year__2016_2019                                     <dbl> 1, 0, 1, 1, 0, 0, …
## $ year__2019_Inf                                      <dbl> 0, 0, 0, 0, 1, 0, …
## $ rr__0                                               <dbl> 1, 1, 1, 1, 1, 1, …
## $ rr__1                                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ `stations__-Inf_4`                                  <dbl> 0, 0, 1, 0, 0, 0, …
## $ stations__4_10                                      <dbl> 1, 1, 0, 0, 1, 1, …
## $ stations__10_20                                     <dbl> 0, 0, 0, 1, 0, 0, …
## $ stations__20_Inf                                    <dbl> 0, 0, 0, 0, 0, 0, …
# Step 2: Correlate
data_corr_tbl <- data_binarized_tbl %>%
    correlate(cost_km_millions__5.47823632836085_Inf)

data_corr_tbl
## # A tibble: 59 × 3
##    feature          bin                               correlation
##    <fct>            <chr>                                   <dbl>
##  1 cost_km_millions 5.47823632836085_Inf                    1    
##  2 cost_km_millions -Inf_4.90431298715264                  -0.335
##  3 cost_km_millions 4.90431298715264_5.19655135163657      -0.333
##  4 cost_km_millions 5.19655135163657_5.47823632836085      -0.333
##  5 country          US                                      0.275
##  6 country          CN                                     -0.233
##  7 country          -OTHER                                  0.219
##  8 year             2019_Inf                                0.174
##  9 city             -OTHER                                  0.167
## 10 rr               0                                      -0.128
## # ℹ 49 more rows
# Step 3: Plot
data_corr_tbl %>%
    plot_correlation_funnel()

Build a Model

set.seed(12345)
data <- sample_n(data, 200)

set.seed(123)
transit_split <- initial_split(data, strata = cost_km_millions)
transit_train <- training(transit_split)
transit_test <- testing(transit_split)

set.seed(234)
transit_folds <- bootstraps(transit_train, strata = cost_km_millions)
transit_folds
## # Bootstrap sampling using stratification 
## # A tibble: 25 × 2
##    splits           id         
##    <list>           <chr>      
##  1 <split [148/52]> Bootstrap01
##  2 <split [148/57]> Bootstrap02
##  3 <split [148/49]> Bootstrap03
##  4 <split [148/54]> Bootstrap04
##  5 <split [148/58]> Bootstrap05
##  6 <split [148/54]> Bootstrap06
##  7 <split [148/49]> Bootstrap07
##  8 <split [148/61]> Bootstrap08
##  9 <split [148/54]> Bootstrap09
## 10 <split [148/56]> Bootstrap10
## # ℹ 15 more rows
xgboost_recipe <- 
  recipe(formula = cost_km_millions ~ ., data = transit_train)  %>%
    recipes::update_role(e, new_role = "id variable") %>%
    step_other(country, city) %>%
    step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
    step_YeoJohnson(stations)

xgboost_recipe %>% prep() %>% bake(new_data = NULL) %>% glimpse()
## Rows: 148
## Columns: 11
## $ e                <dbl> 7210, 8107, 7216, 8139, 7379, 7208, 7601, 8097, 7560,…
## $ year             <dbl> 2020, 2015, 2012, 2005, 2001, 2005, 1998, 2018, 2016,…
## $ stations         <dbl> 1.9354396, 2.6443424, 2.2016661, 1.9354396, 5.7189018…
## $ cost_km_millions <dbl> 4.756603, 4.837789, 4.023306, 2.052841, 3.963033, 4.5…
## $ country_CN       <dbl> 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,…
## $ country_IN       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ country_other    <dbl> 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1,…
## $ city_Shanghai    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,…
## $ city_other       <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,…
## $ rr_X0            <dbl> 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
## $ rr_X1            <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
# Specify model
xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune(), mtry = tune(), learn_rate = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost") 

# Combine recipe and model using workflow
xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

# Tune hyperparameters
set.seed(344)
xgboost_tune <-
  tune_grid(xgboost_workflow, 
            resamples = transit_folds, 
            grid = 5)
## i Creating pre-processing data to finalize unknown parameter: mtry
## → A | warning: A correlation computation is required, but `estimate` is constant and has 0
##                standard deviation, resulting in a divide by 0 error. `NA` will be returned.
## There were issues with some computations   A: x1There were issues with some computations   A: x2There were issues with some computations   A: x3There were issues with some computations   A: x4There were issues with some computations   A: x5There were issues with some computations   A: x6There were issues with some computations   A: x6

Explore Results

show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 10
##    mtry trees min_n learn_rate .metric .estimator  mean     n std_err .config   
##   <int> <int> <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>     
## 1     4  1104    28    0.00484 rmse    standard   0.669    25  0.0147 Preproces…
## 2     1  1613    36    0.0290  rmse    standard   0.674    25  0.0134 Preproces…
## 3     7  1524    23    0.0836  rmse    standard   0.811    25  0.0157 Preproces…
## 4     6   768    12    0.112   rmse    standard   0.866    25  0.0187 Preproces…
## 5     8   162     7    0.00108 rmse    standard   3.98     25  0.0106 Preproces…
# How did all the possible parameter combinations do?
autoplot(xgboost_tune)

We can finalize our random forest workflow with the best performing parameters.

final_rf <- xgboost_workflow %>% 
    finalize_workflow(select_best(xgboost_tune, metric = "rmse"))

The function last_fit() fits this finalized random forest one last time to the training data and evaluates one last time on the testing data.

transit_fit <- last_fit(final_rf, transit_split)
transit_fit
## # Resampling results
## # Manual resampling 
## # A tibble: 1 × 6
##   splits           id               .metrics .notes   .predictions .workflow 
##   <list>           <chr>            <list>   <list>   <list>       <list>    
## 1 <split [148/52]> train/test split <tibble> <tibble> <tibble>     <workflow>

Evaluate model

collect_metrics(transit_fit)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard   0.650     Preprocessor1_Model1
## 2 rsq     standard   0.0000113 Preprocessor1_Model1
collect_predictions(transit_fit)
## # A tibble: 52 × 5
##    .pred id                .row cost_km_millions .config             
##    <dbl> <chr>            <int>            <dbl> <chr>               
##  1  4.83 train/test split     1             5.61 Preprocessor1_Model1
##  2  5.45 train/test split     2             4.98 Preprocessor1_Model1
##  3  5.55 train/test split     3             4.73 Preprocessor1_Model1
##  4  5.23 train/test split     5             4.28 Preprocessor1_Model1
##  5  5.02 train/test split    12             4.91 Preprocessor1_Model1
##  6  4.63 train/test split    15             4.55 Preprocessor1_Model1
##  7  5.29 train/test split    24             5.36 Preprocessor1_Model1
##  8  4.83 train/test split    29             5.17 Preprocessor1_Model1
##  9  5.11 train/test split    31             5.25 Preprocessor1_Model1
## 10  5.41 train/test split    36             6.51 Preprocessor1_Model1
## # ℹ 42 more rows
collect_predictions(transit_fit) %>%
    ggplot(aes(cost_km_millions, .pred)) +
    geom_point(alpha = 0.5, fill = "midnightblue") +
    geom_abline(lty = 2, color = "gray50") +
    coord_fixed()