References

Load packages

library(tidyverse)
## library(DPpackage)
library(MCMCpack)

Dirichlet Process

Here we will use the T cell receptor example in BNPDA (p10) for demonstration of Dirichlet process prior.

Data

The data is based on Guindani 2014. Note count 0 does not exist in the data by design. Counts equal to and greater than 5 were not observed.

tcell <- tribble(
    ~count, ~frequency,
    1, 37,
    2, 11,
    3, 5,
    4, 2,
    ## >= 5, 0,
    )
tcell
## # A tibble: 4 x 2
##   count frequency
##   <dbl>     <dbl>
## 1     1        37
## 2     2        11
## 3     3         5
## 4     4         2

Model

The model is the following.

\[ y_{i} | G \overset{iid}\sim G\\ G \sim DP(M, G_{0}) \]

The notation \(y_{i} | G \overset{iid}\sim G\) puzzled me for a while, but the LHS data | parameter notation is necessary because we consider the joint distribution of data and parameter in Bayesian statistics. An iid argument like seen in Frequentist paradigm is only possible after fixing (conditioning on) the parameter. Importantly, once we marginalize this conditioning over the prior distribution, exchangeability remains, but not generally iid.

In the Dirichlet process formulation, each random probability measure \(G\) retains the same support as the centering measure \(G_{0}\). Here we will choose a Poisson with mean 2 that truncates at 1 (zeros are removed). Because we have a discrete centering measure, the partitioning property is relatively clear. If we collapse all values beyond 8 to 8+ bin, the following is the distribution.

\[ \begin{bmatrix} G(0)\\ G(1)\\ \vdots\\ G(8+)\\ \end{bmatrix} \sim Dirichlet \begin{bmatrix} M G_{0}(0)\\ M G_{0}(1)\\ \vdots\\ M G_{0}(8+)\\ \end{bmatrix} \]

Prior

Now we define a function to give appropriately collapsed finite probability vector for the centering measure.

G0_vector <- function(lambda, min, max) {
    ## values below min are truncated
    ## values above max are collapsed

    ## Probabilities from Poisson(lambda) for min:max
    p_vec <- dpois(x = min:max, lambda = lambda)
    ## Probability for max+1 ...
    p_upper_tail <- 1 - ppois(q = max, lambda = lambda)
    ## Collapse to max
    p_vec[length(p_vec)] <- p_vec[length(p_vec)] + p_upper_tail
    ## Renormalize
    p_vec <- p_vec / sum(p_vec)
    ## Name
    names(p_vec) <- min:max
    ##
    return(p_vec)
}

## Collapse all values beyond 8 to 8+ bin
G0_vector(lambda = 2, min = 1, max = 8)
##           1           2           3           4           5           6           7           8 
## 0.313035285 0.313035285 0.208690190 0.104345095 0.041738038 0.013912679 0.003975051 0.001268375

Now we can define a function to create random Dirichlet draws given the probability vector and mass \(M\).

## This function gives random Dirichlet draws via normalization of Gammas.
MCMCpack::rdirichlet
## function (n, alpha) 
## {
##     l <- length(alpha)
##     x <- matrix(rgamma(l * n, alpha), ncol = l, byrow = TRUE)
##     sm <- x %*% rep(1, l)
##     return(x/as.vector(sm))
## }
## <bytecode: 0x7f9fc9bea560>
## <environment: namespace:MCMCpack>
## Make sure G0_vector has integer names
draw_dirichlet <- function(n_draws, M, G0_vector) {
    MCMCpack::rdirichlet(n_draws, alpha = M * G0_vector) %>%
    t %>%
    as_data_frame %>%
    mutate(y = names(G0_vector)) %>%
    gather(key = .iter, value = p, -y) %>%
    mutate(.iter = as.integer(gsub("V", "", .iter)),
           M = M)
}

draw_dirichlet(n = 10, M = 1, G0_vector = G0_vector(lambda = 2, min = 1, max = 8))
## # A tibble: 80 x 4
##    y     .iter        p     M
##    <chr> <int>    <dbl> <dbl>
##  1 1         1 2.82e- 1     1
##  2 2         1 4.26e- 1     1
##  3 3         1 2.03e- 1     1
##  4 4         1 8.96e- 2     1
##  5 5         1 5.77e-20     1
##  6 6         1 9.89e-44     1
##  7 7         1 9.17e-25     1
##  8 8         1 2.20e-82     1
##  9 1         2 8.04e- 1     1
## 10 2         2 5.90e- 4     1
## # ... with 70 more rows

Let us visualize some random draws from the prior at different \(M\).

G0_vector_values <- G0_vector(lambda = 2, min = 1, max = 8)

prior_mean <- data_frame(y = names(G0_vector_values),
                         p = G0_vector_values,
                         .iter = 0)

n_draws <- 50
prior_draws <- bind_rows(
    draw_dirichlet(n_draws = n_draws, M = 0.1, G0_vector = G0_vector_values),
    draw_dirichlet(n_draws = n_draws, M = 1,   G0_vector = G0_vector_values),
    draw_dirichlet(n_draws = n_draws, M = 10,  G0_vector = G0_vector_values),
    draw_dirichlet(n_draws = n_draws, M = 100, G0_vector = G0_vector_values),
    )

prior_draws %>%
    ggplot(mapping = aes(x = y, y = p, group = .iter)) +
    geom_line(data = prior_mean, size = 1) +
    geom_line(size = 0.1) +
    facet_wrap(~ M) +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
          legend.key = element_blank(),
          plot.title = element_text(hjust = 0.5),
          strip.background = element_blank())

The thicker line is \(G_{0}\). We can see as the mass parameter \(M\) increases, the prior distribution of \(G\) (thin lines represent draws) is more tightly packed around \(G_{0}\).

Posterior inference

Now we can update our model with the observed data.

df1 <- data_frame(y = c(1:4),
                  count = c(37, 11, 5, 2)) %>%
    mutate(total_n = sum(count))
df1
## # A tibble: 4 x 3
##       y count total_n
##   <int> <dbl>   <dbl>
## 1     1    37      55
## 2     2    11      55
## 3     3     5      55
## 4     4     2      55

The total sample size is n = 55. Value 1 was observed most frequently. Values equal to and greater than 5 were not observed. The empirical distribution is the following.

df1_extended <-  df1 %>%
    bind_rows(data_frame(y = 5:8,
                         count = 0)) %>%
    mutate(total_n = sum(count),
           p = count/total_n,
           .iter = 0)

df1_extended %>%
    ggplot(mapping = aes(x = y, y = p)) +
    geom_line() +
    theme_bw() +
    scale_y_continuous(limits = c(0,1)) +
    scale_x_discrete(limits = c(1:8)) +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
          legend.key = element_blank(),
          plot.title = element_text(hjust = 0.5),
          strip.background = element_blank())

A Dirichlet process prior conjugate prior to iid sampling, so we get a Dirichlet process posterior.

\[ G | \mathbf{y} \sim DP \left( M+n, \frac{M}{M+n}G_{0} + \frac{n}{M+n}\widehat{G} \right) \]

The updated centering measure is a weighted mean of the prior centering measure \(G_{0}\) and the empirical measure (empirical probability mass function) \(\widehat{G}\). The mass (concentration) parameter becomes \(M+n\). \(M\) can roughly be interpreted as the “sample size” of the prior information.

df1_posterior <- data_frame(y = as.integer(names(G0_vector_values)),
                            G0 = G0_vector_values,
                            n = 55,
                            G_hat = c(37, 11, 5, 2, 0, 0, 0, 0) / 55,
                            ## Weighted average at different M
                            `M = 0.1` = (0.1 * G0 + n * G_hat) / (0.1 + n),
                            `M = 1` = (1 * G0 + n * G_hat) / (1 + n),
                            `M = 10` = (10 * G0 + n * G_hat) / (10 + n),
                            `M = 100` = (100 * G0 + n * G_hat) / (100 + n),
                            )
G0_vector_new_M0.1 <- df1_posterior$`M = 0.1`
G0_vector_new_M1 <- df1_posterior$`M = 1`
G0_vector_new_M10 <- df1_posterior$`M = 10`
G0_vector_new_M100 <- df1_posterior$`M = 100`
names(G0_vector_new_M0.1) <- 1:8
names(G0_vector_new_M1) <- 1:8
names(G0_vector_new_M10) <- 1:8
names(G0_vector_new_M100) <- 1:8
df1_posterior
## # A tibble: 8 x 8
##       y      G0     n  G_hat  `M = 0.1`   `M = 1` `M = 10` `M = 100`
##   <int>   <dbl> <dbl>  <dbl>      <dbl>     <dbl>    <dbl>     <dbl>
## 1     1 0.313      55 0.673  0.672      0.666     0.617     0.441   
## 2     2 0.313      55 0.2    0.200      0.202     0.217     0.273   
## 3     3 0.209      55 0.0909 0.0911     0.0930    0.109     0.167   
## 4     4 0.104      55 0.0364 0.0365     0.0376    0.0468    0.0802  
## 5     5 0.0417     55 0      0.0000757  0.000745  0.00642   0.0269  
## 6     6 0.0139     55 0      0.0000252  0.000248  0.00214   0.00898 
## 7     7 0.00398    55 0      0.00000721 0.0000710 0.000612  0.00256 
## 8     8 0.00127    55 0      0.00000230 0.0000226 0.000195  0.000818
n <- 55
n_draws <- 50
posterior_draws <- bind_rows(
    draw_dirichlet(n_draws = n_draws, M = n + 0.1, G0_vector = G0_vector_new_M0.1) %>% mutate(M = 0.1),
    draw_dirichlet(n_draws = n_draws, M = n + 1,   G0_vector = G0_vector_new_M1) %>% mutate(M = 1),
    draw_dirichlet(n_draws = n_draws, M = n + 10,  G0_vector = G0_vector_new_M10) %>% mutate(M = 10),
    draw_dirichlet(n_draws = n_draws, M = n + 100, G0_vector = G0_vector_new_M100) %>% mutate(M = 100),
    )

posterior_draws %>%
    ggplot(mapping = aes(x = y, y = p, group = .iter)) +
    ## Draws
    geom_line(size = 0.1) +
    ## Prior centering measure
    geom_line(data = prior_mean, size = 1) +
    ## Empirical distribution
    geom_line(data = df1_extended, size = 1, color = "red") +
    scale_y_continuous(limits = c(0,1)) +
    facet_wrap(~ M) +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
          legend.key = element_blank(),
          plot.title = element_text(hjust = 0.5),
          strip.background = element_blank())

Here the thick red line is the empirical measure \(\widehat{G}\) and the thick black line is the prior centering measure \(G_{0}\). The thin lines are the posterior samples. As \(M\) increases, the posterior draws approach \(G_{0}\) (more influence of the prior).


print(sessionInfo())
## R version 3.5.1 (2018-07-02)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS  10.14.1
## 
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] parallel  stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] bindrcpp_0.2.2    MCMCpack_1.4-4    MASS_7.3-51       coda_0.19-2       forcats_0.3.0     stringr_1.3.1    
##  [7] dplyr_0.7.7       purrr_0.2.5       readr_1.1.1       tidyr_0.8.2       tibble_1.4.2      ggplot2_3.1.0    
## [13] tidyverse_1.2.1   doRNG_1.7.1       rngtools_1.3.1    pkgmaker_0.27     registry_0.5      doParallel_1.0.14
## [19] iterators_1.0.10  foreach_1.4.4     knitr_1.20       
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.19       lubridate_1.7.4    lattice_0.20-35    utf8_1.1.4         assertthat_0.2.0  
##  [6] rprojroot_1.3-2    digest_0.6.18      R6_2.3.0           cellranger_1.1.0   plyr_1.8.4        
## [11] MatrixModels_0.4-1 backports_1.1.2    evaluate_0.12      httr_1.3.1         pillar_1.3.0      
## [16] rlang_0.3.0.1      lazyeval_0.2.1     readxl_1.1.0       rstudioapi_0.8     SparseM_1.77      
## [21] Matrix_1.2-14      rmarkdown_1.10     labeling_0.3       munsell_0.5.0      broom_0.5.0       
## [26] compiler_3.5.1     modelr_0.1.2       pkgconfig_2.0.2    mcmc_0.9-5         htmltools_0.3.6   
## [31] tidyselect_0.2.5   codetools_0.2-15   fansi_0.4.0        crayon_1.3.4       withr_2.1.2       
## [36] grid_3.5.1         nlme_3.1-137       jsonlite_1.5       xtable_1.8-3       gtable_0.2.0      
## [41] magrittr_1.5       scales_1.0.0       bibtex_0.4.2       cli_1.0.1          stringi_1.2.4     
## [46] xml2_1.2.0         tools_3.5.1        glue_1.3.0         hms_0.4.2          yaml_2.2.0        
## [51] colorspace_1.3-2   rvest_0.3.2        bindr_0.1.1        haven_1.1.2        quantreg_5.36
## Record execution time and multicore use
end_time <- Sys.time()
diff_time <- difftime(end_time, start_time, units = "auto")
cat("Started  ", as.character(start_time), "\n",
    "Finished ", as.character(end_time), "\n",
    "Time difference of ", diff_time, " ", attr(diff_time, "units"), "\n",
    "Used ", foreach::getDoParWorkers(), " cores\n",
    "Used ", foreach::getDoParName(), " as backend\n",
    sep = "")
## Started  2018-11-21 22:39:16
## Finished 2018-11-21 22:39:21
## Time difference of 4.524615 secs
## Used 12 cores
## Used doParallelMC as backend