# load needed libraries
library(tidyverse)
library(gtsummary)
library(tidymodels)
library(themis)
Random forests for binary classification: predicting legal status of trees in San Francisco
Data
- Data source: San Francisco Department of Public Works (DPW) Bureau of Urban Forestry
- Outcome: legal status of trees (DPW Maintained vs otherwise)
- Features: plot size, species, site info, etc.
<- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-01-28/sf_trees.csv")
sf_trees
<- sf_trees %>%
trees_df mutate(
legal_status = case_when(
== "DPW Maintained" ~ legal_status,
legal_status TRUE ~ "Other"
),plot_size = parse_number(plot_size)
%>%
) select(-address) %>%
na.omit() %>%
mutate_if(is.character, factor) # convert categorical variables to factors
Descriptive analysis
Location map of trees by legal status:
%>%
trees_df ggplot(aes(longitude, latitude, color = legal_status)) +
geom_point(size = 0.5, alpha = 0.4) +
labs(color = NULL) +
theme_minimal()
# extract year from date
<- trees_df |> mutate(year = year(date))
trees_df # summary table by legal status
%>%
trees_df select(!c(tree_id, date, species, site_info, caretaker)) %>%
tbl_summary(
by = legal_status) |>
add_p()
Characteristic | DPW Maintained, N = 18,0981 | Other, N = 5,7441 | p-value2 |
---|---|---|---|
site_order | 1.0 (1.0, 3.0) | 2.0 (1.0, 4.0) | <0.001 |
dbh | 3.0 (3.0, 6.0) | 3.0 (3.0, 8.0) | <0.001 |
plot_size | 3.00 (3.00, 3.00) | 3.00 (3.00, 3.00) | <0.001 |
latitude | 37.763 (37.738, 37.779) | 37.761 (37.741, 37.780) | 0.4 |
longitude | -122.43 (-122.47, -122.41) | -122.42 (-122.44, -122.41) | <0.001 |
year | 2,008 (2,000, 2,013) | 2,008 (1,999, 2,016) | <0.001 |
1 Median (IQR) | |||
2 Wilcoxon rank sum test |
%>%
trees_df count(legal_status, caretaker) %>%
add_count(caretaker, wt = n, name = "caretaker_count") %>%
filter(caretaker_count > 50) %>%
group_by(legal_status) %>%
mutate(percent_legal = n / sum(n)) %>%
ggplot(aes(percent_legal, caretaker, fill = legal_status)) +
geom_col(position = "dodge") +
labs(
fill = NULL,
x = "% of trees in each category"
+
) theme_classic()
Random forests
Model building
# data splitting
set.seed(123)
<- initial_split(trees_df, strata = legal_status)
trees_split <- training(trees_split)
trees_train <- testing(trees_split) trees_test
Data preprocessing (recipe):
- Specify formula.
- Set tree_id as id.
step_other()
to collapse categorical levels for species, caretaker, and the site info.- Extract year from date
- There are many more DPW maintained trees than not, so let’s downsample the data for training.
<- recipe(legal_status ~ ., data = trees_train) %>%
tree_rec update_role(tree_id, new_role = "ID") %>%
step_other(species, caretaker, threshold = 0.01) %>%
step_other(site_info, threshold = 0.005) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_date(date, features = c("year")) %>%
step_rm(date) |>
step_downsample(legal_status)
# ?step_downsample
<- prep(tree_rec)
tree_prep <- juice(tree_prep)
juiced # ?prep
# ?juice
Model specification:
- 1000 trees:
mtry
(number of variables randomly sampled at each split): tune;min_n
(minimum number of observations in terminal nodes): tune.
<- rand_forest(
tune_spec mtry = tune(),
trees = 1000,
min_n = tune()
%>%
) set_mode("classification") %>%
set_engine("ranger", importance = "permutation")
Workflow:
<- workflow() %>%
tree_wf add_recipe(tree_rec) %>%
add_model(tune_spec)
Model tuning
Set up cross-validation (CV) resamples (10-fold CV):
set.seed(234)
<- vfold_cv(trees_train, v = 10, strata = legal_status) trees_folds
Tune models on 20 random combinations of mtry
and min_n
:
::registerDoParallel()
doParallel
set.seed(345)
<- tree_wf |>
tune_res tune_grid(
resamples = trees_folds,
grid = 20
)saveRDS(tune_res, "tune_res.rds")
Model evaluation and finalization
Select best model by AUC:
# prediction results
%>% collect_metrics() |> filter(.metric == "roc_auc") |> arrange(desc(mean)) tune_res
# A tibble: 20 × 8
mtry min_n .metric .estimator mean n std_err .config
<int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
1 14 3 roc_auc binary 0.943 10 0.00131 Preprocessor1_Model14
2 19 8 roc_auc binary 0.941 10 0.00149 Preprocessor1_Model19
3 36 5 roc_auc binary 0.941 10 0.00174 Preprocessor1_Model18
4 18 11 roc_auc binary 0.940 10 0.00144 Preprocessor1_Model13
5 36 12 roc_auc binary 0.939 10 0.00174 Preprocessor1_Model20
6 10 17 roc_auc binary 0.938 10 0.00113 Preprocessor1_Model06
7 8 15 roc_auc binary 0.938 10 0.00104 Preprocessor1_Model17
8 28 19 roc_auc binary 0.937 10 0.00163 Preprocessor1_Model02
9 32 20 roc_auc binary 0.937 10 0.00171 Preprocessor1_Model04
10 26 26 roc_auc binary 0.936 10 0.00160 Preprocessor1_Model12
11 32 24 roc_auc binary 0.936 10 0.00166 Preprocessor1_Model15
12 40 22 roc_auc binary 0.935 10 0.00182 Preprocessor1_Model16
13 22 28 roc_auc binary 0.935 10 0.00160 Preprocessor1_Model09
14 12 30 roc_auc binary 0.935 10 0.00129 Preprocessor1_Model11
15 15 32 roc_auc binary 0.935 10 0.00142 Preprocessor1_Model05
16 23 31 roc_auc binary 0.935 10 0.00161 Preprocessor1_Model10
17 28 37 roc_auc binary 0.933 10 0.00167 Preprocessor1_Model03
18 7 35 roc_auc binary 0.933 10 0.000951 Preprocessor1_Model01
19 4 39 roc_auc binary 0.926 10 0.000856 Preprocessor1_Model07
20 2 6 roc_auc binary 0.904 10 0.00125 Preprocessor1_Model08
# best model
<- tune_res %>%
best_rf select_best(metric = "roc_auc")
Finalize model workflow and fit:
# finalize model workflow
<- tree_wf |>
final_rf finalize_workflow(best_rf)
final_rf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
6 Recipe Steps
• step_other()
• step_other()
• step_dummy()
• step_date()
• step_rm()
• step_downsample()
── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)
Main Arguments:
mtry = 14
trees = 1000
min_n = 3
Engine-Specific Arguments:
importance = permutation
Computational engine: ranger
# final fit
<- final_rf |>
final_fit last_fit(trees_split)
Performance on test set:
%>%
final_fit collect_metrics() %>%
filter(.metric == "roc_auc")
# A tibble: 1 × 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 roc_auc binary 0.946 Preprocessor1_Model1
Assess variance importance:
library(vip)
# trained model workflow
<- extract_workflow(final_fit)
trained_rf
<- trained_rf |> extract_fit_parsnip() |> vip() +
vip_rf theme_classic()
vip_rf
Classification trees
Specify workflow (tune cost_complexity):
<- decision_tree(
ct_mod_spec cost_complexity = tune(),
min_n = tune(),
%>%
) set_mode("classification") %>%
set_engine("rpart")
<- workflow() %>%
ct_wf add_recipe(tree_rec) %>%
add_model(ct_mod_spec)
Set up grid for tuning parameter and tune model:
set.seed(456)
# ?grid_regular
<- grid_regular(
cost_complexity_grid cost_complexity(),
levels = 20
)
<- ct_wf |>
ct_tune_res tune_grid(
resamples = trees_folds,
grid = 20
)
Select best model workflow:
# plot CV AUC as a functional of cost_complexity
%>%
ct_tune_res collect_metrics() %>%
filter(.metric == "roc_auc") |>
ggplot(aes(cost_complexity, mean)) +
geom_point() +
geom_line() +
scale_x_log10() +
theme_minimal()
# select best model
<- ct_tune_res %>%
best_ct select_best(metric = "roc_auc")
Final model fit and test results:
# finalize workflow
<- ct_wf |>
final_ct finalize_workflow(best_ct)
# final fit
<- final_ct |>
final_ct_fit last_fit(trees_split)
# performance on test set
%>%
final_ct_fit collect_metrics() %>%
filter(.metric == "roc_auc")
# A tibble: 1 × 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 roc_auc binary 0.892 Preprocessor1_Model1
# plot roc curve
%>%
final_ct_fit collect_predictions() %>%
roc_curve(truth = legal_status, `.pred_DPW Maintained`) %>%
autoplot()
Assessing variable importance:
library(rpart.plot) # for visualizing a decision tree
# trained tree workflow
<- extract_workflow(final_ct_fit)
trained_tr # trained_tr
# Plot tree structure
%>%
trained_tr extract_fit_engine() %>%
rpart.plot(roundint = FALSE, fallen.leaves = FALSE)
# Variable importance
<- trained_tr |> extract_fit_parsnip() %>%
vip_tr vip() + theme_classic()
Plot VIP graphs for random forests and classification trees side by side:
library(patchwork)
+ vip_tr vip_rf