0. 本記事の説明

サンプルデータを使用して、MICE, missRanger, missForestの分類精度を比較します。


1. データの準備

# パッケージのロード
pacman::p_load(tidyverse, 
               mice, 
               missForest, 
               doParallel, 
               missRanger)

# dplyr::selectを指定
select <- dplyr::select
# シードの固定
# 念のためmiceとmissRangerは関数内でも同値を指定
set.seed(1234)
# 今回はdiamondsデータを使用します。
# データのロード
dat <- ggplot2::diamonds %>% 
  
  # Ordered factorは欠測補完の時に手間が増えるので、cutの順序情報を削除する
  mutate(cut = factor(as.character(cut)),  
         cut = fct_relevel(cut, "Fair", "Good", "Very Good", "Premium", "Ideal")) %>% 
  
  # depthとtableはダイヤモンドの寸法(x, y, z)に基づく関数なので、ここでは除外
  select(-depth, -table) %>% 
  
  # x, y, zだと名前が分かりづらいのでrename(任意)
  rename(length = x, 
         width = y, 
         depth = z) %>% 
  
  # colorはバイナリ変数にする
  mutate(color_cat = case_when(color %in% c("D", "E", "F", "G") ~ 0, 
                               color %in% c("H", "I", "J") ~ 1), 
         color_cat = factor(color_cat, levels = c(0, 1), labels = c("Worse", "Better"))) %>% 
  
  # clarityは3値のカテゴリ変数にする
  mutate(clarity_cat = case_when(clarity %in% c("I1", "SI2", "SI1") ~ 0, 
                                 clarity %in% c("VS2", "VS1", "VVS2") ~ 1, 
                                 clarity %in% c("VVS1", "IF") ~ 2), 
         clarity_cat = factor(clarity_cat, levels = c(0, 1, 2), labels = c("Worse", "Moderate", "Better"))) %>% 
  
  # 使用しないものを除く
  select(-color, -clarity) %>% 
  
  # missForestはDFがtibble型だとエラーになる。基本as.data.frameをかますのが無難
  as.data.frame()
# 欠測値補完は基本的にはトレーニングデータでの内的妥当性について検証すればよいと思われるが、
# N数もそれなりにあるので、試しにテストデータをホールドアウトしておく。

# train:test = 4:1 とする
# データ分割用のランダムインデックス
test_indices <- sample(1:nrow(dat), size = floor(nrow(dat)/5))

# ランダムなインデックスに基づいてデータフレームを2つに分割
# 最後に縦もちデータにまとめるので、Method変数を作成します。(以下も同様)
dat_test <- dat[test_indices, ] %>% 
  mutate(Method = "Test_original")
  
dat_train <- dat[-test_indices, ] %>% 
  mutate(Method = "Train_original")

# 欠測率20%での欠測とする
dat_train_NA <- dat_train %>% 
  select(-Method) %>% 
  generateNA(p = 0.2)


2. MICE

# ダミーのmiceデータを作成し、miceの引数を確認
# 補完方法:連続変数はpmm、2値変数はlogreg、多カテゴリ変数はpolyregになっていることを確認
# predictor_matrixはデフォルトのまま
dat_train_mice_0 <- mice(dat_train_NA, m = 1, maxit = 0)
method_train_mice <- dat_train_mice_0$method
predictorMatrix_train_mice <- dat_train_mice_0$predictorMatrix
method_train_mice
##       carat         cut       price      length       width       depth 
##       "pmm"   "polyreg"       "pmm"       "pmm"       "pmm"       "pmm" 
##   color_cat clarity_cat 
##    "logreg"   "polyreg"
predictorMatrix_train_mice
##             carat cut price length width depth color_cat clarity_cat
## carat           0   1     1      1     1     1         1           1
## cut             1   0     1      1     1     1         1           1
## price           1   1     0      1     1     1         1           1
## length          1   1     1      0     1     1         1           1
## width           1   1     1      1     0     1         1           1
## depth           1   1     1      1     1     0         1           1
## color_cat       1   1     1      1     1     1         0           1
## clarity_cat     1   1     1      1     1     1         1           0
# maxitはランダムフォレストのiterationにあわせて10とする
dat_mice <- mice(data = dat_train_NA,
                 m = 1, 
                 maxit = 10,
                 method = method_train_mice,
                 predictorMatrix = predictorMatrix_train_mice,
                 seed = 1234,
                 printFlag = FALSE)

# mice::completeでデータを取り出し
dat_mice_imp <- complete(dat_mice, action = 1) %>% 
  mutate(Method = "MICE")


3. missRanger

# missRangerははやいので、ツリー数は100と500(rangerのデフォルト)にしてみる
# pmmあり・なしも比較する
dat_ranger_100 <- missRanger(dat_train_NA, 
                             formula = . ~ ., 
                             maxiter = 10, 
                             pmm.k = 0, 
                             seed = 1234, 
                             verbose = 0, 
                             data_only = TRUE, 
                             num.trees = 100) %>% 
  mutate(Method = "Ranger_100")

dat_ranger_100_pmm <- missRanger(dat_train_NA, 
                                 formula = . ~ ., 
                                 maxiter = 10, 
                                 pmm.k = 3, 
                                 seed = 1234, 
                                 verbose = 0, 
                                 data_only = TRUE, 
                                 num.trees = 100) %>% 
  mutate(Method = "Ranger_100_pmm")

dat_ranger_500 <- missRanger(dat_train_NA, 
                             formula = . ~ ., 
                             maxiter = 10, 
                             pmm.k = 0, 
                             seed = 1234, 
                             verbose = 0, 
                             data_only = TRUE, 
                             num.trees = 500) %>% 
  mutate(Method = "Ranger_500")

dat_ranger_500_pmm <- missRanger(dat_train_NA, 
                                 formula = . ~ ., 
                                 maxiter = 10, 
                                 pmm.k = 3, 
                                 seed = 1234, 
                                 verbose = 0, 
                                 data_only = TRUE, 
                                 num.trees = 500) %>% 
  mutate(Method = "Ranger_500_pmm")


4. missForest

# ここでは補完精度を検討するので、並列処理する
# ntreeはmissForestのデフォルトの100とする
cores <- detectCores(logical = FALSE)
cl <- makeCluster(ncol(dat_train_NA))
registerDoParallel(cl, cores = cores)

dat_forest <- missForest(dat_train_NA,
                         maxiter = 10,
                         ntree = 100,
                         parallelize = "variables",
                         verbose = TRUE)
##   parallelizing over the variables of the input data matrix 'xmis'
##   missForest iteration 1 in progress...done!
##     estimated error(s): 0.2118306 0.3393295 
##     difference(s): 0.09026568 0.1014399 
##     time: 2326.76 seconds
## 
##   missForest iteration 2 in progress...done!
##     estimated error(s): 0.1674313 0.250169 
##     difference(s): 0.003039284 0.04458658 
##     time: 1681.33 seconds
## 
##   missForest iteration 3 in progress...done!
##     estimated error(s): 0.1645239 0.2323446 
##     difference(s): 0.003483392 0.0330228 
##     time: 1566.73 seconds
## 
##   missForest iteration 4 in progress...done!
##     estimated error(s): 0.1624285 0.2253442 
##     difference(s): 0.003591173 0.0300102 
##     time: 1595.6 seconds
## 
##   missForest iteration 5 in progress...done!
##     estimated error(s): 0.162385 0.2261939 
##     difference(s): 0.003743502 0.02870473 
##     time: 1580.06 seconds
## 
##   missForest iteration 6 in progress...done!
##     estimated error(s): 0.162552 0.2244366 
##     difference(s): 0.00379941 0.02859659 
##     time: 1511.84 seconds
## 
##   missForest iteration 7 in progress...done!
##     estimated error(s): 0.162654 0.2245911 
##     difference(s): 0.003812132 0.02760011 
##     time: 1511 seconds
## 
##   missForest iteration 8 in progress...done!
##     estimated error(s): 0.1622832 0.2229593 
##     difference(s): 0.003924854 0.02666543 
##     time: 1438.93 seconds
## 
##   missForest iteration 9 in progress...done!
##     estimated error(s): 0.1621532 0.2229786 
##     difference(s): 0.003974263 0.02691262 
##     time: 1401.79 seconds
dat_forest_imp <- dat_forest$ximp %>%
  mutate(Method = "Forest_100")


5. 結果のまとめ

まず、連続変数です。

# 全部まとめて縦もちデータにする
dat_summary <- dat_test %>% 
  bind_rows(dat_train, dat_mice_imp, dat_forest_imp, dat_ranger_100, dat_ranger_100_pmm, dat_ranger_500, dat_ranger_500_pmm)
# 連続データはfacet_wrapでまとめて提示する
plotdat_continuous <- dat_summary %>% 
  pivot_longer(cols = c(carat, price, length, width, depth), 
               names_to = "continuous_vars", 
               values_to = "value")

plot_continuous <- ggplot(data = plotdat_continuous, aes(x = Method, y = value)) +
  geom_boxplot() +
  facet_wrap(~continuous_vars, ncol = 3, nrow = 3, scales = "free_y") + 
  theme_classic() + 
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

plot_continuous


次に、カテゴリ変数です。

# cut
plotdat_cut <- dat_summary %>% 
  group_by(Method, cut) %>%
  summarize(n = n()) %>% 
  mutate(pct = n/sum(n),
         lbl = scales::percent(pct))
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
plot_cut <- ggplot(data = plotdat_cut, aes(x = Method, y = pct, fill = cut)) + 
  geom_bar(stat = "identity", position = "fill") +
  scale_y_continuous(breaks = seq(0, 1, 0.2)) +
  geom_text(aes(label = lbl), size = 3, position = position_stack(vjust = 0.5)) +
  scale_fill_brewer(palette = "Set2") +
  labs(y = "Percent", 
       fill="Cut quality",
       x = "Method",
       title = "Imputation Results for Cut") +
  theme_classic() + 
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

plot_cut

# color_cat
plotdat_color_cat <- dat_summary %>% 
  group_by(Method, color_cat) %>%
  summarize(n = n()) %>% 
  mutate(pct = n/sum(n),
         lbl = scales::percent(pct))
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
plot_color_cat <- ggplot(data = plotdat_color_cat, aes(x = Method, y = pct, fill = color_cat)) + 
  geom_bar(stat = "identity", position = "fill") +
  scale_y_continuous(breaks = seq(0, 1, 0.2)) +
  geom_text(aes(label = lbl), size = 3, position = position_stack(vjust = 0.5)) +
  scale_fill_brewer(palette = "Set2") +
  labs(y = "Percent", 
       fill="Color quality",
       x = "Method",
       title = "Imputation Results for Color") +
  theme_classic() + 
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

plot_color_cat

# clarity_cat
plotdat_clarity_cat <- dat_summary %>% 
  group_by(Method, clarity_cat) %>%
  summarize(n = n()) %>% 
  mutate(pct = n/sum(n),
         lbl = scales::percent(pct))
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
plot_clarity_cat <- ggplot(data = plotdat_clarity_cat, aes(x = Method, y = pct, fill = clarity_cat)) + 
  geom_bar(stat = "identity", position = "fill") +
  scale_y_continuous(breaks = seq(0, 1, 0.2)) +
  geom_text(aes(label = lbl), size = 3, position = position_stack(vjust = 0.5)) +
  scale_fill_brewer(palette = "Set2") +
  labs(y = "Percent", 
       fill="Clarity quality",
       x = "Method",
       title = "Imputation Results for Clarity") +
  theme_classic() + 
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

plot_clarity_cat


6. ざっくり考察

これだけみると、mice関数を使用した(ロジスティック回帰での)単一補完でいいのでは・・・という気もしてきます。
少なくとも、回帰なら多重補完が必要!ランダムフォレストなら単一補完でもOK!とはいいがたいように思います。
データの複雑さにもよるとは思いますが、この程度のデータだとツリー数は100でもいいかもしれません。
連続変数についてはpmmを併用しても精度に問題はないようです。
missRangerのマニュアルにもある通り、pmmを併用する方が変数としての解釈性を維持しやすいです。
一方、連続変数の補完にpmmを併用すると、カテゴリ変数の推定精度にばらつきが生じるのかもしれません。