HSDMの11章のRプログラムをtidyに書き直した。 データ操作はtidyverse、モデル作成・訓練はtidymodels、作成されたモデルの解釈はDALEXを用いて行った。
library(tidyverse)
## -- Attaching packages --------------------------------------- tidyverse 1.3.0 --
## v ggplot2 3.3.2 v purrr 0.3.4
## v tibble 3.0.4 v dplyr 1.0.2
## v tidyr 1.1.2 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.0
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(tidymodels)
## -- Attaching packages -------------------------------------- tidymodels 0.1.2 --
## v broom 0.7.2 v recipes 0.1.15
## v dials 0.0.9 v rsample 0.0.8
## v infer 0.5.4 v tune 0.1.2
## v modeldata 0.1.0 v workflows 0.2.1
## v parsnip 0.1.5 v yardstick 0.0.7
## -- Conflicts ----------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter() masks stats::filter()
## x recipes::fixed() masks stringr::fixed()
## x dplyr::lag() masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step() masks stats::step()
library(sf)
## Linking to GEOS 3.8.1, GDAL 3.0.4, PROJ 6.3.2
library(gridExtra)
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
library(discrim)
##
## Attaching package: 'discrim'
## The following object is masked from 'package:dials':
##
## smoothness
library(DALEX)
## Welcome to DALEX (version: 2.1.1).
## Find examples and detailed introduction at: http://ema.drwhy.ai/
## Additional features will be available after installation of: ggpubr.
## Use 'install_dependencies()' to get all suggested dependencies
##
## Attaching package: 'DALEX'
## The following object is masked from 'package:dplyr':
##
## explain
まずデータを読み込み、訓練データとテストデータに分ける。
mammals_data <- read_csv(
"data/tabular/species/mammals_and_bioclim_table.csv"
) %>%
rename(x = X_WGS84, y = Y_WGS84) %>%
select(x, y, VulpesVulpes, bio3, bio7, bio11, bio12)
## Warning: Missing column names filled in: 'X1' [1]
##
## -- Column specification --------------------------------------------------------
## cols(
## X1 = col_double(),
## X_WGS84 = col_double(),
## Y_WGS84 = col_double(),
## ConnochaetesGnou = col_double(),
## GuloGulo = col_double(),
## PantheraOnca = col_double(),
## PteropusGiganteus = col_double(),
## TenrecEcaudatus = col_double(),
## VulpesVulpes = col_double(),
## bio3 = col_double(),
## bio4 = col_double(),
## bio7 = col_double(),
## bio11 = col_double(),
## bio12 = col_double()
## )
split_data <- mammals_data %>%
select(-x, -y) %>%
initial_split(p = 0.8)
train_data <- training(split_data)
test_data <- testing(split_data)
次に、特徴量エンジニアリングを行うためのレシピを定義する。 ここでは各予測変数の正規化及び目的変数のファクター化を行っている。
rec <- recipe(VulpesVulpes ~ bio3 + bio7 + bio11 + bio12,
data = train_data) %>%
step_normalize(all_predictors()) %>%
step_bin2factor(VulpesVulpes)
# レシピをtrain_dataに当てはめる。正規化などの計算はデータによって変わってしまうのでここで固定する。
rec_preped <- rec %>% prep(train_data)
# 整えられたtrain_dataはjuice()で出力される。
rec_preped %>% juice()
## # A tibble: 6,834 x 5
## bio3 bio7 bio11 bio12 VulpesVulpes
## <dbl> <dbl> <dbl> <dbl> <fct>
## 1 -1.28 1.14 -1.81 -0.923 no
## 2 -1.28 1.05 -1.81 -0.885 no
## 3 -1.30 0.950 -1.84 -0.823 no
## 4 -1.30 0.898 -1.82 -0.806 no
## 5 -1.24 0.845 -1.71 -0.862 no
## 6 -1.24 0.790 -1.70 -0.838 no
## 7 -1.23 0.736 -1.68 -0.818 no
## 8 -1.28 0.678 -1.76 -0.735 no
## 9 -1.26 0.645 -1.72 -0.740 no
## 10 -1.20 0.628 -1.60 -0.804 no
## # ... with 6,824 more rows
ハイパーパラメータチューニングを行うために、Cross Varidation用のデータを作成する。
cv_data <- vfold_cv(train_data, v = 10) %>%
mutate(recipes = map(splits, prepper, recipe = rec))
予測結果をプロットする関数。
plot_predicted <- function(model) {
p <- bake(rec_preped, mammals_data) %>%
bind_cols(predict(model, .)) %>%
mutate(x = mammals_data$x,
y = mammals_data$y) %>%
mutate(original_data = as.factor(VulpesVulpes)) %>%
rename(predicted = .pred_class) %>%
st_as_sf(coords = c("x", "y")) %>%
`st_crs<-`(4326)
g1 <- p %>% ggplot() +
geom_sf(aes(color = original_data))
g2 <- p %>% ggplot() +
geom_sf(aes(color = predicted))
grid.arrange(g1, g2, nrow = 2)
}
モデルの作成。 チューニングしたいハイパーパラメータにはtune()を与える。
dt <- decision_tree(
mode = "classification",
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
) %>%
set_engine("rpart")
ハイパーパラメータチューニングを行う。 ここではグリッドサーチでチューニングする。
params = list(
cost_complexity(),
tree_depth(),
min_n()
)
### Grid searching
grid <- params %>%
grid_regular(levels = 3)
tuned <- tune_grid(
object = dt,
preprocessor = rec,
resamples = cv_data,
grid = grid,
metrics = metric_set(accuracy, kap),
control = control_grid(verbose = F)
)
##
## Attaching package: 'rlang'
## The following objects are masked from 'package:purrr':
##
## %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
## flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
## splice
##
## Attaching package: 'vctrs'
## The following object is masked from 'package:dplyr':
##
## data_frame
## The following object is masked from 'package:tibble':
##
## data_frame
##
## Attaching package: 'rpart'
## The following object is masked from 'package:dials':
##
## prune
best_params <- tuned %>%
select_best("accuracy") %>%
select(cost_complexity, tree_depth, min_n)
best_params
## # A tibble: 1 x 3
## cost_complexity tree_depth min_n
## <dbl> <int> <int>
## 1 0.0000000001 15 21
チューニングしたハイパーパラメータでモデルを作成し、全データを用いて訓練し直す。
dt_tuned <- update(dt, best_params)
trained <- dt_tuned %>%
fit(VulpesVulpes ~ .,
data = rec_preped %>%
juice())
### Model testing
bake(rec_preped, test_data) %>%
bind_cols(predict(trained, .)) %>%
metrics(VulpesVulpes, .pred_class)
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.917
## 2 kap binary 0.834
plot_predicted(trained)
plot(trained$fit)
text(trained$fit, cex = 0.8)
モデルの解釈。Feature Importanceについてはここ, Partial Dependence Plotについてはここ 詳しい。
explainer <- trained %>%
explain(
data = bake(rec_preped, mammals_data) %>%
select(-VulpesVulpes),
y = mammals_data %>% pull(VulpesVulpes),
label = "Decision Tree")
## Preparation of a new explainer is initiated
## -> model label : Decision Tree
## -> data : 8542 rows 4 cols
## -> data : tibble converted into a data.frame
## -> target variable : 8542 values
## -> predict function : yhat.model_fit will be used ( [33m default [39m )
## -> predicted values : No value for predict function target column. ( [33m default [39m )
## -> model_info : package parsnip , ver. 0.1.5 , task classification ( [33m default [39m )
## -> predicted values : numerical, min = 0 , mean = 0.5056124 , max = 1
## -> residual function : difference between y and yhat ( [33m default [39m )
## -> residuals : numerical, min = -1 , mean = -0.01357311 , max = 1
## [32m A new explainer has been created! [39m
explainer %>%
ingredients::feature_importance(type = "ratio", B = 1) %>%
plot()
explainer %>%
ingredients::partial_dependency(variale_type = "numerical") %>%
plot()
fda <- discrim_flexible(
mode = "classification",
num_terms = tune(),
prod_degree = tune(),
prune_method = tune()
) %>%
set_engine("earth")
## Train
### Hyper-parameters to be tuned
params = list(
finalize(num_terms(), rec_preped %>% juice()),
prod_degree(),
prune_method()
)
### Grid searching
grid <- params %>%
grid_regular(levels = 3)
tuned <- tune_grid(
object = fda,
preprocessor = rec,
resamples = cv_data,
grid = grid,
metrics = metric_set(accuracy, kap),
control = control_grid(verbose = F)
)
## Loading required package: Formula
## Loading required package: plotmo
## Loading required package: plotrix
##
## Attaching package: 'plotrix'
## The following object is masked from 'package:scales':
##
## rescale
## Loading required package: TeachingDemos
## Loading required package: class
## Loaded mda 0.5-2
##
## Attaching package: 'mda'
## The following object is masked from 'package:parsnip':
##
## mars
## x Fold01: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold01: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold02: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold02: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold02: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold02: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold02: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold02: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold02: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold02: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold02: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold02: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold02: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold02: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold02: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold02: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold03: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold03: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold03: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold03: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold03: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold03: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold03: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold03: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold03: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold03: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold03: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold03: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold03: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold03: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold04: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold04: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold04: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold04: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold04: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold04: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold04: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold04: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold04: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold04: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold04: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold04: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold04: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold04: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## x Fold05: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold05: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold06: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold06: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold06: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold06: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold06: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold06: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold06: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold06: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold06: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold06: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold06: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold06: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold06: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold06: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## x Fold07: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold07: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold08: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold08: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold08: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold08: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold08: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold08: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold08: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold08: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold08: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold08: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold08: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold08: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold08: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold08: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## ! Fold09: preprocessor 1/1, model 1/18: degenerate problem; no discrimination
## ! Fold09: preprocessor 1/1, model 1/18 (predictions): no non-missing arguments to ...
## x Fold09: preprocessor 1/1, model 1/18 (predictions): Error in terms.default(objec...
## ! Fold09: preprocessor 1/1, model 4/18: degenerate problem; no discrimination
## ! Fold09: preprocessor 1/1, model 4/18 (predictions): no non-missing arguments to ...
## x Fold09: preprocessor 1/1, model 4/18 (predictions): Error in terms.default(objec...
## ! Fold09: preprocessor 1/1, model 7/18: degenerate problem; no discrimination
## ! Fold09: preprocessor 1/1, model 7/18 (predictions): no non-missing arguments to ...
## x Fold09: preprocessor 1/1, model 7/18 (predictions): Error in terms.default(objec...
## ! Fold09: preprocessor 1/1, model 10/18: degenerate problem; no discrimination
## ! Fold09: preprocessor 1/1, model 10/18 (predictions): no non-missing arguments to...
## x Fold09: preprocessor 1/1, model 10/18 (predictions): Error in terms.default(obje...
## x Fold09: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold09: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
## x Fold10: preprocessor 1/1, model 13/18: Error: Fortran routine XHAUST returned er...
## x Fold10: preprocessor 1/1, model 16/18: Error: Fortran routine XHAUST returned er...
best_params <- tuned %>%
select_best("accuracy")
### Train with tuned parameters
# update() does not work...
fda_tuned <- fda <- discrim_flexible(
mode = "classification",
num_terms = best_params$num_terms,
prod_degree = best_params$prod_degree,
prune_method = best_params$prune_method
) %>%
set_engine("earth")
trained <- fda_tuned %>%
fit(VulpesVulpes ~ .,
data = rec_preped %>%
juice())
### Model testing
bake(rec_preped, test_data) %>%
bind_cols(predict(trained, .)) %>%
metrics(VulpesVulpes, .pred_class)
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.898
## 2 kap binary 0.796
plot_predicted(trained)
### Model explanation
explainer <- trained %>%
explain(
data = bake(rec_preped, mammals_data) %>%
select(-VulpesVulpes),
y = mammals_data %>% pull(VulpesVulpes),
label = "Flexible Discriminant Analysis")
## Preparation of a new explainer is initiated
## -> model label : Flexible Discriminant Analysis
## -> data : 8542 rows 4 cols
## -> data : tibble converted into a data.frame
## -> target variable : 8542 values
## -> predict function : yhat.model_fit will be used ( [33m default [39m )
## -> predicted values : No value for predict function target column. ( [33m default [39m )
## -> model_info : package parsnip , ver. 0.1.5 , task classification ( [33m default [39m )
## -> predicted values : numerical, min = 9.002539e-05 , mean = 0.5170832 , max = 0.9999689
## -> residual function : difference between y and yhat ( [33m default [39m )
## -> residuals : numerical, min = -0.9999689 , mean = -0.02504384 , max = 0.99991
## [32m A new explainer has been created! [39m
explainer %>%
ingredients::feature_importance(type = "ratio", B = 1) %>%
plot()
explainer %>%
ingredients::partial_dependency(variale_type = "numerical") %>%
plot()
# Neural Net
nn <- mlp(
mode = "classification",
hidden_units = tune(),
penalty = tune(),
epochs = 20
)%>%
set_engine("nnet") %>%
translate()
## Train
### Hyper-parameters to be tuned
params = list(
hidden_units(),
penalty()
)
### Grid searching
grid <- params %>%
grid_regular(levels = 3)
tuned <- tune_grid(
object = nn,
preprocessor = rec,
resamples = cv_data,
grid = grid,
metrics = metric_set(accuracy, kap),
control = control_grid(verbose = F)
)
best_params <- tuned %>%
select_best("accuracy") %>%
select(hidden_units, penalty)
### Train with tuned parameters
# update() does not work well too...
nn_tuned <- mlp(
mode = "classification",
hidden_units = best_params$hidden_units,
penalty = best_params$penalty,
epochs = 100
)%>%
set_engine("nnet") %>%
translate()
trained <- nn_tuned %>%
fit(VulpesVulpes ~ .,
data = rec_preped %>%
juice())
### Model testing
bake(rec_preped, test_data) %>%
bind_cols(predict(trained, .)) %>%
metrics(VulpesVulpes, .pred_class)
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.923
## 2 kap binary 0.847
plot_predicted(trained)
### Model explanation
explainer <- trained %>%
explain(
data = bake(rec_preped, mammals_data) %>%
select(-VulpesVulpes),
y = mammals_data %>% pull(VulpesVulpes),
label = "Multi Layer Perceptron")
## Preparation of a new explainer is initiated
## -> model label : Multi Layer Perceptron
## -> data : 8542 rows 4 cols
## -> data : tibble converted into a data.frame
## -> target variable : 8542 values
## -> predict function : yhat.model_fit will be used ( [33m default [39m )
## -> predicted values : No value for predict function target column. ( [33m default [39m )
## -> model_info : package parsnip , ver. 0.1.5 , task classification ( [33m default [39m )
## -> predicted values : numerical, min = 0.2689418 , mean = 0.5020381 , max = 0.7310586
## -> residual function : difference between y and yhat ( [33m default [39m )
## -> residuals : numerical, min = -0.7310586 , mean = -0.009998775 , max = 0.7310582
## [32m A new explainer has been created! [39m
explainer %>%
ingredients::feature_importance(type = "ratio", B = 1) %>%
plot()
explainer %>%
ingredients::partial_dependency(variale_type = "numerical") %>%
plot()