library(tidyverse)
## library(DPpackage)
library(MCMCpack)
Here we will use the T cell receptor example in BNPDA (p10) for demonstration of Dirichlet process prior.
The data is based on Guindani 2014. Note count 0 does not exist in the data by design. Counts equal to and greater than 5 were not observed.
tcell <- tribble(
~count, ~frequency,
1, 37,
2, 11,
3, 5,
4, 2,
## >= 5, 0,
)
tcell
## # A tibble: 4 x 2
## count frequency
## <dbl> <dbl>
## 1 1 37
## 2 2 11
## 3 3 5
## 4 4 2
The model is the following.
\[ y_{i} | G \overset{iid}\sim G\\ G \sim DP(M, G_{0}) \]
The notation \(y_{i} | G \overset{iid}\sim G\) puzzled me for a while, but the LHS data | parameter notation is necessary because we consider the joint distribution of data and parameter in Bayesian statistics. An iid argument like seen in Frequentist paradigm is only possible after fixing (conditioning on) the parameter. Importantly, once we marginalize this conditioning over the prior distribution, exchangeability remains, but not generally iid.
In the Dirichlet process formulation, each random probability measure \(G\) retains the same support as the centering measure \(G_{0}\). Here we will choose a Poisson with mean 2 that truncates at 1 (zeros are removed). Because we have a discrete centering measure, the partitioning property is relatively clear. If we collapse all values beyond 8 to 8+ bin, the following is the distribution.
\[ \begin{bmatrix} G(0)\\ G(1)\\ \vdots\\ G(8+)\\ \end{bmatrix} \sim Dirichlet \begin{bmatrix} M G_{0}(0)\\ M G_{0}(1)\\ \vdots\\ M G_{0}(8+)\\ \end{bmatrix} \]
Now we define a function to give appropriately collapsed finite probability vector for the centering measure.
G0_vector <- function(lambda, min, max) {
## values below min are truncated
## values above max are collapsed
## Probabilities from Poisson(lambda) for min:max
p_vec <- dpois(x = min:max, lambda = lambda)
## Probability for max+1 ...
p_upper_tail <- 1 - ppois(q = max, lambda = lambda)
## Collapse to max
p_vec[length(p_vec)] <- p_vec[length(p_vec)] + p_upper_tail
## Renormalize
p_vec <- p_vec / sum(p_vec)
## Name
names(p_vec) <- min:max
##
return(p_vec)
}
## Collapse all values beyond 8 to 8+ bin
G0_vector(lambda = 2, min = 1, max = 8)
## 1 2 3 4 5 6 7 8
## 0.313035285 0.313035285 0.208690190 0.104345095 0.041738038 0.013912679 0.003975051 0.001268375
Now we can define a function to create random Dirichlet draws given the probability vector and mass \(M\).
## This function gives random Dirichlet draws via normalization of Gammas.
MCMCpack::rdirichlet
## function (n, alpha)
## {
## l <- length(alpha)
## x <- matrix(rgamma(l * n, alpha), ncol = l, byrow = TRUE)
## sm <- x %*% rep(1, l)
## return(x/as.vector(sm))
## }
## <bytecode: 0x7f9fc9bea560>
## <environment: namespace:MCMCpack>
## Make sure G0_vector has integer names
draw_dirichlet <- function(n_draws, M, G0_vector) {
MCMCpack::rdirichlet(n_draws, alpha = M * G0_vector) %>%
t %>%
as_data_frame %>%
mutate(y = names(G0_vector)) %>%
gather(key = .iter, value = p, -y) %>%
mutate(.iter = as.integer(gsub("V", "", .iter)),
M = M)
}
draw_dirichlet(n = 10, M = 1, G0_vector = G0_vector(lambda = 2, min = 1, max = 8))
## # A tibble: 80 x 4
## y .iter p M
## <chr> <int> <dbl> <dbl>
## 1 1 1 2.82e- 1 1
## 2 2 1 4.26e- 1 1
## 3 3 1 2.03e- 1 1
## 4 4 1 8.96e- 2 1
## 5 5 1 5.77e-20 1
## 6 6 1 9.89e-44 1
## 7 7 1 9.17e-25 1
## 8 8 1 2.20e-82 1
## 9 1 2 8.04e- 1 1
## 10 2 2 5.90e- 4 1
## # ... with 70 more rows
Let us visualize some random draws from the prior at different \(M\).
G0_vector_values <- G0_vector(lambda = 2, min = 1, max = 8)
prior_mean <- data_frame(y = names(G0_vector_values),
p = G0_vector_values,
.iter = 0)
n_draws <- 50
prior_draws <- bind_rows(
draw_dirichlet(n_draws = n_draws, M = 0.1, G0_vector = G0_vector_values),
draw_dirichlet(n_draws = n_draws, M = 1, G0_vector = G0_vector_values),
draw_dirichlet(n_draws = n_draws, M = 10, G0_vector = G0_vector_values),
draw_dirichlet(n_draws = n_draws, M = 100, G0_vector = G0_vector_values),
)
prior_draws %>%
ggplot(mapping = aes(x = y, y = p, group = .iter)) +
geom_line(data = prior_mean, size = 1) +
geom_line(size = 0.1) +
facet_wrap(~ M) +
theme_bw() +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
legend.key = element_blank(),
plot.title = element_text(hjust = 0.5),
strip.background = element_blank())
The thicker line is \(G_{0}\). We can see as the mass parameter \(M\) increases, the prior distribution of \(G\) (thin lines represent draws) is more tightly packed around \(G_{0}\).
Now we can update our model with the observed data.
df1 <- data_frame(y = c(1:4),
count = c(37, 11, 5, 2)) %>%
mutate(total_n = sum(count))
df1
## # A tibble: 4 x 3
## y count total_n
## <int> <dbl> <dbl>
## 1 1 37 55
## 2 2 11 55
## 3 3 5 55
## 4 4 2 55
The total sample size is n = 55. Value 1 was observed most frequently. Values equal to and greater than 5 were not observed. The empirical distribution is the following.
df1_extended <- df1 %>%
bind_rows(data_frame(y = 5:8,
count = 0)) %>%
mutate(total_n = sum(count),
p = count/total_n,
.iter = 0)
df1_extended %>%
ggplot(mapping = aes(x = y, y = p)) +
geom_line() +
theme_bw() +
scale_y_continuous(limits = c(0,1)) +
scale_x_discrete(limits = c(1:8)) +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
legend.key = element_blank(),
plot.title = element_text(hjust = 0.5),
strip.background = element_blank())
A Dirichlet process prior conjugate prior to iid sampling, so we get a Dirichlet process posterior.
\[ G | \mathbf{y} \sim DP \left( M+n, \frac{M}{M+n}G_{0} + \frac{n}{M+n}\widehat{G} \right) \]
The updated centering measure is a weighted mean of the prior centering measure \(G_{0}\) and the empirical measure (empirical probability mass function) \(\widehat{G}\). The mass (concentration) parameter becomes \(M+n\). \(M\) can roughly be interpreted as the “sample size” of the prior information.
df1_posterior <- data_frame(y = as.integer(names(G0_vector_values)),
G0 = G0_vector_values,
n = 55,
G_hat = c(37, 11, 5, 2, 0, 0, 0, 0) / 55,
## Weighted average at different M
`M = 0.1` = (0.1 * G0 + n * G_hat) / (0.1 + n),
`M = 1` = (1 * G0 + n * G_hat) / (1 + n),
`M = 10` = (10 * G0 + n * G_hat) / (10 + n),
`M = 100` = (100 * G0 + n * G_hat) / (100 + n),
)
G0_vector_new_M0.1 <- df1_posterior$`M = 0.1`
G0_vector_new_M1 <- df1_posterior$`M = 1`
G0_vector_new_M10 <- df1_posterior$`M = 10`
G0_vector_new_M100 <- df1_posterior$`M = 100`
names(G0_vector_new_M0.1) <- 1:8
names(G0_vector_new_M1) <- 1:8
names(G0_vector_new_M10) <- 1:8
names(G0_vector_new_M100) <- 1:8
df1_posterior
## # A tibble: 8 x 8
## y G0 n G_hat `M = 0.1` `M = 1` `M = 10` `M = 100`
## <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 1 0.313 55 0.673 0.672 0.666 0.617 0.441
## 2 2 0.313 55 0.2 0.200 0.202 0.217 0.273
## 3 3 0.209 55 0.0909 0.0911 0.0930 0.109 0.167
## 4 4 0.104 55 0.0364 0.0365 0.0376 0.0468 0.0802
## 5 5 0.0417 55 0 0.0000757 0.000745 0.00642 0.0269
## 6 6 0.0139 55 0 0.0000252 0.000248 0.00214 0.00898
## 7 7 0.00398 55 0 0.00000721 0.0000710 0.000612 0.00256
## 8 8 0.00127 55 0 0.00000230 0.0000226 0.000195 0.000818
n <- 55
n_draws <- 50
posterior_draws <- bind_rows(
draw_dirichlet(n_draws = n_draws, M = n + 0.1, G0_vector = G0_vector_new_M0.1) %>% mutate(M = 0.1),
draw_dirichlet(n_draws = n_draws, M = n + 1, G0_vector = G0_vector_new_M1) %>% mutate(M = 1),
draw_dirichlet(n_draws = n_draws, M = n + 10, G0_vector = G0_vector_new_M10) %>% mutate(M = 10),
draw_dirichlet(n_draws = n_draws, M = n + 100, G0_vector = G0_vector_new_M100) %>% mutate(M = 100),
)
posterior_draws %>%
ggplot(mapping = aes(x = y, y = p, group = .iter)) +
## Draws
geom_line(size = 0.1) +
## Prior centering measure
geom_line(data = prior_mean, size = 1) +
## Empirical distribution
geom_line(data = df1_extended, size = 1, color = "red") +
scale_y_continuous(limits = c(0,1)) +
facet_wrap(~ M) +
theme_bw() +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
legend.key = element_blank(),
plot.title = element_text(hjust = 0.5),
strip.background = element_blank())
Here the thick red line is the empirical measure \(\widehat{G}\) and the thick black line is the prior centering measure \(G_{0}\). The thin lines are the posterior samples. As \(M\) increases, the posterior draws approach \(G_{0}\) (more influence of the prior).
print(sessionInfo())
## R version 3.5.1 (2018-07-02)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS 10.14.1
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] parallel stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] bindrcpp_0.2.2 MCMCpack_1.4-4 MASS_7.3-51 coda_0.19-2 forcats_0.3.0 stringr_1.3.1
## [7] dplyr_0.7.7 purrr_0.2.5 readr_1.1.1 tidyr_0.8.2 tibble_1.4.2 ggplot2_3.1.0
## [13] tidyverse_1.2.1 doRNG_1.7.1 rngtools_1.3.1 pkgmaker_0.27 registry_0.5 doParallel_1.0.14
## [19] iterators_1.0.10 foreach_1.4.4 knitr_1.20
##
## loaded via a namespace (and not attached):
## [1] Rcpp_0.12.19 lubridate_1.7.4 lattice_0.20-35 utf8_1.1.4 assertthat_0.2.0
## [6] rprojroot_1.3-2 digest_0.6.18 R6_2.3.0 cellranger_1.1.0 plyr_1.8.4
## [11] MatrixModels_0.4-1 backports_1.1.2 evaluate_0.12 httr_1.3.1 pillar_1.3.0
## [16] rlang_0.3.0.1 lazyeval_0.2.1 readxl_1.1.0 rstudioapi_0.8 SparseM_1.77
## [21] Matrix_1.2-14 rmarkdown_1.10 labeling_0.3 munsell_0.5.0 broom_0.5.0
## [26] compiler_3.5.1 modelr_0.1.2 pkgconfig_2.0.2 mcmc_0.9-5 htmltools_0.3.6
## [31] tidyselect_0.2.5 codetools_0.2-15 fansi_0.4.0 crayon_1.3.4 withr_2.1.2
## [36] grid_3.5.1 nlme_3.1-137 jsonlite_1.5 xtable_1.8-3 gtable_0.2.0
## [41] magrittr_1.5 scales_1.0.0 bibtex_0.4.2 cli_1.0.1 stringi_1.2.4
## [46] xml2_1.2.0 tools_3.5.1 glue_1.3.0 hms_0.4.2 yaml_2.2.0
## [51] colorspace_1.3-2 rvest_0.3.2 bindr_0.1.1 haven_1.1.2 quantreg_5.36
## Record execution time and multicore use
end_time <- Sys.time()
diff_time <- difftime(end_time, start_time, units = "auto")
cat("Started ", as.character(start_time), "\n",
"Finished ", as.character(end_time), "\n",
"Time difference of ", diff_time, " ", attr(diff_time, "units"), "\n",
"Used ", foreach::getDoParWorkers(), " cores\n",
"Used ", foreach::getDoParName(), " as backend\n",
sep = "")
## Started 2018-11-21 22:39:16
## Finished 2018-11-21 22:39:21
## Time difference of 4.524615 secs
## Used 12 cores
## Used doParallelMC as backend