0. 本記事の説明

サンプルデータを使用して、MICE, missRangerの分類精度を比較します。
purrr::mapによる一括処理を行いました。


1. データの準備

# パッケージのロード
pacman::p_load(tidyverse, 
               mice, 
               missForest, 
               doParallel, 
               missRanger, 
               ggpubr)

# dplyr::selectを指定
select <- dplyr::select

# シードの固定
set.seed(1234)
# データのロード
dat <- ggplot2::diamonds %>% 
  
  # Ordered factorは欠測補完の時に手間が増えるので、cutの順序情報を削除する
  mutate(cut = factor(as.character(cut)),  
         cut = fct_relevel(cut, "Fair", "Good", "Very Good", "Premium", "Ideal")) %>% 
  
  # depthとtableはダイヤモンドの寸法(x, y, z)に基づく関数なので、ここでは除外
  select(-depth, -table) %>% 
  
  # x, y, zだと名前が分かりづらいのでrename(任意)
  rename(length = x, 
         width = y, 
         depth = z) %>% 
  
  # colorはバイナリ変数にする
  mutate(color_cat = case_when(color %in% c("D", "E", "F", "G") ~ 0, 
                               color %in% c("H", "I", "J") ~ 1), 
         color_cat = factor(color_cat, levels = c(0, 1), labels = c("Worse", "Better"))) %>% 
  
  # clarityは3値のカテゴリ変数にする
  mutate(clarity_cat = case_when(clarity %in% c("I1", "SI2", "SI1") ~ 0, 
                                 clarity %in% c("VS2", "VS1", "VVS2") ~ 1, 
                                 clarity %in% c("VVS1", "IF") ~ 2), 
         clarity_cat = factor(clarity_cat, levels = c(0, 1, 2), labels = c("Worse", "Moderate", "Better"))) %>% 
  
  # 使用しないものを除く
  select(-color, -clarity) %>% 
  
  # missForestはDFがtibble型だとエラーになる。基本as.data.frameをかますのが無難
  as.data.frame()

###

# 欠測値補完は基本的にはトレーニングデータでの内的妥当性について検証すればよいと思われるが、
# N数もそれなりにあるので、試しにテストデータをホールドアウトしておく。

# train:test = 4:1 とする
# データ分割用のランダムインデックス
test_indices <- sample(1:nrow(dat), size = floor(nrow(dat)/5))

# ランダムなインデックスに基づいてデータフレームを2つに分割
dat_test <- dat[test_indices, ] %>% 
  mutate(Method = "Test_original")

dat_train <- dat[-test_indices, ] %>% 
  mutate(Method = "Train_original")

# 欠測率20%での欠測とする
dat_train_NA <- dat_train %>% 
  select(-Method) %>% 
  generateNA(p = 0.2)


2. MICE

# ダミーのmiceデータを作成し、miceの引数を確認
# 補完方法:連続変数はpmm、2値変数はlogreg、多カテゴリ変数はpolyregになっていることを確認
# predictor_matrixはデフォルトのまま
dat_train_mice_0 <- mice(dat_train_NA, m = 1, maxit = 0)
method_train_mice <- dat_train_mice_0$method
predictorMatrix_train_mice <- dat_train_mice_0$predictorMatrix
method_train_mice
##       carat         cut       price      length       width       depth 
##       "pmm"   "polyreg"       "pmm"       "pmm"       "pmm"       "pmm" 
##   color_cat clarity_cat 
##    "logreg"   "polyreg"
predictorMatrix_train_mice
##             carat cut price length width depth color_cat clarity_cat
## carat           0   1     1      1     1     1         1           1
## cut             1   0     1      1     1     1         1           1
## price           1   1     0      1     1     1         1           1
## length          1   1     1      0     1     1         1           1
## width           1   1     1      1     0     1         1           1
## depth           1   1     1      1     1     0         1           1
## color_cat       1   1     1      1     1     1         0           1
## clarity_cat     1   1     1      1     1     1         1           0
# maxitはランダムフォレストのiterationにあわせて10とする
dat_mice <- mice(data = dat_train_NA,
                 m = 1, 
                 maxit = 10,
                 method = method_train_mice,
                 predictorMatrix = predictorMatrix_train_mice,
                 seed = 1234,
                 printFlag = FALSE)

# mice::completeでデータを取り出し
dat_mice_imp <- complete(dat_mice, action = 1) %>% 
  mutate(Method = "MICE")


3. missRanger

# missRangerははやいので、ツリー数は100と500(rangerのデフォルト)にしてみる
# pmmあり・なしも比較する
dat_ranger_100 <- missRanger(dat_train_NA, 
                             formula = . ~ ., 
                             maxiter = 10, 
                             pmm.k = 0, 
                             seed = 1234, 
                             verbose = 0, 
                             data_only = TRUE, 
                             num.trees = 100) %>% 
  mutate(Method = "Ranger_100")

dat_ranger_100_pmm <- missRanger(dat_train_NA, 
                                 formula = . ~ ., 
                                 maxiter = 10, 
                                 pmm.k = 3, 
                                 seed = 1234, 
                                 verbose = 0, 
                                 data_only = TRUE, 
                                 num.trees = 100) %>% 
  mutate(Method = "Ranger_100_pmm")

dat_ranger_500 <- missRanger(dat_train_NA, 
                             formula = . ~ ., 
                             maxiter = 10, 
                             pmm.k = 0, 
                             seed = 1234, 
                             verbose = 0, 
                             data_only = TRUE, 
                             num.trees = 500) %>% 
  mutate(Method = "Ranger_500")

dat_ranger_500_pmm <- missRanger(dat_train_NA, 
                                 formula = . ~ ., 
                                 maxiter = 10, 
                                 pmm.k = 3, 
                                 seed = 1234, 
                                 verbose = 0, 
                                 data_only = TRUE, 
                                 num.trees = 500) %>% 
  mutate(Method = "Ranger_500_pmm")


4. missForest

# 図示が目的なので省略

# # ここでは補完精度を検討するので、並列処理する
# # ntreeはmissForestのデフォルトの100とする
# cores <- detectCores(logical = FALSE)
# cl <- makeCluster(ncol(dat_train_NA))
# registerDoParallel(cl, cores = cores)
# 
# dat_forest <- missForest(dat_train_NA,
#                          maxiter = 10,
#                          ntree = 100,
#                          parallelize = "variables",
#                          verbose = TRUE)
# 
# dat_forest_imp <- dat_forest$ximp %>%
#   mutate(Method = "Forest_100")


5. 結果のまとめ

まず、連続変数です。

# 各解析データをまとめる
dat_plot <- dat_test %>% 
  bind_rows(dat_train, dat_mice_imp, dat_ranger_100, dat_ranger_100_pmm, dat_ranger_500, dat_ranger_500_pmm)
# purrr::mapで処理する。
dat_plot_con <- dat_plot %>% 
  pivot_longer(cols = c(carat, price, length, width, depth), 
               names_to = "vars", 
               values_to = "value")

# purrr::mapに渡す.xは、図示したい連続変数名
names_plot_con <- unique(dat_plot_con$vars)

# 反復処理
my_plots_con <- purrr::map(
  .x = names_plot_con,
  .f = ~{
    ggplot(data = dat_plot_con %>% filter(vars == .x), 
           aes(x = Method, y = value)) + 
      geom_boxplot() + 
      labs(x = "Method",
           title = paste0("Imputation Results for ", .x)) +
      theme_classic() + 
      theme(axis.text.x = element_text(size = 6, angle = 45, hjust = 1))
  })

# ggpubr::ggarangeは、リスト型で格納されたプロットデータを一括図示する
figure_con <- ggarrange(plotlist = my_plots_con, ncol = 3, nrow = 2)

figure_con


次に、カテゴリ変数です。

# purrr::map→ggpubr::ggarangeで一括処理する。
# purrr::mapで扱うために、カテゴリ変数について縦もちにする。
dat_plot_cat <- dat_plot %>% 
  pivot_longer(cols = c(cut, color_cat, clarity_cat), 
               names_to = "vars", 
               values_to = "value")

# purrr::mapに渡す.xは、図示したいカテゴリ変数名
names_plot_cat <- unique(dat_plot_cat$vars)

# 反復処理
my_plots_cat <- purrr::map(
  .x = names_plot_cat,
  .f = ~{
    ggplot(data = dat_plot_cat %>% 
             filter(vars == .x) %>% 
             group_by(Method, value) %>%
             summarize(n = n()) %>% 
             mutate(pct = n/sum(n), lbl = scales::percent(pct)), 
           aes(x = Method, y = pct, fill = value)) +
      geom_bar(stat = "identity", position = "fill") +
      scale_y_continuous(breaks = seq(0, 1, 0.2)) +
      geom_text(aes(label = lbl), size = 2, position = position_stack(vjust = 0.5)) +
      scale_fill_brewer(palette = "Set2") +
      labs(y = "Percent", 
           fill = paste0(.x, " quality"),
           x = "Method",
           title = paste0("Imputation Results for ", .x)) +
      theme_classic() + 
      theme(axis.text.x = element_text(size = 6, angle = 45, hjust = 1))
    })
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
# ggpubr::ggarangeは、リスト型で格納されたプロットデータを一括図示する
figure_cat <- ggarrange(plotlist = my_plots_cat, ncol = 2, nrow = 2)

figure_cat


6. ざっくり考察

purrr::map使いこなすと便利ですね~