MULTINOMIAL CLASSIFICATION WITH TIDYMODELS AND TIDYTUESDAY VOLCANO ERUPTIONS
Lately I’ve been publishing screencasts demonstrating how to use the tidymodels framework, from first steps in modeling to how to evaluate complex models. Today’s screencast demonstrates how to implement multiclass or multinomial classification using with this week’s TidyTuesday dataset on volcanoes
suppressMessages(library(tidyverse))
suppressMessages(library(tidymodels))
volcano_raw <- readr::read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-05-12/volcano.csv")
## Parsed with column specification:
## cols(
## .default = col_character(),
## volcano_number = col_double(),
## latitude = col_double(),
## longitude = col_double(),
## elevation = col_double(),
## population_within_5_km = col_double(),
## population_within_10_km = col_double(),
## population_within_30_km = col_double(),
## population_within_100_km = col_double()
## )
## See spec(...) for full column specifications.
volcano_raw %>%
count(primary_volcano_type, sort = TRUE)
## # A tibble: 26 x 2
## primary_volcano_type n
## <chr> <int>
## 1 Stratovolcano 353
## 2 Stratovolcano(es) 107
## 3 Shield 85
## 4 Volcanic field 71
## 5 Pyroclastic cone(s) 70
## 6 Caldera 65
## 7 Complex 46
## 8 Shield(s) 33
## 9 Submarine 27
## 10 Lava dome(s) 26
## # ... with 16 more rows
Well, that’s probably too many types of volcanoes for us to build a model for, especially with just 958 examples. Let’s create a new volcano_type variable and build a model to distinguish between 3 volcano types:
While we use transmute() to create this new variable, let’s also select the variables to use in modeling, like the info about the tectonics around the volcano and the most important rock type
volcano_df <- volcano_raw %>%
transmute(
volcano_type = case_when(
str_detect(primary_volcano_type, "Stratovolcano") ~ "Stratovolcano",
str_detect(primary_volcano_type, "Shield") ~"Shield",
TRUE ~ "Other"),
volcano_number, latitude, longitude, elevation, tectonic_settings, major_rock_1) %>%
mutate_if(is.character, factor)
volcano_df %>% count(volcano_type, sort = TRUE)
## # A tibble: 3 x 2
## volcano_type n
## <fct> <int>
## 1 Stratovolcano 461
## 2 Other 379
## 3 Shield 118
This is not a lot of data to be building a random forest model with TBH, but it’s agreat data for demonstrating how to make a MAP
world <- map_data('world')
ggplot() + geom_map(
data = world, map = world,
aes(x = long, y = lat, map_id = region), color = 'white', fill = 'gray50', size = 0.05, alpha = 0.2
) +
geom_point(
data = volcano_df,
aes(x = longitude, y = latitude, color = volcano_type), alpha = 0.8) +
theme_void(base_family = 'IBMPlexSans') + labs(x = NULL, y =NULL, color = NULL)
## Warning: Ignoring unknown aesthetics: x, y
## Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family not
## found in Windows font database
Instead of splitting this small-ish dataset into training and testing data, let’s create a set of bootstrap resamples
library(tidymodels)
volcano_boot <- bootstraps(volcano_df)
volcano_boot
## # Bootstrap sampling
## # A tibble: 25 x 2
## splits id
## <list> <chr>
## 1 <split [958/349]> Bootstrap01
## 2 <split [958/370]> Bootstrap02
## 3 <split [958/370]> Bootstrap03
## 4 <split [958/368]> Bootstrap04
## 5 <split [958/355]> Bootstrap05
## 6 <split [958/358]> Bootstrap06
## 7 <split [958/346]> Bootstrap07
## 8 <split [958/355]> Bootstrap08
## 9 <split [958/355]> Bootstrap09
## 10 <split [958/357]> Bootstrap10
## # ... with 15 more rows
Let’s train our multinomial classification model on these resamples, but keep in mind that the performance estimates are probably pessimistically biased
library(themis)
## Warning: package 'themis' was built under R version 4.0.2
## Registered S3 methods overwritten by 'themis':
## method from
## bake.step_downsample recipes
## bake.step_upsample recipes
## prep.step_downsample recipes
## prep.step_upsample recipes
## tidy.step_downsample recipes
## tidy.step_upsample recipes
##
## Attaching package: 'themis'
## The following objects are masked from 'package:recipes':
##
## step_downsample, step_upsample, tunable.step_downsample,
## tunable.step_upsample
volcano_rec <- recipe(volcano_type ~., data = volcano_df) %>%
update_role(volcano_number, new_role ='id') %>%
step_other(tectonic_settings) %>%
step_other(major_rock_1) %>%
step_dummy(tectonic_settings, major_rock_1) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors()) %>%
step_smote(volcano_type)
Let’s walk through the steps in this recipe
First, we must tell the recipe() what our model is going to be (using a formula here) and what data we are using
Next, we update the role for volcano number, since this is a variable we want to keep around for convenience as an identifier for rows but is not a predictor or outcome.
There are a lot of different tectonic setting and rocks in this dataset, so let’s collapse some of the less frequently occurring levels into an “Other” category, for each predictor.
Next, we can create indicator variables and remove variables with zero variance
Before oversampling ,we center and scale (i.e normalize) all the predictors
Finally, we implement SMOTE oversampling so that the volcano types are balanced!
volcano_prep <- prep(volcano_rec)
juice(volcano_prep)
## # A tibble: 1,383 x 14
## volcano_number latitude longitude elevation volcano_type tectonic_settin~
## <dbl> <dbl> <dbl> <dbl> <fct> <dbl>
## 1 213004 0.746 0.101 -0.131 Other -0.289
## 2 284141 0.172 1.11 -1.39 Other -0.289
## 3 282080 0.526 0.975 -0.535 Other -0.289
## 4 285070 0.899 1.10 -0.263 Other -0.289
## 5 320020 1.44 -1.45 0.250 Other -0.289
## 6 221060 -0.0377 0.155 -0.920 Other -0.289
## 7 273088 0.0739 0.888 0.330 Other -0.289
## 8 266020 -0.451 0.918 -0.0514 Other -0.289
## 9 233011 -0.873 0.233 -0.280 Other -0.289
## 10 257040 -0.989 1.32 -0.380 Other -0.289
## # ... with 1,373 more rows, and 8 more variables:
## # tectonic_settings_Rift.zone...Oceanic.crust....15.km. <dbl>,
## # tectonic_settings_Subduction.zone...Continental.crust...25.km. <dbl>,
## # tectonic_settings_Subduction.zone...Oceanic.crust....15.km. <dbl>,
## # tectonic_settings_other <dbl>, major_rock_1_Basalt...Picro.Basalt <dbl>,
## # major_rock_1_Dacite <dbl>,
## # major_rock_1_Trachybasalt...Tephrite.Basanite <dbl>,
## # major_rock_1_other <dbl>
Before using prep() these steps have been defined but not actually run or implemented. The prep() function is where everything gets evaluated. You can use juice() to get the preprocessed data back out and check on your results
Now it’s time to specify our model. I am using a workflow() in this example for convenience; these are objects that can help you manage modeling pipelines more easily, with peices that fit together like Lego blocks. This workflow() contains both the recipe and the model a random forest classifier. The ranger implementation for random forest can handle multinomial classification without any special handling
rf_spec <- rand_forest(trees = 1000) %>%
set_mode('classification') %>%
set_engine('ranger')
volcano_wf <- workflow() %>%
add_recipe(volcano_rec) %>%
add_model(rf_spec)
volcano_wf
## == Workflow ===================
## Preprocessor: Recipe
## Model: rand_forest()
##
## -- Preprocessor ---------------
## 6 Recipe Steps
##
## * step_other()
## * step_other()
## * step_dummy()
## * step_zv()
## * step_normalize()
## * step_smote()
##
## -- Model ----------------------
## Random Forest Model Specification (classification)
##
## Main Arguments:
## trees = 1000
##
## Computational engine: ranger
Now we can fit our workflow to our samples
volcano_res <- fit_resamples(
volcano_wf,
resamples = volcano_boot,
control = control_resamples(save_pred = TRUE)
)
One of the biggest differences when working with multiclass problems is that yoru performance metrics are different. The yeardsticks package provides implementations for many multiclass metrics
volcano_res %>% collect_metrics()
## # A tibble: 2 x 5
## .metric .estimator mean n std_err
## <chr> <chr> <dbl> <int> <dbl>
## 1 accuracy multiclass 0.646 25 0.00453
## 2 roc_auc hand_till 0.788 25 0.00296
We can create a confusion matrix to see how the different classes did
volcano_res %>% collect_predictions() %>% conf_mat(volcano_type, .pred_class)
## Truth
## Prediction Other Shield Stratovolcano
## Other 1999 349 867
## Shield 264 570 227
## Stratovolcano 1259 198 3207
Even after using SMOTE oversampling, the stratovolcanoes are easiest to identify
We computed accuracu amd AUC during fit_resamples(), but we can always fo back and compute other metrics we are interested in if we saved the predictions. We can even group_by() resample, if we like
volcano_res %>% collect_predictions() %>%
group_by(id) %>%
ppv(volcano_type, .pred_class)
## # A tibble: 25 x 4
## id .metric .estimator .estimate
## <chr> <chr> <chr> <dbl>
## 1 Bootstrap01 ppv macro 0.645
## 2 Bootstrap02 ppv macro 0.586
## 3 Bootstrap03 ppv macro 0.613
## 4 Bootstrap04 ppv macro 0.606
## 5 Bootstrap05 ppv macro 0.600
## 6 Bootstrap06 ppv macro 0.613
## 7 Bootstrap07 ppv macro 0.646
## 8 Bootstrap08 ppv macro 0.619
## 9 Bootstrap09 ppv macro 0.670
## 10 Bootstrap10 ppv macro 0.561
## # ... with 15 more rows
We can we learn about variable importance, using the vip package?
library(vip)
## Warning: package 'vip' was built under R version 4.0.2
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
rf_spec %>% set_engine('ranger', importance = 'permutation') %>%
fit(volcano_type ~.,
data = juice(volcano_prep) %>%
select(-volcano_number) %>%
janitor::clean_names()
) %>%
vip(geom = 'point')
The spatial information is really important for the mode, followed by the presence of basalt. Let’s explore the spatial information a bit further, and make a map showing how right or wrong our modeling is across the world. Let’s join the predictions back to the original data
volcano_pred <- volcano_res %>%
collect_predictions() %>%
mutate(correct = volcano_type == .pred_class) %>%
left_join(volcano_df %>% mutate(.row = row_number()))
## Joining, by = c(".row", "volcano_type")
volcano_pred
## # A tibble: 8,940 x 14
## id .pred_Other .pred_Shield .pred_Stratovol~ .row .pred_class
## <chr> <dbl> <dbl> <dbl> <int> <fct>
## 1 Boot~ 0.584 0.122 0.294 1 Other
## 2 Boot~ 0.255 0.0930 0.652 5 Stratovolc~
## 3 Boot~ 0.165 0.0724 0.762 6 Stratovolc~
## 4 Boot~ 0.567 0.108 0.325 15 Other
## 5 Boot~ 0.465 0.0528 0.482 16 Stratovolc~
## 6 Boot~ 0.228 0.0511 0.721 18 Stratovolc~
## 7 Boot~ 0.124 0.603 0.273 26 Shield
## 8 Boot~ 0.214 0.0672 0.719 27 Stratovolc~
## 9 Boot~ 0.391 0.0301 0.579 30 Stratovolc~
## 10 Boot~ 0.217 0.523 0.260 32 Shield
## # ... with 8,930 more rows, and 8 more variables: volcano_type <fct>,
## # correct <lgl>, volcano_number <dbl>, latitude <dbl>, longitude <dbl>,
## # elevation <dbl>, tectonic_settings <fct>, major_rock_1 <fct>
Then, let’s make a map using state_summary_hex(). Within each hexagon, let’s take the mean of correct to find what percentage of volcanoes were classified correctly, across all our bootstrap resamples
ggplot() + geom_map(
data = world, map = world,
aes(long, lat, map_id = region),
color = 'white', fill = 'gray90', size = 0.05, alpha = 0.5
) +
stat_summary_hex(
data = volcano_pred,
aes(longitude, latitude, z = as.integer(correct)),
fun = 'mean',
alpha = 0.7, bins = 50
) +
scale_fill_gradient(high = 'cyan', labels = scales::percent) + theme_void(base_family = 'IBMPlexSans') +
labs(x = NULL, y = NULL, fill = 'Percent classified\ncorrectly')
## Warning: Ignoring unknown aesthetics: x, y
## Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family not
## found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database