0. 本記事の説明

サンプルデータを使用して、データ、欠測率20%のcomplete case analysis (CCA)、missRangerによる単一のchained random forestによる補完、m=10, 100でのmissRangerによる多重補完を比較します。

1. データの準備

# パッケージのロード
pacman::p_load(rio, 
               here, 
               tidyverse, 
               missRanger, 
               ggpubr, 
               svglite, 
               mitools, 
               ggpubr)

# dplyr::selectを指定
select <- dplyr::select

# シードの固定
set.seed(1234)

# 一つのシード値で複数の乱数を発生させる関数
generate_random_numbers <- function(seed, n){
  set.seed(seed)
  return(sample(1000:9999, n, replace = FALSE))
}
# データのロード
dat <- ggplot2::diamonds %>% 
  
  # Ordered factorは欠測補完の時に手間が増えるので、cutの順序情報を削除する
  mutate(cut = factor(as.character(cut)),  
         cut = fct_relevel(cut, "Fair", "Good", "Very Good", "Premium", "Ideal")) %>% 
  
  # depthとtableはダイヤモンドの寸法(x, y, z)に基づく関数なので、ここでは除外
  select(-depth, -table) %>% 
  
  # x, y, zだと名前が分かりづらいのでrename(任意)
  rename(length = x, 
         width = y, 
         depth = z) %>% 
  
  # colorはバイナリ変数にする
  mutate(color_cat = case_when(color %in% c("D", "E", "F", "G") ~ 0, 
                               color %in% c("H", "I", "J") ~ 1), 
         color_cat = factor(color_cat, levels = c(0, 1), labels = c("Worse", "Better"))) %>% 
  
  # clarityは3値のカテゴリ変数にする
  mutate(clarity_cat = case_when(clarity %in% c("I1", "SI2", "SI1") ~ 0, 
                                 clarity %in% c("VS2", "VS1", "VVS2") ~ 1, 
                                 clarity %in% c("VVS1", "IF") ~ 2), 
         clarity_cat = factor(clarity_cat, levels = c(0, 1, 2), labels = c("Worse", "Moderate", "Better"))) %>% 
  
  # 使用しないものを除く
  select(-color, -clarity) %>% 
  
  # missForestはDFがtibble型だとエラーになる。基本as.data.frameをかますのが無難
  as.data.frame()
# 欠測率20%での欠測とする
# 多重補完の性能を検討するにあたり、priceを予測する
dat_NA <- dat %>% 
  select(price, everything()) %>% 
  generateNA(p = c(0, rep(0.2, times = ncol(dat) - 1)))


2. 実行

オリジナルデータを使用した、priceを目的変数とした線形回帰

reg_original <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat, 
                   data = dat)


CCA

reg_cca <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat, 
              data = dat_NA %>% drop_na())


single chained random forests
多重補完との比較が目的なので、ハイパーパラメータはmaxiter=10, tree=100, pmm.k=3で固定します。

dat_ranger_single <- missRanger(dat_NA,
                                formula = . ~ .,
                                maxiter = 10,
                                pmm.k = 3,
                                seed = 1234,
                                verbose = 0,
                                data_only = TRUE,
                                num.trees = 100)

reg_single <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat, 
                 data = dat_ranger_single)


multiple chained random forest, m=10

dat_ranger_multi_10 <- tibble()
seed_10 <- generate_random_numbers(1234, 10)

for (i in 1:10){
  dat_tmp <- missRanger(dat_NA, 
                        formula = . ~ ., 
                        maxiter = 10, 
                        pmm.k = 3, 
                        seed = seed_10[i], 
                        verbose = 0, 
                        data_only = TRUE, 
                        num.trees = 100) %>% 
    mutate(iter = i)
  
  dat_ranger_multi_10 <- dat_ranger_multi_10 %>% 
    bind_rows(dat_tmp)
  
  # print progress
  if(i %% 2==0){cat("procedure", i, "of 10 has been completed. \n")}
}
## procedure 2 of 10 has been completed. 
## procedure 4 of 10 has been completed. 
## procedure 6 of 10 has been completed. 
## procedure 8 of 10 has been completed. 
## procedure 10 of 10 has been completed.
fit_ranger_10 <- dat_ranger_multi_10 %>%
  # Stacked up dataset
  group_by(iter) %>%
  
  # Nested data frame
  nest() %>%
  
  # linear reg
  mutate(fit = map(data, function(data) {
    reg_tmp <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat, 
                  data = data)
    return(reg_tmp)
  })) %>%
  
  mutate(coef_fit = map(fit, coef),
         vcov_fit = map(fit, vcov))

# The results can be combined using the mitools package
combine_result_10 <- mitools::MIcombine(results = fit_ranger_10$coef_fit,
                                        variances = fit_ranger_10$vcov_fit)
summary_combine_result_10 <- summary(combine_result_10)
## Multiple imputation results:
##       MIcombine.default(results = fit_ranger_10$coef_fit, variances = fit_ranger_10$vcov_fit)
##                         results        se      (lower      upper) missInfo
## (Intercept)         -1483.98008 773.17606 -3224.53129   256.57113     99 %
## carat               10082.77887 487.26106  8984.63701 11180.92073     99 %
## cutGood               956.84993 226.01575   450.09052  1463.60933     98 %
## cutVery Good         1082.07699 235.25164   553.89366  1610.26032     98 %
## cutPremium           1059.22291 234.40595   532.96724  1585.47859     98 %
## cutIdeal             1208.50435 238.39890   673.04566  1743.96303     98 %
## length               -613.95560 189.33581 -1038.22314  -189.68806     97 %
## width                  47.20303  35.78666   -27.09203   121.49808     67 %
## depth                -247.96322  83.73484  -429.80296   -66.12348     87 %
## color_catBetter      -757.89749 133.03368 -1058.04434  -457.75064     99 %
## clarity_catModerate  1086.56276 132.21036   788.30852  1384.81700     99 %
## clarity_catBetter    1578.32995 150.27679  1240.32858  1916.33132     98 %


multiple chained random forest, m=100

dat_ranger_multi_100 <- tibble()
seed_100 <- generate_random_numbers(1234, 100)

for (i in 1:100){
  dat_tmp <- missRanger(dat_NA, 
                        formula = . ~ ., 
                        maxiter = 10, 
                        pmm.k = 3, 
                        seed = seed_100[i], 
                        verbose = 0, 
                        data_only = TRUE, 
                        num.trees = 100) %>% 
    mutate(iter = i)
  
  dat_ranger_multi_100 <- dat_ranger_multi_100 %>% 
    bind_rows(dat_tmp)
  
  # print progress
  if(i %% 20==0){cat("procedure", i, "of 100 has been completed. \n")}
}
## procedure 20 of 100 has been completed. 
## procedure 40 of 100 has been completed. 
## procedure 60 of 100 has been completed. 
## procedure 80 of 100 has been completed. 
## procedure 100 of 100 has been completed.
fit_ranger_100 <- dat_ranger_multi_100 %>%
  # Stacked up dataset
  group_by(iter) %>%
  
  # Nested data frame
  nest() %>%
  
  # linear reg
  mutate(fit = map(data, function(data) {
    reg_tmp <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat, 
                  data = data)
    return(reg_tmp)
  })) %>%
  
  mutate(coef_fit = map(fit, coef),
         vcov_fit = map(fit, vcov))

# The results can be combined using the mitools package
combine_result_100 <- mitools::MIcombine(results = fit_ranger_100$coef_fit,
                                         variances = fit_ranger_100$vcov_fit)
summary_combine_result_100 <- summary(combine_result_100)
## Multiple imputation results:
##       MIcombine.default(results = fit_ranger_100$coef_fit, variances = fit_ranger_100$vcov_fit)
##                        results        se      (lower      upper) missInfo
## (Intercept)         -1534.6689 792.78900 -3107.14746    37.80968     99 %
## carat               10116.5909 493.51464  9137.63788 11095.54399     99 %
## cutGood               979.2295 207.64673   567.56913  1390.88992     97 %
## cutVery Good         1112.1952 212.40134   691.06434  1533.32599     97 %
## cutPremium           1088.2246 214.63880   662.65215  1513.79697     97 %
## cutIdeal             1234.5331 214.98986   808.25496  1660.81121     97 %
## length               -605.5088 198.58843  -999.27528  -211.74236     97 %
## width                  39.9322  41.74024   -42.42721   122.29162     74 %
## depth                -247.5604  65.67109  -377.17699  -117.94388     76 %
## color_catBetter      -837.5948 109.17611 -1054.15161  -621.03794     99 %
## clarity_catModerate  1081.0826 110.32512   862.24268  1299.92247     99 %
## clarity_catBetter    1588.9413 120.71885  1349.58791  1828.29471     97 %


3. 結果の図示

100回のmissRangerによる補完を、オリジナルデータと比べます。
連続変数はほぼ変わらなかったので、ここではカテゴリ変数のみ図示します。

dat_ranger_multi_100_plot <- dat_ranger_multi_100 %>% 
  mutate(Method = paste0("multi_", iter))

dat_plot <- dat %>% 
  mutate(Method = "original", 
         iter = 1) %>% 
  bind_rows(dat_ranger_multi_100_plot) %>% 
  mutate(Method = factor(Method, levels = c("original", paste0("multi_", 1:100))))

# purrr::mapで扱うために、カテゴリ変数について縦もちにする。
dat_plot_cat <- dat_plot %>% 
  pivot_longer(cols = c(cut, color_cat, clarity_cat), 
               names_to = "vars", 
               values_to = "value")

# 反復処理
my_plots_cat <- purrr::map(
  .x = unique(dat_plot_cat$vars),
  .f = ~{
    ggplot(data = dat_plot_cat %>% 
             filter(vars == .x) %>% 
             group_by(Method, value) %>%
             summarize(n = n()) %>% 
             mutate(pct = n/sum(n)), 
           aes(x = Method, y = pct, fill = value)) +
      geom_bar(stat = "identity", position = "fill") +
      scale_y_continuous(breaks = seq(0, 1, 0.2)) +
      scale_fill_brewer(palette = "Set2") +
      labs(y = "Percent", 
           fill = paste0(.x, " quality"),
           x = "Method",
           title = paste0("Imputation Results for ", .x)) +
      theme_classic() + 
      theme(axis.text.x = element_text(size = 6, angle = 45, hjust = 1))
  })
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
# ggpubr::ggarangeは、リスト型で格納されたプロットデータを一括図示する
figure_cat <- ggarrange(plotlist = my_plots_cat, ncol = 1, nrow = 3)

figure_cat


次に、priceを目的変数とした線形回帰分析について、各々の係数の点推定値を比べてみます。

compare_results <- tibble(vars = names(coef(reg_original)), 
                          original = coef(reg_original),
                          cca = coef(reg_cca),
                          single = coef(reg_single),
                          mi_10 = summary_combine_result_10$results, 
                          mi_100 = summary_combine_result_100$results)

compare_results
## # A tibble: 12 × 6
##    vars                original    cca  single   mi_10  mi_100
##    <chr>                  <dbl>  <dbl>   <dbl>   <dbl>   <dbl>
##  1 (Intercept)          -1115.   -489. -2332.  -1484.  -1535. 
##  2 carat                10627.  11042.  9610.  10083.  10117. 
##  3 cutGood               1047.   1025.  1168.    957.    979. 
##  4 cutVery Good          1303.   1261.  1198.   1082.   1112. 
##  5 cutPremium            1287.   1197.  1212.   1059.   1088. 
##  6 cutIdeal              1456.   1384.  1330.   1209.   1235. 
##  7 length                -820.   -880.  -402.   -614.   -606. 
##  8 width                   59.8   112.    47.2    47.2    39.9
##  9 depth                 -230.   -462.  -260.   -248.   -248. 
## 10 color_catBetter      -1038.  -1100.  -879.   -758.   -838. 
## 11 clarity_catModerate   1221.   1244.   957.   1087.   1081. 
## 12 clarity_catBetter     1783.   1782.  1522.   1578.   1589.


最後に、オリジナルデータとの差の絶対値をとり、図示します。
小さいほうが、オリジナルに近い値を示していることになります。

compare_results_2 <- compare_results %>% 
  mutate(across(c(cca, single, mi_10, mi_100), ~{round(abs(.x - original), 1)})) %>% 
  select(-original) %>% 
  pivot_longer(cols = c(-vars), 
               names_to = "variables", 
               values_to = "Distance_from_original") %>% 
  mutate(variables = factor(variables, levels = c("cca", "single", "mi_10", "mi_100")))
  
figure <- ggbarplot(
  data = compare_results_2, 
  x = "vars", 
  y = "Distance_from_original",
  fill = "variables", 
  color = "variables",
  position = position_dodge(0.9), 
  ggtheme = theme_classic() + theme(axis.text.x = element_text(angle = 45, hjust = 1)), 
  label = FALSE)

figure


最後の図をみると、必ずしも多重補完をしたほうが、オリジナルの値に近づくというわけでもないようです。
お疲れ様でした。