この文章について

HSDMの11章のRプログラムをtidyに書き直した。 データ操作はtidyverse、モデル作成・訓練はtidymodels、作成されたモデルの解釈はDALEXを用いて行った。

library(tidyverse)
## -- Attaching packages --------------------------------------- tidyverse 1.3.0 --
## v ggplot2 3.3.2     v purrr   0.3.4
## v tibble  3.0.4     v dplyr   1.0.2
## v tidyr   1.1.2     v stringr 1.4.0
## v readr   1.4.0     v forcats 0.5.0
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(tidymodels)
## -- Attaching packages -------------------------------------- tidymodels 0.1.2 --
## v broom     0.7.2      v recipes   0.1.15
## v dials     0.0.9      v rsample   0.0.8 
## v infer     0.5.4      v tune      0.1.2 
## v modeldata 0.1.0      v workflows 0.2.1 
## v parsnip   0.1.5      v yardstick 0.0.7
## -- Conflicts ----------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter()   masks stats::filter()
## x recipes::fixed()  masks stringr::fixed()
## x dplyr::lag()      masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step()   masks stats::step()
library(sf)
## Linking to GEOS 3.8.1, GDAL 3.0.4, PROJ 6.3.2
library(gridExtra)
## 
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
## 
##     combine
library(discrim)
## 
## Attaching package: 'discrim'
## The following object is masked from 'package:dials':
## 
##     smoothness
library(DALEX)
## Welcome to DALEX (version: 2.1.1).
## Find examples and detailed introduction at: http://ema.drwhy.ai/
## Additional features will be available after installation of: ggpubr.
## Use 'install_dependencies()' to get all suggested dependencies
## 
## Attaching package: 'DALEX'
## The following object is masked from 'package:dplyr':
## 
##     explain

データ準備

まずデータを読み込み、訓練データとテストデータに分ける。

mammals_data <- read_csv(
  "data/tabular/species/mammals_and_bioclim_table.csv"
  ) %>%
  rename(x = X_WGS84, y = Y_WGS84) %>%
  select(x, y, VulpesVulpes, bio3, bio7, bio11, bio12)
## Warning: Missing column names filled in: 'X1' [1]
## 
## -- Column specification --------------------------------------------------------
## cols(
##   X1 = col_double(),
##   X_WGS84 = col_double(),
##   Y_WGS84 = col_double(),
##   ConnochaetesGnou = col_double(),
##   GuloGulo = col_double(),
##   PantheraOnca = col_double(),
##   PteropusGiganteus = col_double(),
##   TenrecEcaudatus = col_double(),
##   VulpesVulpes = col_double(),
##   bio3 = col_double(),
##   bio4 = col_double(),
##   bio7 = col_double(),
##   bio11 = col_double(),
##   bio12 = col_double()
## )
split_data <- mammals_data %>%
  select(-x, -y) %>%
  initial_split(p = 0.8)

train_data <- training(split_data)
test_data <- testing(split_data)

次に、特徴量エンジニアリングを行うためのレシピを定義する。 ここでは各予測変数の正規化及び目的変数のファクター化を行っている。

rec <- recipe(VulpesVulpes ~ bio3 + bio7 + bio11 + bio12,
              data = train_data) %>%
  step_normalize(all_predictors()) %>%
  step_bin2factor(VulpesVulpes)

# レシピをtrain_dataに当てはめる。正規化などの計算はデータによって変わってしまうのでここで固定する。
rec_preped <- rec %>% prep(train_data)
# 整えられたtrain_dataはjuice()で出力される。
rec_preped %>% juice()
## # A tibble: 6,834 x 5
##     bio3  bio7 bio11  bio12 VulpesVulpes
##    <dbl> <dbl> <dbl>  <dbl> <fct>       
##  1 -1.28 1.14  -1.81 -0.923 no          
##  2 -1.28 1.05  -1.81 -0.885 no          
##  3 -1.30 0.950 -1.84 -0.823 no          
##  4 -1.30 0.898 -1.82 -0.806 no          
##  5 -1.24 0.845 -1.71 -0.862 no          
##  6 -1.24 0.790 -1.70 -0.838 no          
##  7 -1.23 0.736 -1.68 -0.818 no          
##  8 -1.28 0.678 -1.76 -0.735 no          
##  9 -1.26 0.645 -1.72 -0.740 no          
## 10 -1.20 0.628 -1.60 -0.804 no          
## # ... with 6,824 more rows

ハイパーパラメータチューニングを行うために、Cross Varidation用のデータを作成する。

cv_data <- vfold_cv(train_data, v = 10) %>%
  mutate(recipes = map(splits, prepper, recipe = rec))

予測結果をプロットする関数。

plot_predicted <- function(model) {
  p <- bake(rec_preped, mammals_data) %>%
    bind_cols(predict(model, .)) %>%
    mutate(x = mammals_data$x,
           y = mammals_data$y) %>%
    mutate(original_data = as.factor(VulpesVulpes)) %>%
    rename(predicted = .pred_class) %>%
    st_as_sf(coords = c("x", "y")) %>%
    `st_crs<-`(4326)
  g1 <- p %>% ggplot() +
    geom_sf(aes(color = original_data))
  g2 <- p %>% ggplot() +
    geom_sf(aes(color = predicted))
  grid.arrange(g1, g2, nrow = 2)
}

決定木

モデルの作成。 チューニングしたいハイパーパラメータにはtune()を与える。

dt <- decision_tree(
  mode = "classification",
  cost_complexity = tune(),
  tree_depth = tune(),
  min_n = tune()
  ) %>%
  set_engine("rpart")

ハイパーパラメータチューニングを行う。 ここではグリッドサーチでチューニングする。

params = list(
  cost_complexity(),
  tree_depth(),
  min_n()
)

### Grid searching
grid <- params %>%
  grid_regular(levels = 3)

tuned <- tune_grid(
  object = dt,
  preprocessor = rec,
  resamples = cv_data,
  grid = grid,
  metrics = metric_set(accuracy, kap),
  control = control_grid(verbose = F)
)
## 
## Attaching package: 'rlang'
## The following objects are masked from 'package:purrr':
## 
##     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
##     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
##     splice
## 
## Attaching package: 'vctrs'
## The following object is masked from 'package:dplyr':
## 
##     data_frame
## The following object is masked from 'package:tibble':
## 
##     data_frame
## 
## Attaching package: 'rpart'
## The following object is masked from 'package:dials':
## 
##     prune
best_params <- tuned %>%
  select_best("accuracy") %>%
  select(cost_complexity, tree_depth, min_n)

best_params
## # A tibble: 1 x 3
##   cost_complexity tree_depth min_n
##             <dbl>      <int> <int>
## 1    0.0000000001         15    21

チューニングしたハイパーパラメータでモデルを作成し、全データを用いて訓練し直す。

dt_tuned <- update(dt, best_params)

trained <- dt_tuned %>%
  fit(VulpesVulpes ~ .,
      data = rec_preped %>%
        juice())

### Model testing
bake(rec_preped, test_data) %>%
  bind_cols(predict(trained, .)) %>%
  metrics(VulpesVulpes, .pred_class)
## # A tibble: 2 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.917
## 2 kap      binary         0.834
plot_predicted(trained)

plot(trained$fit)
text(trained$fit, cex = 0.8)

モデルの解釈。Feature Importanceについてはここ, Partial Dependence Plotについてはここ 詳しい。

explainer <- trained %>%
  explain(
    data = bake(rec_preped, mammals_data) %>%
      select(-VulpesVulpes),
    y = mammals_data %>% pull(VulpesVulpes),
    label = "Decision Tree")
## Preparation of a new explainer is initiated
##   -> model label       :  Decision Tree 
##   -> data              :  8542  rows  4  cols 
##   -> data              :  tibble converted into a data.frame 
##   -> target variable   :  8542  values 
##   -> predict function  :  yhat.model_fit  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package parsnip , ver. 0.1.5 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  0 , mean =  0.5056124 , max =  1  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -1 , mean =  -0.01357311 , max =  1  
##   A new explainer has been created! 
explainer %>%
  ingredients::feature_importance(type = "ratio", B = 1) %>%
  plot()

explainer %>%
  ingredients::partial_dependency(variale_type = "numerical") %>%
  plot()

Flexible Discriminant Analysis

fda <- discrim_flexible(
  mode = "classification",
  num_terms = tune(),
  prod_degree = tune(),
  prune_method = tune()
) %>%
  set_engine("earth")

## Train
### Hyper-parameters to be tuned
params = list(
  finalize(num_terms(), rec_preped %>% juice()),
  prod_degree(),
  prune_method()
)

### Grid searching
grid <- params %>%
  grid_regular(levels = 3)

tuned <- tune_grid(
  object = fda,
  preprocessor = rec,
  resamples = cv_data,
  grid = grid,
  metrics = metric_set(accuracy, kap),
  control = control_grid(verbose = F)
)
## Loading required package: Formula
## Loading required package: plotmo
## Loading required package: plotrix
## 
## Attaching package: 'plotrix'
## The following object is masked from 'package:scales':
## 
##     rescale
## Loading required package: TeachingDemos
## Loading required package: class
## Loaded mda 0.5-2
## 
## Attaching package: 'mda'
## The following object is masked from 'package:parsnip':
## 
##     mars
## x Fold01: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold01: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold02: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold02: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold02: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold02: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold02: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold02: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold02: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold02: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold02: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold02: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold02: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold02: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold02: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold02: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold03: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold03: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold03: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold03: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold03: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold03: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold03: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold03: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold03: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold03: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold03: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold03: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold03: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold03: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold04: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold04: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold04: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold04: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold04: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold04: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold04: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold04: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold04: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold04: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold04: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold04: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold04: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold04: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## x Fold05: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold05: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold06: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold06: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold06: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold06: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold06: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold06: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold06: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold06: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold06: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold06: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold06: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold06: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold06: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold06: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## x Fold07: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold07: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold08: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold08: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold08: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold08: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold08: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold08: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold08: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold08: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold08: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold08: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold08: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold08: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold08: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold08: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold09: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold09: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold09: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold09: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold09: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold09: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold09: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold09: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold09: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold09: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold09: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold09: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold09: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold09: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## x Fold10: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold10: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
best_params <- tuned %>%
  select_best("accuracy") 

### Train with tuned parameters
# update() does not work...

fda_tuned <- fda <- discrim_flexible(
  mode = "classification",
  num_terms = best_params$num_terms,
  prod_degree = best_params$prod_degree,
  prune_method = best_params$prune_method
) %>%
  set_engine("earth")

trained <- fda_tuned %>%
  fit(VulpesVulpes ~ .,
      data = rec_preped %>%
        juice())

### Model testing
bake(rec_preped, test_data) %>%
  bind_cols(predict(trained, .)) %>%
  metrics(VulpesVulpes, .pred_class)
## # A tibble: 2 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.898
## 2 kap      binary         0.796
plot_predicted(trained)

### Model explanation
explainer <- trained %>%
  explain(
    data = bake(rec_preped, mammals_data) %>%
      select(-VulpesVulpes),
    y = mammals_data %>% pull(VulpesVulpes),
    label = "Flexible Discriminant Analysis")
## Preparation of a new explainer is initiated
##   -> model label       :  Flexible Discriminant Analysis 
##   -> data              :  8542  rows  4  cols 
##   -> data              :  tibble converted into a data.frame 
##   -> target variable   :  8542  values 
##   -> predict function  :  yhat.model_fit  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package parsnip , ver. 0.1.5 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  9.002539e-05 , mean =  0.5170832 , max =  0.9999689  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -0.9999689 , mean =  -0.02504384 , max =  0.99991  
##   A new explainer has been created! 
explainer %>%
  ingredients::feature_importance(type = "ratio", B = 1) %>%
  plot()

explainer %>%
  ingredients::partial_dependency(variale_type = "numerical") %>%
  plot()

Muti Layer Perceptron

# Neural Net
nn <- mlp(
  mode = "classification",
  hidden_units = tune(),
  penalty = tune(),
  epochs = 20
)%>%
  set_engine("nnet") %>%
  translate()

## Train
### Hyper-parameters to be tuned
params = list(
  hidden_units(),
  penalty()
) 

### Grid searching
grid <- params %>%
  grid_regular(levels = 3)

tuned <- tune_grid(
  object = nn,
  preprocessor = rec,
  resamples = cv_data,
  grid = grid,
  metrics = metric_set(accuracy, kap),
  control = control_grid(verbose = F)
)

best_params <- tuned %>%
  select_best("accuracy") %>%
  select(hidden_units, penalty)

### Train with tuned parameters
# update() does not work well too...
nn_tuned <- mlp(
  mode = "classification",
  hidden_units = best_params$hidden_units,
  penalty = best_params$penalty,
  epochs = 100
)%>%
  set_engine("nnet") %>%
  translate()

trained <- nn_tuned %>%
  fit(VulpesVulpes ~ .,
      data = rec_preped %>%
        juice())

### Model testing
bake(rec_preped, test_data) %>%
  bind_cols(predict(trained, .)) %>%
  metrics(VulpesVulpes, .pred_class)
## # A tibble: 2 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.923
## 2 kap      binary         0.847
plot_predicted(trained)

### Model explanation
explainer <- trained %>%
  explain(
    data = bake(rec_preped, mammals_data) %>%
      select(-VulpesVulpes),
    y = mammals_data %>% pull(VulpesVulpes),
    label = "Multi Layer Perceptron")
## Preparation of a new explainer is initiated
##   -> model label       :  Multi Layer Perceptron 
##   -> data              :  8542  rows  4  cols 
##   -> data              :  tibble converted into a data.frame 
##   -> target variable   :  8542  values 
##   -> predict function  :  yhat.model_fit  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package parsnip , ver. 0.1.5 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  0.2689418 , mean =  0.5020381 , max =  0.7310586  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -0.7310586 , mean =  -0.009998775 , max =  0.7310582  
##   A new explainer has been created! 
explainer %>%
  ingredients::feature_importance(type = "ratio", B = 1) %>%
  plot()

explainer %>%
  ingredients::partial_dependency(variale_type = "numerical") %>%
  plot()