HSDMの15章のRプログラムをtidyに書き直した。 データ操作はtidyverse、モデル作成・訓練はtidymodels、作成されたモデルの解釈はDALEXを用いて行った。 ここでは決定木、ランダムフォレスト、ブースト回帰木の3つのモデルを用いてアカギツネの分布予測モデルを作成し、その性能評価を行う。
library(tidyverse)
## -- Attaching packages --------------------------------------- tidyverse 1.3.0 --
## v ggplot2 3.3.2 v purrr 0.3.4
## v tibble 3.0.4 v dplyr 1.0.2
## v tidyr 1.1.2 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.0
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(tidymodels)
## -- Attaching packages -------------------------------------- tidymodels 0.1.2 --
## v broom 0.7.2 v recipes 0.1.15
## v dials 0.0.9 v rsample 0.0.8
## v infer 0.5.4 v tune 0.1.2
## v modeldata 0.1.0 v workflows 0.2.1
## v parsnip 0.1.5 v yardstick 0.0.7
## -- Conflicts ----------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter() masks stats::filter()
## x recipes::fixed() masks stringr::fixed()
## x dplyr::lag() masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step() masks stats::step()
library(gridExtra)
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
まずはじめにデータを読み込み、予測モデルを作成する。
mammals_data <- read_csv(
"data/tabular/species/mammals_and_bioclim_table.csv"
) %>%
rename(x = X_WGS84, y = Y_WGS84) %>%
select(x, y, VulpesVulpes, bio3, bio7, bio11, bio12)
## Warning: Missing column names filled in: 'X1' [1]
##
## -- Column specification --------------------------------------------------------
## cols(
## X1 = col_double(),
## X_WGS84 = col_double(),
## Y_WGS84 = col_double(),
## ConnochaetesGnou = col_double(),
## GuloGulo = col_double(),
## PantheraOnca = col_double(),
## PteropusGiganteus = col_double(),
## TenrecEcaudatus = col_double(),
## VulpesVulpes = col_double(),
## bio3 = col_double(),
## bio4 = col_double(),
## bio7 = col_double(),
## bio11 = col_double(),
## bio12 = col_double()
## )
split_data <- mammals_data %>%
select(-x, -y) %>%
initial_split(p = 0.8)
train_data <- training(split_data)
test_data <- testing(split_data)
rec <- recipe(VulpesVulpes ~ bio3 + bio7 + bio11 + bio12,
data = train_data) %>%
step_normalize(all_predictors())
rec_preped <- rec %>% prep(train_data)
rec_preped %>% juice()
## # A tibble: 6,834 x 5
## bio3 bio7 bio11 bio12 VulpesVulpes
## <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 -1.28 1.14 -1.82 -0.936 0
## 2 -1.30 1.08 -1.85 -0.892 0
## 3 -1.29 1.05 -1.81 -0.898 0
## 4 -1.31 0.999 -1.85 -0.852 0
## 5 -1.30 0.947 -1.84 -0.835 0
## 6 -1.30 0.895 -1.82 -0.818 0
## 7 -1.24 0.842 -1.72 -0.874 0
## 8 -1.28 0.676 -1.76 -0.748 0
## 9 -1.26 0.644 -1.72 -0.752 0
## 10 -1.20 0.626 -1.61 -0.816 0
## # ... with 6,824 more rows
cv_data <- vfold_cv(train_data, v = 10) %>%
mutate(recipes = map(splits, prepper, recipe = rec))
prediction_results <- bake(rec_preped, mammals_data)
dt <- decision_tree(
mode = "regression",
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
) %>%
set_engine("rpart")
## Train
### Hyper-parameters to be tuned
params = list(
cost_complexity(),
tree_depth(),
min_n()
)
### Grid searching
grid <- params %>%
grid_regular(levels = 3)
tuned_dt <- tune_grid(
object = dt,
preprocessor = rec,
resamples = cv_data,
grid = grid,
metrics = metric_set(rmse),
control = control_grid(verbose = F)
)
##
## Attaching package: 'rlang'
## The following objects are masked from 'package:purrr':
##
## %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
## flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
## splice
##
## Attaching package: 'vctrs'
## The following object is masked from 'package:dplyr':
##
## data_frame
## The following object is masked from 'package:tibble':
##
## data_frame
##
## Attaching package: 'rpart'
## The following object is masked from 'package:dials':
##
## prune
best_params <- tuned_dt %>%
select_best("rmse") %>%
select(cost_complexity, tree_depth, min_n)
### Train with tuned parameters
dt_tuned <- update(dt, best_params)
trained <- dt_tuned %>%
fit(VulpesVulpes ~ .,
data = rec_preped %>%
juice())
### Model testing
bake(rec_preped, test_data) %>%
bind_cols(predict(trained, .)) %>%
metrics(VulpesVulpes, .pred)
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 0.253
## 2 rsq standard 0.747
## 3 mae standard 0.0995
### Prediction result
prediction_results <- prediction_results %>%
bind_cols(predict(trained, .)) %>%
rename(decision_tree = .pred)
rf <- rand_forest(
mode = "regression",
trees = tune(),
min_n = tune(),
mtry = tune()
) %>%
set_engine("ranger", num.threads = parallel::detectCores())
## Train
### Hyper-parameters to be tuned
params = list(
trees = trees(),
min_n = min_n(),
mtry = finalize(mtry(),
rec_preped %>% juice() %>%
select(-VulpesVulpes))
)
### Grid searching
grid <- params %>%
grid_regular(levels = 3)
tuned_rf <- tune_grid(
object = rf,
preprocessor = rec,
resamples = cv_data,
grid = grid,
metrics = metric_set(rmse),
control = control_grid(verbose = F)
)
best_params <- tuned_rf %>%
select_best("rmse")
### Train with tuned parameters
rf_tuned <- update(rf,
best_params %>%
select(mtry, trees, min_n))
trained <- rf_tuned %>%
fit(VulpesVulpes ~ .,
data = rec_preped %>%
juice())
### Model testing
bake(rec_preped, test_data) %>%
bind_cols(predict(trained, .)) %>%
metrics(VulpesVulpes, .pred)
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 0.206
## 2 rsq standard 0.829
## 3 mae standard 0.0948
### Prediction result
prediction_results <- prediction_results %>%
bind_cols(predict(trained, .)) %>%
rename(random_forest = .pred)
bt <- boost_tree(
mode = "regression",
mtry = tune(),
trees = tune(),
tree_depth = tune()
) %>%
set_engine("xgboost", nthread = parallel::detectCores())
## Train
### Hyper-parameters to be tuned
params <- list(
mtry = finalize(mtry(),
rec_preped %>% juice() %>%
select(-VulpesVulpes)),
trees = trees(),
tree_depth = tree_depth()
)
### Grid searching
grid <- params %>%
grid_regular(levels = 3)
tuned_bt <- tune_grid(
object = bt,
preprocessor = rec,
resamples = cv_data,
grid = grid,
metrics = metric_set(rmse),
control = control_grid(verbose = F)
)
##
## Attaching package: 'xgboost'
## The following object is masked from 'package:dplyr':
##
## slice
best_params <- tuned_bt %>%
select_best("rmse")
### Train with tuned parameters
bt_tuned <- update(bt,
best_params %>%
select(mtry, trees, tree_depth))
trained <- bt_tuned %>%
fit(VulpesVulpes ~ .,
data = rec_preped %>%
juice())
### Model testing
bake(rec_preped, test_data) %>%
bind_cols(predict(trained, .)) %>%
metrics(VulpesVulpes, .pred)
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 0.221
## 2 rsq standard 0.804
## 3 mae standard 0.104
### Prediction result
prediction_results <- prediction_results %>%
bind_cols(predict(trained, .)) %>%
rename(boosted_trees = .pred)
prediction_results <- prediction_results %>%
rename(observed = VulpesVulpes) %>%
select(observed, decision_tree, boosted_trees, random_forest) %>%
pivot_longer(-observed, names_to = "method", values_to = "predicted")
prediction_results %>%
write_csv("prediction_results.csv")
15.1では、在不在データを用いたモデルの評価を行う。
###15.1.1 Calibration Plot ここでは、予測された在確率と実際の在割合との関係をプロットすることによってモデルの予測結果を可視化する。
n_bin <- 10
breaks <- seq(-0.1, 1.1, length = n_bin+1)
prediction_results %>%
mutate(
predicted_cut = cut(predicted, breaks = breaks, include.lowest = T)
) %>%
group_by(method, predicted_cut) %>%
summarise(
predicted_mean = mean(predicted),
num_data = length(predicted),
observed_rate = (as.integer(observed==T) %>% sum()) / num_data
) %>%
do(
plots = ggplot(data = .,
aes(x = observed_rate, y = predicted_mean, size = num_data)) +
geom_point(aes(size = num_data)) +
geom_abline(linetype = "dashed") +
labs(
x = "Observed Occurence Rate",
y = "Predicted Probability"
) +
ggtitle(.$method)
) -> p
## `summarise()` regrouping output by 'method' (override with `.groups` argument)
grid.arrange(p$plots[[1]], p$plots[[2]], p$plots[[3]], nrow = 2)
教科書ではプロットが対角線上にのっていることが良いと書かれていたが、Accuracyとこの指標の間に関係があるようには見えない。 関係がありそうなのは、予測された在確率が0と1の両端に集中していることであるように見える。
ここでは、予測された在確率にある閾値を適用することによって、バイナリの予測に変換する際の、閾値の選定を行う。 実際には、回帰モデルの変わりに分類モデルを作成し、バイナリで予測結果がでるようにするのが良いと思う。
バイナリ化を行うと、Confusion Matrixを作成することができる。
th <- 0.5
conf <- prediction_results %>%
mutate(
observed = if_else(observed > th, "presence", "absence") %>%
as.factor(),
predicted_class = if_else(predicted > th, "presence", "absence") %>%
as.factor()
) %>%
group_by(method) %>%
conf_mat(observed, predicted_class)
conf$method
## [1] "boosted_trees" "decision_tree" "random_forest"
conf$conf_mat
## [[1]]
## Truth
## Prediction absence presence
## absence 4277 51
## presence 62 4152
##
## [[2]]
## Truth
## Prediction absence presence
## absence 4063 211
## presence 276 3992
##
## [[3]]
## Truth
## Prediction absence presence
## absence 4281 38
## presence 58 4165
閾値をスライディングさせ、最も良い閾値を選定する。 閾値を動かしたときの各種の指標の変化を可視化する。
ths <- 1:9 / 10
metrics_df <- function(th) {
prediction_results %>%
mutate(
observed = if_else(observed > th, "presence", "absence") %>%
as.factor(),
predicted_class = if_else(predicted > th, "presence", "absence") %>%
as.factor()
) %>%
group_by(method) %>%
summarise(
specificity = specificity_vec(observed, predicted_class),
acc = accuracy_vec(observed, predicted_class),
sensitivity = sensitivity_vec(observed, predicted_class),
kappa = kap_vec(observed, predicted_class)
) %>%
mutate(threshold = th)
}
sliding_th <- map(ths, metrics_df) %>%
reduce(bind_rows) %>%
arrange(method)
## `summarise()` ungrouping output (override with `.groups` argument)
## `summarise()` ungrouping output (override with `.groups` argument)
## `summarise()` ungrouping output (override with `.groups` argument)
## `summarise()` ungrouping output (override with `.groups` argument)
## `summarise()` ungrouping output (override with `.groups` argument)
## `summarise()` ungrouping output (override with `.groups` argument)
## `summarise()` ungrouping output (override with `.groups` argument)
## `summarise()` ungrouping output (override with `.groups` argument)
## `summarise()` ungrouping output (override with `.groups` argument)
sliding_th %>%
print(n = nrow(.))
## # A tibble: 27 x 6
## method specificity acc sensitivity kappa threshold
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 boosted_trees 0.998 0.977 0.957 0.954 0.1
## 2 boosted_trees 0.996 0.982 0.969 0.965 0.2
## 3 boosted_trees 0.995 0.986 0.977 0.972 0.3
## 4 boosted_trees 0.993 0.987 0.981 0.974 0.4
## 5 boosted_trees 0.988 0.987 0.986 0.974 0.5
## 6 boosted_trees 0.984 0.987 0.990 0.974 0.6
## 7 boosted_trees 0.979 0.985 0.991 0.970 0.7
## 8 boosted_trees 0.971 0.983 0.995 0.966 0.8
## 9 boosted_trees 0.954 0.976 0.998 0.953 0.9
## 10 decision_tree 0.994 0.906 0.822 0.813 0.1
## 11 decision_tree 0.983 0.932 0.881 0.863 0.2
## 12 decision_tree 0.976 0.938 0.901 0.876 0.3
## 13 decision_tree 0.961 0.943 0.926 0.886 0.4
## 14 decision_tree 0.950 0.943 0.936 0.886 0.5
## 15 decision_tree 0.926 0.941 0.955 0.882 0.6
## 16 decision_tree 0.903 0.937 0.969 0.873 0.7
## 17 decision_tree 0.868 0.925 0.979 0.849 0.8
## 18 decision_tree 0.812 0.902 0.990 0.804 0.9
## 19 random_forest 1.00 0.931 0.864 0.862 0.1
## 20 random_forest 0.998 0.959 0.921 0.919 0.2
## 21 random_forest 0.996 0.975 0.955 0.951 0.3
## 22 random_forest 0.994 0.988 0.983 0.976 0.4
## 23 random_forest 0.991 0.989 0.987 0.978 0.5
## 24 random_forest 0.985 0.988 0.990 0.976 0.6
## 25 random_forest 0.968 0.980 0.993 0.961 0.7
## 26 random_forest 0.930 0.963 0.995 0.927 0.8
## 27 random_forest 0.852 0.926 0.999 0.852 0.9
# Accuracyで最も良い閾値を選ぶ。
sliding_th %>%
group_by(method) %>%
slice_max(acc)
## # A tibble: 3 x 6
## # Groups: method [3]
## method specificity acc sensitivity kappa threshold
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 boosted_trees 0.984 0.987 0.990 0.974 0.6
## 2 decision_tree 0.950 0.943 0.936 0.886 0.5
## 3 random_forest 0.991 0.989 0.987 0.978 0.5
plots <- sliding_th %>%
pivot_longer(
cols = c(specificity, acc, sensitivity, kappa),
names_to = "metrics",
values_to = "evaluation"
) %>%
group_by(method) %>%
do(
plots =
ggplot(data = .,
aes(x = threshold, y = evaluation)) +
geom_line(aes(color = metrics)) +
geom_point(aes(color = metrics)) +
ggtitle(.$method) +
ylim(0.75, 1.0)
)
grid.arrange(
plots$plots[[1]],
plots$plots[[2]],
plots$plots[[3]],
nrow = 2
)
閾値に依存しない指標の一つに、AUCがある。
prediction_results %>%
mutate(
observed = as.factor(1-observed),
predicted = predicted
) %>%
group_by(method) %>%
roc_curve(observed, predicted) %>%
autoplot()
prediction_results %>%
mutate(
observed = as.factor(1-observed),
predicted = predicted
) %>%
group_by(method) %>%
roc_auc(observed, predicted)
## # A tibble: 3 x 4
## method .metric .estimator .estimate
## <chr> <chr> <chr> <dbl>
## 1 boosted_trees roc_auc binary 0.998
## 2 decision_tree roc_auc binary 0.986
## 3 random_forest roc_auc binary 0.999
連続変数(ここでは予測確率)が正規分布に従うことを仮定できる場合、ピアソン相関を用いることで予測の良さの指標とすることができる。
prediction_results %>%
group_by(method) %>%
summarise(correlation = cor(observed, predicted))
## `summarise()` ungrouping output (override with `.groups` argument)
## # A tibble: 3 x 2
## method correlation
## <chr> <dbl>
## 1 boosted_trees 0.980
## 2 decision_tree 0.915
## 3 random_forest 0.973
疑似不在データを15.1で紹介したような指標に適用することは難しい。ここでは、在のみデータに対して適用することのできる精度指標を紹介する。
AVIは実際に在であった地点のうち、在と予測された割合を示す。 CVIは「どの場所でも種分布を予測するモデルに対し理論値で是正したAVI」。。。?
prediction_results %>%
mutate(obs = predicted * observed) %>%
group_by(method) %>%
summarise(
AVI = sum(obs > 0.5) / n(),
CVI = sum(observed) / n() - AVI
)
## `summarise()` ungrouping output (override with `.groups` argument)
## # A tibble: 3 x 3
## method AVI CVI
## <chr> <dbl> <dbl>
## 1 boosted_trees 0.486 0.00597
## 2 decision_tree 0.467 0.0247
## 3 random_forest 0.488 0.00445
AVIやCVIは閾値依存の指標であるため、閾値の選択が必要になるという問題がある。 Boyce指標は閾値に依存しない指標の一つであり、15.1で作成した校正プロットに似ている。 校正プロットでは、予測値をいくつかのクラスに分割した後に各クラス内で観測値の在・不在割合を計算しプロットを書いていたが、 Boyceプロットでは観測値の在・不在割合の代わりに在データ数/全データ数を用いている。 予測値のクラス値と在割合の間でスピアマンの相関係数を計算し、モデル評価の指標とする。 クラスに分割するのではなく、スライディングウィンドウを用いた平滑化によって同様のプロットを行う方法もある。
prediction_results %>%
group_by(method) %>%
summarise(
boyce = ecospat::ecospat.boyce(
predicted,
filter(., observed == 1)$predicted,
nclass = 0,
window.w = "default",
res = 100,
PEplot = T
)$Spearman.cor
)
## Registered S3 methods overwritten by 'adehabitatMA':
## method from
## print.SpatialPixelsDataFrame sp
## print.SpatialPixels sp
## Registered S3 method overwritten by 'vegan':
## method from
## print.nullmodel parsnip
## `summarise()` ungrouping output (override with `.groups` argument)
## # A tibble: 3 x 2
## method boyce
## <chr> <dbl>
## 1 boosted_trees 0.756
## 2 decision_tree 0.981
## 3 random_forest 0.619