Goal: To figure out how to deliver more high-capacity transit projects for a fraction of the cost in countries like the United States. click here for the data.

Import Data

transit_cost <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-01-05/transit_cost.csv')
## Rows: 544 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (11): country, city, line, start_year, end_year, tunnel_per, source1, cu...
## dbl  (9): e, rr, length, tunnel, stations, cost, year, ppp_rate, cost_km_mil...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(transit_cost) 
Data summary
Name transit_cost
Number of rows 544
Number of columns 20
_______________________
Column type frequency:
character 11
numeric 9
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
country 7 0.99 2 2 0 56 0
city 7 0.99 4 16 0 140 0
line 7 0.99 2 46 0 366 0
start_year 53 0.90 4 9 0 40 0
end_year 71 0.87 1 4 0 36 0
tunnel_per 32 0.94 5 7 0 134 0
source1 12 0.98 4 54 0 17 0
currency 7 0.99 2 3 0 39 0
real_cost 0 1.00 1 10 0 534 0
source2 10 0.98 3 16 0 12 0
reference 19 0.97 3 302 0 350 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
e 7 0.99 7738.76 463.23 7136.00 7403.00 7705.00 7977.00 9510.00 ▇▇▂▁▁
rr 8 0.99 0.06 0.24 0.00 0.00 0.00 0.00 1.00 ▇▁▁▁▁
length 5 0.99 58.34 621.20 0.60 6.50 15.77 29.08 12256.98 ▇▁▁▁▁
tunnel 32 0.94 29.38 344.04 0.00 3.40 8.91 21.52 7790.78 ▇▁▁▁▁
stations 15 0.97 13.81 13.70 0.00 4.00 10.00 20.00 128.00 ▇▁▁▁▁
cost 7 0.99 805438.12 6708033.07 0.00 2289.00 11000.00 27000.00 90000000.00 ▇▁▁▁▁
year 7 0.99 2014.91 5.64 1987.00 2012.00 2016.00 2019.00 2027.00 ▁▁▂▇▂
ppp_rate 9 0.98 0.66 0.87 0.00 0.24 0.26 1.00 5.00 ▇▂▁▁▁
cost_km_millions 2 1.00 232.98 257.22 7.79 134.86 181.25 241.43 3928.57 ▇▁▁▁▁
data <- transit_cost %>%
    
    # Select relevant variables
    select(e, cost_km_millions, country, city, year, rr, stations) %>%
    
    # rr convert it to factor: 1 = railroad
    mutate(rr = as.factor(rr)) %>%
    
    # Remove missing values
    na.omit() %>%
    
    # Tranform the target var
    mutate(cost_km_millions = log(cost_km_millions))

Explore Data

top20_cities_vec <- data %>%
    count(city, sort = TRUE) %>%
    head(20) %>%
    pull(city)

top20_cities_vec
##  [1] "Shanghai"  "Beijing"   "Wuhan"     "Istanbul"  "Shenzhen"  "Changsha" 
##  [7] "Mumbai"    "Nanjing"   "Chengdu"   "Chongqing" "Hangzhou"  "Paris"    
## [13] "Guangzhou" "Hefei"     "Tokyo"     "Kunming"   "Taipei"    "Tianjin"  
## [19] "Bangkok"   "Changchun"
data %>%
    
    # Filter for top 20 cities
    filter(city %in% top20_cities_vec) %>%
    
    # Plot
    ggplot(aes(cost_km_millions, fct_reorder(city, cost_km_millions))) +
    geom_boxplot()

Identify good predictors.

EDA Shortcut

# Step 1 Prepare

data_binarized_tbl <- data %>%
    select(-e) %>%
    binarize()

data_binarized_tbl %>% glimpse()
## Rows: 526
## Columns: 59
## $ `cost_km_millions__-Inf_4.90431298715264`           <dbl> 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__4.90431298715264_5.19655135163657 <dbl> 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__5.19655135163657_5.47823632836085 <dbl> 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__5.47823632836085_Inf              <dbl> 1, 1, 1, 1, 1, 1, …
## $ country__BG                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__CA                                         <dbl> 1, 1, 1, 1, 1, 0, …
## $ country__CN                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__DE                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__ES                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__FR                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__IN                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__IT                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__JP                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__KR                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__SA                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__TH                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__TR                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__TW                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__US                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ `country__-OTHER`                                   <dbl> 0, 0, 0, 0, 0, 1, …
## $ city__Bangkok                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Barcelona                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Beijing                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Changchun                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Changsha                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Chengdu                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Chongqing                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Dongguan                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Guangzhou                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Hangzhou                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Hefei                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Istanbul                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Kunming                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Madrid                                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Mumbai                                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Nanjing                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Paris                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Riyadh                                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Shanghai                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Shenzhen                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Sofia                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Taipei                                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Tianjin                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Tokyo                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Toronto                                       <dbl> 0, 1, 1, 1, 1, 0, …
## $ city__Wuhan                                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ `city__Xi'an`                                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Zhengzhou                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ `city__-OTHER`                                      <dbl> 1, 0, 0, 0, 0, 1, …
## $ `year__-Inf_2013`                                   <dbl> 0, 1, 0, 0, 0, 1, …
## $ year__2013_2016                                     <dbl> 0, 0, 0, 0, 0, 0, …
## $ year__2016_2019                                     <dbl> 1, 0, 1, 1, 0, 0, …
## $ year__2019_Inf                                      <dbl> 0, 0, 0, 0, 1, 0, …
## $ rr__0                                               <dbl> 1, 1, 1, 1, 1, 1, …
## $ rr__1                                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ `stations__-Inf_4`                                  <dbl> 0, 0, 1, 0, 0, 0, …
## $ stations__4_10                                      <dbl> 1, 1, 0, 0, 1, 1, …
## $ stations__10_20                                     <dbl> 0, 0, 0, 1, 0, 0, …
## $ stations__20_Inf                                    <dbl> 0, 0, 0, 0, 0, 0, …
# Step 2 Correlate
data_corr_tbl <- data_binarized_tbl %>%
    correlate(cost_km_millions__5.47823632836085_Inf)

data_corr_tbl %>% glimpse()
## Rows: 59
## Columns: 3
## $ feature     <fct> cost_km_millions, cost_km_millions, cost_km_millions, cost…
## $ bin         <chr> "5.47823632836085_Inf", "-Inf_4.90431298715264", "4.904312…
## $ correlation <dbl> 1.00000000, -0.33502538, -0.33333119, -0.33333119, 0.27502…
# Step 3 Plot 
data_corr_tbl %>%
    plot_correlation_funnel()
## Warning: ggrepel: 40 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Preprocess Data

Build Models

split data

# data <- sample_n(data, 100)

# Split into train and test data set
set.seed(12345)
data <- sample_n(data, 200)

set.seed(123)
transit_split <- initial_split(data, strata = cost_km_millions)
transit_train <- training(transit_split)
transit_test <- testing(transit_split)

# Further split training data set for cross validation
set.seed(234)
transit_folds <- bootstraps(transit_train, strata = cost_km_millions)
transit_folds
## # Bootstrap sampling using stratification 
## # A tibble: 25 × 2
##    splits           id         
##    <list>           <chr>      
##  1 <split [148/52]> Bootstrap01
##  2 <split [148/57]> Bootstrap02
##  3 <split [148/49]> Bootstrap03
##  4 <split [148/54]> Bootstrap04
##  5 <split [148/58]> Bootstrap05
##  6 <split [148/54]> Bootstrap06
##  7 <split [148/49]> Bootstrap07
##  8 <split [148/61]> Bootstrap08
##  9 <split [148/54]> Bootstrap09
## 10 <split [148/56]> Bootstrap10
## # ℹ 15 more rows
# Specify recipe

xgboost_recipe <- 
  recipe(formula = cost_km_millions ~ ., data = transit_train)  %>%
    recipes::update_role(e, new_role = "id variable") %>%
    step_other(country, city) %>%
    step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
    step_YeoJohnson(stations)

xgboost_recipe %>% prep() %>% bake(new_data = NULL) %>% glimpse()
## Rows: 148
## Columns: 11
## $ e                <dbl> 7210, 8107, 7216, 8139, 7379, 7208, 7601, 8097, 7560,…
## $ year             <dbl> 2020, 2015, 2012, 2005, 2001, 2005, 1998, 2018, 2016,…
## $ stations         <dbl> 1.9354396, 2.6443424, 2.2016661, 1.9354396, 5.7189018…
## $ cost_km_millions <dbl> 4.756603, 4.837789, 4.023306, 2.052841, 3.963033, 4.5…
## $ country_CN       <dbl> 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,…
## $ country_IN       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ country_other    <dbl> 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1,…
## $ city_Shanghai    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,…
## $ city_other       <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,…
## $ rr_X0            <dbl> 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
## $ rr_X1            <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
# Specify model
xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune(), mtry = tune(), learn_rate = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost") 

# Combine recipe and model using workflow
xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

# Tune hyperparameters
set.seed(344)
xgboost_tune <-
  tune_grid(xgboost_workflow, 
            resamples = transit_folds, 
            grid = 10)
## i Creating pre-processing data to finalize unknown parameter: mtry
## → A | warning: A correlation computation is required, but `estimate` is constant and has 0
##                standard deviation, resulting in a divide by 0 error. `NA` will be returned.
## 
There were issues with some computations   A: x1

There were issues with some computations   A: x2

There were issues with some computations   A: x3

There were issues with some computations   A: x4

There were issues with some computations   A: x5

There were issues with some computations   A: x6

There were issues with some computations   A: x7

There were issues with some computations   A: x8

There were issues with some computations   A: x9

There were issues with some computations   A: x10

There were issues with some computations   A: x11

There were issues with some computations   A: x12

There were issues with some computations   A: x13

There were issues with some computations   A: x14

There were issues with some computations   A: x15

There were issues with some computations   A: x16

There were issues with some computations   A: x17

There were issues with some computations   A: x18

There were issues with some computations   A: x19

There were issues with some computations   A: x20

There were issues with some computations   A: x21

There were issues with some computations   A: x22

There were issues with some computations   A: x23

There were issues with some computations   A: x23

Evaluate Models

tune::show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 10
##    mtry trees min_n learn_rate .metric .estimator  mean     n std_err .config   
##   <int> <int> <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>     
## 1     5   770    40    0.00622 rmse    standard   0.672    25  0.0143 Preproces…
## 2     6   928    24    0.0171  rmse    standard   0.704    25  0.0127 Preproces…
## 3     7  1740    35    0.0677  rmse    standard   0.733    25  0.0124 Preproces…
## 4     4  1820    16    0.0223  rmse    standard   0.759    25  0.0140 Preproces…
## 5     6   255    19    0.211   rmse    standard   0.783    25  0.0142 Preproces…
# How did all the possible parameter combinations do?
autoplot(xgboost_tune)

# Update the model by selecting the best hyper parameter
final_rf <- xgboost_workflow %>% 
    finalize_workflow(select_best(xgboost_tune, metric = "rmse"))

# Fit the model on the entire training data and test it on the test data
transit_fit <- last_fit(final_rf, transit_split)
transit_fit
## # Resampling results
## # Manual resampling 
## # A tibble: 1 × 6
##   splits           id               .metrics .notes   .predictions .workflow 
##   <list>           <chr>            <list>   <list>   <list>       <list>    
## 1 <split [148/52]> train/test split <tibble> <tibble> <tibble>     <workflow>
collect_metrics(transit_fit)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard     0.640   Preprocessor1_Model1
## 2 rsq     standard     0.00116 Preprocessor1_Model1
collect_predictions(transit_fit)
## # A tibble: 52 × 5
##    .pred id                .row cost_km_millions .config             
##    <dbl> <chr>            <int>            <dbl> <chr>               
##  1  4.89 train/test split     1             5.61 Preprocessor1_Model1
##  2  5.45 train/test split     2             4.98 Preprocessor1_Model1
##  3  5.43 train/test split     3             4.73 Preprocessor1_Model1
##  4  5.30 train/test split     5             4.28 Preprocessor1_Model1
##  5  5.05 train/test split    12             4.91 Preprocessor1_Model1
##  6  4.77 train/test split    15             4.55 Preprocessor1_Model1
##  7  5.22 train/test split    24             5.36 Preprocessor1_Model1
##  8  4.89 train/test split    29             5.17 Preprocessor1_Model1
##  9  5.17 train/test split    31             5.25 Preprocessor1_Model1
## 10  5.44 train/test split    36             6.51 Preprocessor1_Model1
## # ℹ 42 more rows
collect_predictions(transit_fit) %>%
    ggplot(aes(cost_km_millions, .pred)) +
    geom_point(alpha = 0.5, fill = "skyblue") +
    geom_abline(lty = 2, color = "blue") +
    coord_fixed()

# Make predictions