サンプルデータを使用して、データ、欠測率20%のcomplete case analysis
(CCA)、missRangerによる単一のchained random forestによる補完、m=10,
100でのmissRangerによる多重補完を比較します。
# パッケージのロード
pacman::p_load(rio,
here,
tidyverse,
missRanger,
ggpubr,
svglite,
mitools,
ggpubr)
# dplyr::selectを指定
select <- dplyr::select
# シードの固定
set.seed(1234)
# 一つのシード値で複数の乱数を発生させる関数
generate_random_numbers <- function(seed, n){
set.seed(seed)
return(sample(1000:9999, n, replace = FALSE))
}
# データのロード
dat <- ggplot2::diamonds %>%
# Ordered factorは欠測補完の時に手間が増えるので、cutの順序情報を削除する
mutate(cut = factor(as.character(cut)),
cut = fct_relevel(cut, "Fair", "Good", "Very Good", "Premium", "Ideal")) %>%
# depthとtableはダイヤモンドの寸法(x, y, z)に基づく関数なので、ここでは除外
select(-depth, -table) %>%
# x, y, zだと名前が分かりづらいのでrename(任意)
rename(length = x,
width = y,
depth = z) %>%
# colorはバイナリ変数にする
mutate(color_cat = case_when(color %in% c("D", "E", "F", "G") ~ 0,
color %in% c("H", "I", "J") ~ 1),
color_cat = factor(color_cat, levels = c(0, 1), labels = c("Worse", "Better"))) %>%
# clarityは3値のカテゴリ変数にする
mutate(clarity_cat = case_when(clarity %in% c("I1", "SI2", "SI1") ~ 0,
clarity %in% c("VS2", "VS1", "VVS2") ~ 1,
clarity %in% c("VVS1", "IF") ~ 2),
clarity_cat = factor(clarity_cat, levels = c(0, 1, 2), labels = c("Worse", "Moderate", "Better"))) %>%
# 使用しないものを除く
select(-color, -clarity) %>%
# missForestはDFがtibble型だとエラーになる。基本as.data.frameをかますのが無難
as.data.frame()
# 欠測率20%での欠測とする
# 多重補完の性能を検討するにあたり、priceを予測する
dat_NA <- dat %>%
select(price, everything()) %>%
generateNA(p = c(0, rep(0.2, times = ncol(dat) - 1)))
オリジナルデータを使用した、priceを目的変数とした線形回帰
reg_original <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat,
data = dat)
CCA
reg_cca <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat,
data = dat_NA %>% drop_na())
single chained random forests
多重補完との比較が目的なので、ハイパーパラメータはmaxiter=10, tree=100,
pmm.k=3で固定します。
dat_ranger_single <- missRanger(dat_NA,
formula = . ~ .,
maxiter = 10,
pmm.k = 3,
seed = 1234,
verbose = 0,
data_only = TRUE,
num.trees = 100)
reg_single <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat,
data = dat_ranger_single)
multiple chained random forest, m=10
dat_ranger_multi_10 <- tibble()
seed_10 <- generate_random_numbers(1234, 10)
for (i in 1:10){
dat_tmp <- missRanger(dat_NA,
formula = . ~ .,
maxiter = 10,
pmm.k = 3,
seed = seed_10[i],
verbose = 0,
data_only = TRUE,
num.trees = 100) %>%
mutate(iter = i)
dat_ranger_multi_10 <- dat_ranger_multi_10 %>%
bind_rows(dat_tmp)
# print progress
if(i %% 2==0){cat("procedure", i, "of 10 has been completed. \n")}
}
## procedure 2 of 10 has been completed.
## procedure 4 of 10 has been completed.
## procedure 6 of 10 has been completed.
## procedure 8 of 10 has been completed.
## procedure 10 of 10 has been completed.
fit_ranger_10 <- dat_ranger_multi_10 %>%
# Stacked up dataset
group_by(iter) %>%
# Nested data frame
nest() %>%
# linear reg
mutate(fit = map(data, function(data) {
reg_tmp <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat,
data = data)
return(reg_tmp)
})) %>%
mutate(coef_fit = map(fit, coef),
vcov_fit = map(fit, vcov))
# The results can be combined using the mitools package
combine_result_10 <- mitools::MIcombine(results = fit_ranger_10$coef_fit,
variances = fit_ranger_10$vcov_fit)
summary_combine_result_10 <- summary(combine_result_10)
## Multiple imputation results:
## MIcombine.default(results = fit_ranger_10$coef_fit, variances = fit_ranger_10$vcov_fit)
## results se (lower upper) missInfo
## (Intercept) -1483.98008 773.17606 -3224.53129 256.57113 99 %
## carat 10082.77887 487.26106 8984.63701 11180.92073 99 %
## cutGood 956.84993 226.01575 450.09052 1463.60933 98 %
## cutVery Good 1082.07699 235.25164 553.89366 1610.26032 98 %
## cutPremium 1059.22291 234.40595 532.96724 1585.47859 98 %
## cutIdeal 1208.50435 238.39890 673.04566 1743.96303 98 %
## length -613.95560 189.33581 -1038.22314 -189.68806 97 %
## width 47.20303 35.78666 -27.09203 121.49808 67 %
## depth -247.96322 83.73484 -429.80296 -66.12348 87 %
## color_catBetter -757.89749 133.03368 -1058.04434 -457.75064 99 %
## clarity_catModerate 1086.56276 132.21036 788.30852 1384.81700 99 %
## clarity_catBetter 1578.32995 150.27679 1240.32858 1916.33132 98 %
multiple chained random forest, m=100
dat_ranger_multi_100 <- tibble()
seed_100 <- generate_random_numbers(1234, 100)
for (i in 1:100){
dat_tmp <- missRanger(dat_NA,
formula = . ~ .,
maxiter = 10,
pmm.k = 3,
seed = seed_100[i],
verbose = 0,
data_only = TRUE,
num.trees = 100) %>%
mutate(iter = i)
dat_ranger_multi_100 <- dat_ranger_multi_100 %>%
bind_rows(dat_tmp)
# print progress
if(i %% 20==0){cat("procedure", i, "of 100 has been completed. \n")}
}
## procedure 20 of 100 has been completed.
## procedure 40 of 100 has been completed.
## procedure 60 of 100 has been completed.
## procedure 80 of 100 has been completed.
## procedure 100 of 100 has been completed.
fit_ranger_100 <- dat_ranger_multi_100 %>%
# Stacked up dataset
group_by(iter) %>%
# Nested data frame
nest() %>%
# linear reg
mutate(fit = map(data, function(data) {
reg_tmp <- lm(price ~ carat + cut + length + width + depth + color_cat + clarity_cat,
data = data)
return(reg_tmp)
})) %>%
mutate(coef_fit = map(fit, coef),
vcov_fit = map(fit, vcov))
# The results can be combined using the mitools package
combine_result_100 <- mitools::MIcombine(results = fit_ranger_100$coef_fit,
variances = fit_ranger_100$vcov_fit)
summary_combine_result_100 <- summary(combine_result_100)
## Multiple imputation results:
## MIcombine.default(results = fit_ranger_100$coef_fit, variances = fit_ranger_100$vcov_fit)
## results se (lower upper) missInfo
## (Intercept) -1534.6689 792.78900 -3107.14746 37.80968 99 %
## carat 10116.5909 493.51464 9137.63788 11095.54399 99 %
## cutGood 979.2295 207.64673 567.56913 1390.88992 97 %
## cutVery Good 1112.1952 212.40134 691.06434 1533.32599 97 %
## cutPremium 1088.2246 214.63880 662.65215 1513.79697 97 %
## cutIdeal 1234.5331 214.98986 808.25496 1660.81121 97 %
## length -605.5088 198.58843 -999.27528 -211.74236 97 %
## width 39.9322 41.74024 -42.42721 122.29162 74 %
## depth -247.5604 65.67109 -377.17699 -117.94388 76 %
## color_catBetter -837.5948 109.17611 -1054.15161 -621.03794 99 %
## clarity_catModerate 1081.0826 110.32512 862.24268 1299.92247 99 %
## clarity_catBetter 1588.9413 120.71885 1349.58791 1828.29471 97 %
100回のmissRangerによる補完を、オリジナルデータと比べます。
連続変数はほぼ変わらなかったので、ここではカテゴリ変数のみ図示します。
dat_ranger_multi_100_plot <- dat_ranger_multi_100 %>%
mutate(Method = paste0("multi_", iter))
dat_plot <- dat %>%
mutate(Method = "original",
iter = 1) %>%
bind_rows(dat_ranger_multi_100_plot) %>%
mutate(Method = factor(Method, levels = c("original", paste0("multi_", 1:100))))
# purrr::mapで扱うために、カテゴリ変数について縦もちにする。
dat_plot_cat <- dat_plot %>%
pivot_longer(cols = c(cut, color_cat, clarity_cat),
names_to = "vars",
values_to = "value")
# 反復処理
my_plots_cat <- purrr::map(
.x = unique(dat_plot_cat$vars),
.f = ~{
ggplot(data = dat_plot_cat %>%
filter(vars == .x) %>%
group_by(Method, value) %>%
summarize(n = n()) %>%
mutate(pct = n/sum(n)),
aes(x = Method, y = pct, fill = value)) +
geom_bar(stat = "identity", position = "fill") +
scale_y_continuous(breaks = seq(0, 1, 0.2)) +
scale_fill_brewer(palette = "Set2") +
labs(y = "Percent",
fill = paste0(.x, " quality"),
x = "Method",
title = paste0("Imputation Results for ", .x)) +
theme_classic() +
theme(axis.text.x = element_text(size = 6, angle = 45, hjust = 1))
})
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
## `summarise()` has grouped output by 'Method'. You can override using the
## `.groups` argument.
# ggpubr::ggarangeは、リスト型で格納されたプロットデータを一括図示する
figure_cat <- ggarrange(plotlist = my_plots_cat, ncol = 1, nrow = 3)
figure_cat
次に、priceを目的変数とした線形回帰分析について、各々の係数の点推定値を比べてみます。
compare_results <- tibble(vars = names(coef(reg_original)),
original = coef(reg_original),
cca = coef(reg_cca),
single = coef(reg_single),
mi_10 = summary_combine_result_10$results,
mi_100 = summary_combine_result_100$results)
compare_results
## # A tibble: 12 × 6
## vars original cca single mi_10 mi_100
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) -1115. -489. -2332. -1484. -1535.
## 2 carat 10627. 11042. 9610. 10083. 10117.
## 3 cutGood 1047. 1025. 1168. 957. 979.
## 4 cutVery Good 1303. 1261. 1198. 1082. 1112.
## 5 cutPremium 1287. 1197. 1212. 1059. 1088.
## 6 cutIdeal 1456. 1384. 1330. 1209. 1235.
## 7 length -820. -880. -402. -614. -606.
## 8 width 59.8 112. 47.2 47.2 39.9
## 9 depth -230. -462. -260. -248. -248.
## 10 color_catBetter -1038. -1100. -879. -758. -838.
## 11 clarity_catModerate 1221. 1244. 957. 1087. 1081.
## 12 clarity_catBetter 1783. 1782. 1522. 1578. 1589.
最後に、オリジナルデータとの差の絶対値をとり、図示します。
小さいほうが、オリジナルに近い値を示していることになります。
compare_results_2 <- compare_results %>%
mutate(across(c(cca, single, mi_10, mi_100), ~{round(abs(.x - original), 1)})) %>%
select(-original) %>%
pivot_longer(cols = c(-vars),
names_to = "variables",
values_to = "Distance_from_original") %>%
mutate(variables = factor(variables, levels = c("cca", "single", "mi_10", "mi_100")))
figure <- ggbarplot(
data = compare_results_2,
x = "vars",
y = "Distance_from_original",
fill = "variables",
color = "variables",
position = position_dodge(0.9),
ggtheme = theme_classic() + theme(axis.text.x = element_text(angle = 45, hjust = 1)),
label = FALSE)
figure
最後の図をみると、必ずしも多重補完をしたほうが、オリジナルの値に近づくというわけでもないようです。
お疲れ様でした。