Goal: To figure out how to deliver more high-capacity transit projects for a fraction of the cost in countries like the United States. click here for the data.
transit_cost <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-01-05/transit_cost.csv')
## Rows: 544 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (11): country, city, line, start_year, end_year, tunnel_per, source1, cu...
## dbl (9): e, rr, length, tunnel, stations, cost, year, ppp_rate, cost_km_mil...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(transit_cost)
Name | transit_cost |
Number of rows | 544 |
Number of columns | 20 |
_______________________ | |
Column type frequency: | |
character | 11 |
numeric | 9 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
country | 7 | 0.99 | 2 | 2 | 0 | 56 | 0 |
city | 7 | 0.99 | 4 | 16 | 0 | 140 | 0 |
line | 7 | 0.99 | 2 | 46 | 0 | 366 | 0 |
start_year | 53 | 0.90 | 4 | 9 | 0 | 40 | 0 |
end_year | 71 | 0.87 | 1 | 4 | 0 | 36 | 0 |
tunnel_per | 32 | 0.94 | 5 | 7 | 0 | 134 | 0 |
source1 | 12 | 0.98 | 4 | 54 | 0 | 17 | 0 |
currency | 7 | 0.99 | 2 | 3 | 0 | 39 | 0 |
real_cost | 0 | 1.00 | 1 | 10 | 0 | 534 | 0 |
source2 | 10 | 0.98 | 3 | 16 | 0 | 12 | 0 |
reference | 19 | 0.97 | 3 | 302 | 0 | 350 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
e | 7 | 0.99 | 7738.76 | 463.23 | 7136.00 | 7403.00 | 7705.00 | 7977.00 | 9510.00 | ▇▇▂▁▁ |
rr | 8 | 0.99 | 0.06 | 0.24 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 | ▇▁▁▁▁ |
length | 5 | 0.99 | 58.34 | 621.20 | 0.60 | 6.50 | 15.77 | 29.08 | 12256.98 | ▇▁▁▁▁ |
tunnel | 32 | 0.94 | 29.38 | 344.04 | 0.00 | 3.40 | 8.91 | 21.52 | 7790.78 | ▇▁▁▁▁ |
stations | 15 | 0.97 | 13.81 | 13.70 | 0.00 | 4.00 | 10.00 | 20.00 | 128.00 | ▇▁▁▁▁ |
cost | 7 | 0.99 | 805438.12 | 6708033.07 | 0.00 | 2289.00 | 11000.00 | 27000.00 | 90000000.00 | ▇▁▁▁▁ |
year | 7 | 0.99 | 2014.91 | 5.64 | 1987.00 | 2012.00 | 2016.00 | 2019.00 | 2027.00 | ▁▁▂▇▂ |
ppp_rate | 9 | 0.98 | 0.66 | 0.87 | 0.00 | 0.24 | 0.26 | 1.00 | 5.00 | ▇▂▁▁▁ |
cost_km_millions | 2 | 1.00 | 232.98 | 257.22 | 7.79 | 134.86 | 181.25 | 241.43 | 3928.57 | ▇▁▁▁▁ |
data <- transit_cost %>%
# Select relevant variables
select(e, cost_km_millions, country, city, year, rr, stations) %>%
# rr convert it to factor: 1 = railroad
mutate(rr = as.factor(rr)) %>%
# Remove missing values
na.omit() %>%
# Tranform the target var
mutate(cost_km_millions = log(cost_km_millions))
top20_cities_vec <- data %>%
count(city, sort = TRUE) %>%
head(20) %>%
pull(city)
top20_cities_vec
## [1] "Shanghai" "Beijing" "Wuhan" "Istanbul" "Shenzhen" "Changsha"
## [7] "Mumbai" "Nanjing" "Chengdu" "Chongqing" "Hangzhou" "Paris"
## [13] "Guangzhou" "Hefei" "Tokyo" "Kunming" "Taipei" "Tianjin"
## [19] "Bangkok" "Changchun"
data %>%
# Filter for top 20 cities
filter(city %in% top20_cities_vec) %>%
# Plot
ggplot(aes(cost_km_millions, fct_reorder(city, cost_km_millions))) +
geom_boxplot()
Identify good predictors.
EDA Shortcut
# Step 1 Prepare
data_binarized_tbl <- data %>%
select(-e) %>%
binarize()
data_binarized_tbl %>% glimpse()
## Rows: 526
## Columns: 59
## $ `cost_km_millions__-Inf_4.90431298715264` <dbl> 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__4.90431298715264_5.19655135163657 <dbl> 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__5.19655135163657_5.47823632836085 <dbl> 0, 0, 0, 0, 0, 0, …
## $ cost_km_millions__5.47823632836085_Inf <dbl> 1, 1, 1, 1, 1, 1, …
## $ country__BG <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__CA <dbl> 1, 1, 1, 1, 1, 0, …
## $ country__CN <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__DE <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__ES <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__FR <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__IN <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__IT <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__JP <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__KR <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__SA <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__TH <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__TR <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__TW <dbl> 0, 0, 0, 0, 0, 0, …
## $ country__US <dbl> 0, 0, 0, 0, 0, 0, …
## $ `country__-OTHER` <dbl> 0, 0, 0, 0, 0, 1, …
## $ city__Bangkok <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Barcelona <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Beijing <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Changchun <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Changsha <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Chengdu <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Chongqing <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Dongguan <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Guangzhou <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Hangzhou <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Hefei <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Istanbul <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Kunming <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Madrid <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Mumbai <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Nanjing <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Paris <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Riyadh <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Shanghai <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Shenzhen <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Sofia <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Taipei <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Tianjin <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Tokyo <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Toronto <dbl> 0, 1, 1, 1, 1, 0, …
## $ city__Wuhan <dbl> 0, 0, 0, 0, 0, 0, …
## $ `city__Xi'an` <dbl> 0, 0, 0, 0, 0, 0, …
## $ city__Zhengzhou <dbl> 0, 0, 0, 0, 0, 0, …
## $ `city__-OTHER` <dbl> 1, 0, 0, 0, 0, 1, …
## $ `year__-Inf_2013` <dbl> 0, 1, 0, 0, 0, 1, …
## $ year__2013_2016 <dbl> 0, 0, 0, 0, 0, 0, …
## $ year__2016_2019 <dbl> 1, 0, 1, 1, 0, 0, …
## $ year__2019_Inf <dbl> 0, 0, 0, 0, 1, 0, …
## $ rr__0 <dbl> 1, 1, 1, 1, 1, 1, …
## $ rr__1 <dbl> 0, 0, 0, 0, 0, 0, …
## $ `stations__-Inf_4` <dbl> 0, 0, 1, 0, 0, 0, …
## $ stations__4_10 <dbl> 1, 1, 0, 0, 1, 1, …
## $ stations__10_20 <dbl> 0, 0, 0, 1, 0, 0, …
## $ stations__20_Inf <dbl> 0, 0, 0, 0, 0, 0, …
# Step 2 Correlate
data_corr_tbl <- data_binarized_tbl %>%
correlate(cost_km_millions__5.47823632836085_Inf)
data_corr_tbl %>% glimpse()
## Rows: 59
## Columns: 3
## $ feature <fct> cost_km_millions, cost_km_millions, cost_km_millions, cost…
## $ bin <chr> "5.47823632836085_Inf", "-Inf_4.90431298715264", "4.904312…
## $ correlation <dbl> 1.00000000, -0.33502538, -0.33333119, -0.33333119, 0.27502…
# Step 3 Plot
data_corr_tbl %>%
plot_correlation_funnel()
## Warning: ggrepel: 40 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
split data
# data <- sample_n(data, 100)
# Split into train and test data set
set.seed(12345)
data <- sample_n(data, 200)
set.seed(123)
transit_split <- initial_split(data, strata = cost_km_millions)
transit_train <- training(transit_split)
transit_test <- testing(transit_split)
# Further split training data set for cross validation
set.seed(234)
transit_folds <- bootstraps(transit_train, strata = cost_km_millions)
transit_folds
## # Bootstrap sampling using stratification
## # A tibble: 25 × 2
## splits id
## <list> <chr>
## 1 <split [148/52]> Bootstrap01
## 2 <split [148/57]> Bootstrap02
## 3 <split [148/49]> Bootstrap03
## 4 <split [148/54]> Bootstrap04
## 5 <split [148/58]> Bootstrap05
## 6 <split [148/54]> Bootstrap06
## 7 <split [148/49]> Bootstrap07
## 8 <split [148/61]> Bootstrap08
## 9 <split [148/54]> Bootstrap09
## 10 <split [148/56]> Bootstrap10
## # ℹ 15 more rows
# Specify recipe
xgboost_recipe <-
recipe(formula = cost_km_millions ~ ., data = transit_train) %>%
recipes::update_role(e, new_role = "id variable") %>%
step_other(country, city) %>%
step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
step_YeoJohnson(stations)
xgboost_recipe %>% prep() %>% bake(new_data = NULL) %>% glimpse()
## Rows: 148
## Columns: 11
## $ e <dbl> 7210, 8107, 7216, 8139, 7379, 7208, 7601, 8097, 7560,…
## $ year <dbl> 2020, 2015, 2012, 2005, 2001, 2005, 1998, 2018, 2016,…
## $ stations <dbl> 1.9354396, 2.6443424, 2.2016661, 1.9354396, 5.7189018…
## $ cost_km_millions <dbl> 4.756603, 4.837789, 4.023306, 2.052841, 3.963033, 4.5…
## $ country_CN <dbl> 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,…
## $ country_IN <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ country_other <dbl> 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1,…
## $ city_Shanghai <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,…
## $ city_other <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,…
## $ rr_X0 <dbl> 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
## $ rr_X1 <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
# Specify model
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune(), mtry = tune(), learn_rate = tune()) %>%
set_mode("regression") %>%
set_engine("xgboost")
# Combine recipe and model using workflow
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_recipe) %>%
add_model(xgboost_spec)
# Tune hyperparameters
set.seed(344)
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = transit_folds,
grid = 10)
## i Creating pre-processing data to finalize unknown parameter: mtry
## → A | warning: A correlation computation is required, but `estimate` is constant and has 0
## standard deviation, resulting in a divide by 0 error. `NA` will be returned.
##
There were issues with some computations A: x1
There were issues with some computations A: x2
There were issues with some computations A: x3
There were issues with some computations A: x4
There were issues with some computations A: x5
There were issues with some computations A: x6
There were issues with some computations A: x7
There were issues with some computations A: x8
There were issues with some computations A: x9
There were issues with some computations A: x10
There were issues with some computations A: x11
There were issues with some computations A: x12
There were issues with some computations A: x13
There were issues with some computations A: x14
There were issues with some computations A: x15
There were issues with some computations A: x16
There were issues with some computations A: x17
There were issues with some computations A: x18
There were issues with some computations A: x19
There were issues with some computations A: x20
There were issues with some computations A: x21
There were issues with some computations A: x22
There were issues with some computations A: x23
There were issues with some computations A: x23
tune::show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 10
## mtry trees min_n learn_rate .metric .estimator mean n std_err .config
## <int> <int> <int> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 5 770 40 0.00622 rmse standard 0.672 25 0.0143 Preproces…
## 2 6 928 24 0.0171 rmse standard 0.704 25 0.0127 Preproces…
## 3 7 1740 35 0.0677 rmse standard 0.733 25 0.0124 Preproces…
## 4 4 1820 16 0.0223 rmse standard 0.759 25 0.0140 Preproces…
## 5 6 255 19 0.211 rmse standard 0.783 25 0.0142 Preproces…
# How did all the possible parameter combinations do?
autoplot(xgboost_tune)
# Update the model by selecting the best hyper parameter
final_rf <- xgboost_workflow %>%
finalize_workflow(select_best(xgboost_tune, metric = "rmse"))
# Fit the model on the entire training data and test it on the test data
transit_fit <- last_fit(final_rf, transit_split)
transit_fit
## # Resampling results
## # Manual resampling
## # A tibble: 1 × 6
## splits id .metrics .notes .predictions .workflow
## <list> <chr> <list> <list> <list> <list>
## 1 <split [148/52]> train/test split <tibble> <tibble> <tibble> <workflow>
collect_metrics(transit_fit)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 0.640 Preprocessor1_Model1
## 2 rsq standard 0.00116 Preprocessor1_Model1
collect_predictions(transit_fit)
## # A tibble: 52 × 5
## .pred id .row cost_km_millions .config
## <dbl> <chr> <int> <dbl> <chr>
## 1 4.89 train/test split 1 5.61 Preprocessor1_Model1
## 2 5.45 train/test split 2 4.98 Preprocessor1_Model1
## 3 5.43 train/test split 3 4.73 Preprocessor1_Model1
## 4 5.30 train/test split 5 4.28 Preprocessor1_Model1
## 5 5.05 train/test split 12 4.91 Preprocessor1_Model1
## 6 4.77 train/test split 15 4.55 Preprocessor1_Model1
## 7 5.22 train/test split 24 5.36 Preprocessor1_Model1
## 8 4.89 train/test split 29 5.17 Preprocessor1_Model1
## 9 5.17 train/test split 31 5.25 Preprocessor1_Model1
## 10 5.44 train/test split 36 6.51 Preprocessor1_Model1
## # ℹ 42 more rows
collect_predictions(transit_fit) %>%
ggplot(aes(cost_km_millions, .pred)) +
geom_point(alpha = 0.5, fill = "skyblue") +
geom_abline(lty = 2, color = "blue") +
coord_fixed()
# Make predictions