サンプルデータを使用して、MICE, missRanger,
missForestの分類精度を比較します。
# パッケージのロード
pacman::p_load(tidyverse,
mice,
missForest,
doParallel,
missRanger)
# dplyr::selectを指定
select <- dplyr::select
# シードの固定
# 念のためmiceとmissRangerは関数内でも同値を指定
set.seed(1234)
# 今回はdiamondsデータを使用します。
# データのロード
dat <- ggplot2::diamonds %>%
# Ordered factorは欠測補完の時に手間が増えるので、cutの順序情報を削除する
mutate(cut = factor(as.character(cut)),
cut = fct_relevel(cut, "Fair", "Good", "Very Good", "Premium", "Ideal")) %>%
# depthとtableはダイヤモンドの寸法(x, y, z)に基づく関数なので、ここでは除外
select(-depth, -table) %>%
# x, y, zだと名前が分かりづらいのでrename(任意)
rename(length = x,
width = y,
depth = z) %>%
# colorはバイナリ変数にする
mutate(color_cat = case_when(color %in% c("D", "E", "F", "G") ~ 0,
color %in% c("H", "I", "J") ~ 1),
color_cat = factor(color_cat, levels = c(0, 1), labels = c("Worse", "Better"))) %>%
# clarityは3値のカテゴリ変数にする
mutate(clarity_cat = case_when(clarity %in% c("I1", "SI2", "SI1") ~ 0,
clarity %in% c("VS2", "VS1", "VVS2") ~ 1,
clarity %in% c("VVS1", "IF") ~ 2),
clarity_cat = factor(clarity_cat, levels = c(0, 1, 2), labels = c("Worse", "Moderate", "Better"))) %>%
# 使用しないものを除く
select(-color, -clarity) %>%
# missForestはDFがtibble型だとエラーになる。基本as.data.frameをかますのが無難
as.data.frame()
# 欠測値補完は基本的にはトレーニングデータでの内的妥当性について検証すればよいと思われるが、
# N数もそれなりにあるので、試しにテストデータをホールドアウトしておく。
# train:test = 4:1 とする
# データ分割用のランダムインデックス
test_indices <- sample(1:nrow(dat), size = floor(nrow(dat)/5))
# ランダムなインデックスに基づいてデータフレームを2つに分割
# 最後に縦もちデータにまとめるので、Method変数を作成します。(以下も同様)
dat_test <- dat[test_indices, ] %>%
mutate(Method = "Test_original")
dat_train <- dat[-test_indices, ] %>%
mutate(Method = "Train_original")
# 欠測率20%での欠測とする
dat_train_NA <- dat_train %>%
select(-Method) %>%
generateNA(p = 0.2)
# ダミーのmiceデータを作成し、miceの引数を確認
# 補完方法:連続変数はpmm、2値変数はlogreg、多カテゴリ変数はpolyregになっていることを確認
# predictor_matrixはデフォルトのまま
dat_train_mice_0 <- mice(dat_train_NA, m = 1, maxit = 0)
method_train_mice <- dat_train_mice_0$method
predictorMatrix_train_mice <- dat_train_mice_0$predictorMatrix
method_train_mice
## carat cut price length width depth
## "pmm" "polyreg" "pmm" "pmm" "pmm" "pmm"
## color_cat clarity_cat
## "logreg" "polyreg"
predictorMatrix_train_mice
## carat cut price length width depth color_cat clarity_cat
## carat 0 1 1 1 1 1 1 1
## cut 1 0 1 1 1 1 1 1
## price 1 1 0 1 1 1 1 1
## length 1 1 1 0 1 1 1 1
## width 1 1 1 1 0 1 1 1
## depth 1 1 1 1 1 0 1 1
## color_cat 1 1 1 1 1 1 0 1
## clarity_cat 1 1 1 1 1 1 1 0
# maxitはランダムフォレストのiterationにあわせて10とする
dat_mice <- mice(data = dat_train_NA,
m = 1,
maxit = 10,
method = method_train_mice,
predictorMatrix = predictorMatrix_train_mice,
seed = 1234,
printFlag = FALSE)
# mice::completeでデータを取り出し
dat_mice_imp <- complete(dat_mice, action = 1) %>%
mutate(Method = "MICE")
# missRangerははやいので、ツリー数は100と500(rangerのデフォルト)にしてみる
# pmmあり・なしも比較する
dat_ranger_100 <- missRanger(dat_train_NA,
formula = . ~ .,
maxiter = 10,
pmm.k = 0,
seed = 1234,
verbose = 0,
data_only = TRUE,
num.trees = 100) %>%
mutate(Method = "Ranger_100")
dat_ranger_100_pmm <- missRanger(dat_train_NA,
formula = . ~ .,
maxiter = 10,
pmm.k = 3,
seed = 1234,
verbose = 0,
data_only = TRUE,
num.trees = 100) %>%
mutate(Method = "Ranger_100_pmm")
dat_ranger_500 <- missRanger(dat_train_NA,
formula = . ~ .,
maxiter = 10,
pmm.k = 0,
seed = 1234,
verbose = 0,
data_only = TRUE,
num.trees = 500) %>%
mutate(Method = "Ranger_500")
dat_ranger_500_pmm <- missRanger(dat_train_NA,
formula = . ~ .,
maxiter = 10,
pmm.k = 3,
seed = 1234,
verbose = 0,
data_only = TRUE,
num.trees = 500) %>%
mutate(Method = "Ranger_500_pmm")
# ここでは補完精度を検討するので、並列処理する
# ntreeはmissForestのデフォルトの100とする
cores <- detectCores(logical = FALSE)
cl <- makeCluster(ncol(dat_train_NA))
registerDoParallel(cl, cores = cores)
dat_forest <- missForest(dat_train_NA,
maxiter = 10,
ntree = 100,
parallelize = "variables",
verbose = TRUE)
## parallelizing over the variables of the input data matrix 'xmis'
## missForest iteration 1 in progress...done!
## estimated error(s): 0.2118306 0.3393295
## difference(s): 0.09026568 0.1014399
## time: 2326.76 seconds
##
## missForest iteration 2 in progress...done!
## estimated error(s): 0.1674313 0.250169
## difference(s): 0.003039284 0.04458658
## time: 1681.33 seconds
##
## missForest iteration 3 in progress...done!
## estimated error(s): 0.1645239 0.2323446
## difference(s): 0.003483392 0.0330228
## time: 1566.73 seconds
##
## missForest iteration 4 in progress...done!
## estimated error(s): 0.1624285 0.2253442
## difference(s): 0.003591173 0.0300102
## time: 1595.6 seconds
##
## missForest iteration 5 in progress...done!
## estimated error(s): 0.162385 0.2261939
## difference(s): 0.003743502 0.02870473
## time: 1580.06 seconds
##
## missForest iteration 6 in progress...done!
## estimated error(s): 0.162552 0.2244366
## difference(s): 0.00379941 0.02859659
## time: 1511.84 seconds
##
## missForest iteration 7 in progress...done!
## estimated error(s): 0.162654 0.2245911
## difference(s): 0.003812132 0.02760011
## time: 1511 seconds
##
## missForest iteration 8 in progress...done!
## estimated error(s): 0.1622832 0.2229593
## difference(s): 0.003924854 0.02666543
## time: 1438.93 seconds
##
## missForest iteration 9 in progress...done!
## estimated error(s): 0.1621532 0.2229786
## difference(s): 0.003974263 0.02691262
## time: 1401.79 seconds
dat_forest_imp <- dat_forest$ximp %>%
mutate(Method = "Forest_100")
まず、連続変数です。
# 全部まとめて縦もちデータにする
dat_summary <- dat_test %>%
bind_rows(dat_train, dat_mice_imp, dat_forest_imp, dat_ranger_100, dat_ranger_100_pmm, dat_ranger_500, dat_ranger_500_pmm)
# 連続データはfacet_wrapでまとめて提示する
plotdat_continuous <- dat_summary %>%
pivot_longer(cols = c(carat, price, length, width, depth),
names_to = "continuous_vars",
values_to = "value")
plot_continuous <- ggplot(data = plotdat_continuous, aes(x = Method, y = value)) +
geom_boxplot() +
facet_wrap(~continuous_vars, ncol = 3, nrow = 3, scales = "free_y") +
theme_classic() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
plot_continuous
次に、カテゴリ変数です。
# cut
plotdat_cut <- dat_summary %>%
group_by(Method, cut) %>%
summarize(n = n()) %>%
mutate(pct = n/sum(n),
lbl = scales::percent(pct))
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
plot_cut <- ggplot(data = plotdat_cut, aes(x = Method, y = pct, fill = cut)) +
geom_bar(stat = "identity", position = "fill") +
scale_y_continuous(breaks = seq(0, 1, 0.2)) +
geom_text(aes(label = lbl), size = 3, position = position_stack(vjust = 0.5)) +
scale_fill_brewer(palette = "Set2") +
labs(y = "Percent",
fill="Cut quality",
x = "Method",
title = "Imputation Results for Cut") +
theme_classic() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
plot_cut
# color_cat
plotdat_color_cat <- dat_summary %>%
group_by(Method, color_cat) %>%
summarize(n = n()) %>%
mutate(pct = n/sum(n),
lbl = scales::percent(pct))
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
plot_color_cat <- ggplot(data = plotdat_color_cat, aes(x = Method, y = pct, fill = color_cat)) +
geom_bar(stat = "identity", position = "fill") +
scale_y_continuous(breaks = seq(0, 1, 0.2)) +
geom_text(aes(label = lbl), size = 3, position = position_stack(vjust = 0.5)) +
scale_fill_brewer(palette = "Set2") +
labs(y = "Percent",
fill="Color quality",
x = "Method",
title = "Imputation Results for Color") +
theme_classic() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
plot_color_cat
# clarity_cat
plotdat_clarity_cat <- dat_summary %>%
group_by(Method, clarity_cat) %>%
summarize(n = n()) %>%
mutate(pct = n/sum(n),
lbl = scales::percent(pct))
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
plot_clarity_cat <- ggplot(data = plotdat_clarity_cat, aes(x = Method, y = pct, fill = clarity_cat)) +
geom_bar(stat = "identity", position = "fill") +
scale_y_continuous(breaks = seq(0, 1, 0.2)) +
geom_text(aes(label = lbl), size = 3, position = position_stack(vjust = 0.5)) +
scale_fill_brewer(palette = "Set2") +
labs(y = "Percent",
fill="Clarity quality",
x = "Method",
title = "Imputation Results for Clarity") +
theme_classic() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
plot_clarity_cat
これだけみると、mice関数を使用した(ロジスティック回帰での)単一補完でいいのでは・・・という気もしてきます。
少なくとも、回帰なら多重補完が必要!ランダムフォレストなら単一補完でもOK!とはいいがたいように思います。
データの複雑さにもよるとは思いますが、この程度のデータだとツリー数は100でもいいかもしれません。
連続変数についてはpmmを併用しても精度に問題はないようです。
missRangerのマニュアルにもある通り、pmmを併用する方が変数としての解釈性を維持しやすいです。
一方、連続変数の補完にpmmを併用すると、カテゴリ変数の推定精度にばらつきが生じるのかもしれません。