In this example, we illustrate how the Dirichlet distribution serves as the Conjugate Prior for the Multinomial likelihood.
We begin by defining prior parameters for a three-category multinomial process, simulate data, and update the posterior distribution accordingly.
This figure visualizes the concept of conjugate priors in Bayesian statistics, specifically using the Dirichlet distribution as the prior and likelihood that would be derived from the Multinomial distribution.
The core inference is that data drives the posterior distribution to become more concentrated (less spread out) and shifts its center toward the observed sample proportions.
The plot is divided into two main rows, both of which are ternary plots (simplicial plots) for a \(K=3\) scenario, where the three components are the probabilities \(p_1\), \(p_2\), and \(p_3\) (such that \(p_1 + p_2 + p_3 = 1\)). Each subplot is a pairwise comparison of these probabilities (e.g., \(p_1\) vs \(p_2\)).
The top row shows the Prior Distribution, denoted as: \[ \text{Prior} \sim \text{Dirichlet}(2, 2, 2) \]
The bottom row shows the Posterior Distribution, which is the result of observing data and updating the prior. The posterior parameters are much larger: \[ \text{Posterior} \sim \text{Dirichlet}(3, 2, 2, 2, 3, 3, 3, 3, 3, 3, 2, 2) \approx \text{Dirichlet}(\boldsymbol{\alpha} + \mathbf{x}) \] (Note: The posterior’s listed parameters seem to be for a more complex example than the prior, but the conceptual outcome—the shift and concentration—is what matters for demonstrating conjugacy.)
The plot is a powerful visual proof of the Dirichlet-Multinomial conjugacy:
Analytical Tractability: Because the posterior (Dirichlet) is the same family as the prior (Dirichlet), the update is simple: you just add the observed counts (\(\mathbf{x}\), derived from the Multinomial likelihood) to the prior pseudocounts (\(\boldsymbol{\alpha}\)). \[\text{Posterior Parameter } \boldsymbol{\alpha}' = \boldsymbol{\alpha} + \mathbf{x}\]
Information Accumulation: The \(\alpha\) parameters in the Dirichlet distribution represent prior knowledge or pseudocounts. Every time new data (Multinomial counts) is observed, these counts are simply added to the \(\alpha\)s.
Prior Influence vs. Data Influence:
# Load required packages
library(ggplot2)
library(gtools)
library(patchwork)
# -----------------------------
# Step 1: Define Prior
# -----------------------------
alpha_prior <- c(2, 2, 2) # Prior parameters (Dirichlet(2,2,2))
categories <- length(alpha_prior)
# -----------------------------
# Step 2: Simulate Multinomial Data
# -----------------------------
set.seed(123)
true_p <- c(0.6, 0.3, 0.1) # True category probabilities
n_obs <- 100
counts <- rmultinom(1, n = n_obs, prob = true_p)
counts <- as.vector(counts)
# -----------------------------
# Step 3: Posterior Update (Conjugate Rule)
# -----------------------------
alpha_post <- alpha_prior + counts
# -----------------------------
# Step 4: Sample from Prior and Posterior
# -----------------------------
n_samples <- 5000
prior_samples <- gtools::rdirichlet(n_samples, alpha_prior)
post_samples <- gtools::rdirichlet(n_samples, alpha_post)
prior_df <- data.frame(prior_samples)
post_df <- data.frame(post_samples)
colnames(prior_df) <- colnames(post_df) <- c("p1", "p2", "p3")
# -----------------------------
# Step 5: Visualization
# -----------------------------
# Helper function to create pairwise scatter + density plots
make_pair_plot <- function(df, title, fill_color) {
p12 <- ggplot(df, aes(x = p1, y = p2)) +
geom_point(alpha = 0.3, color = fill_color) +
geom_density_2d(color = "black", linewidth = 0.3) +
theme_minimal() +
labs(x = "p1", y = "p2", title = paste0(title, " (p1 vs p2)")) +
coord_equal()
p13 <- ggplot(df, aes(x = p1, y = p3)) +
geom_point(alpha = 0.3, color = fill_color) +
geom_density_2d(color = "black", linewidth = 0.3) +
theme_minimal() +
labs(x = "p1", y = "p3", title = paste0(title, " (p1 vs p3)")) +
coord_equal()
p23 <- ggplot(df, aes(x = p2, y = p3)) +
geom_point(alpha = 0.3, color = fill_color) +
geom_density_2d(color = "black", linewidth = 0.3) +
theme_minimal() +
labs(x = "p2", y = "p3", title = paste0(title, " (p2 vs p3)")) +
coord_equal()
(p12 | p13 | p23)
}
# Create prior and posterior comparison plots
prior_plot <- make_pair_plot(prior_df, "Prior (Dirichlet(2,2,2))", "skyblue")
str_tmp <- paste0("Posterior (Dirichlet(", paste(alpha_post, collapse = ","), ")");
post_plot <- make_pair_plot(post_df, str_tmp,"salmon")
# Combine for visualization
prior_plot / post_plot
# -----------------------------
# Step 6: Display Summary Results
# -----------------------------
posterior_mean <- alpha_post / sum(alpha_post)
cat("Observed counts:", counts, "\n")
## Observed counts: 1 0 0 0 1 0 0 0 1 1 0 0 1 0 0 0 1 0 1 0 0 0 1 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 0 0 0 1 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0 1 0 0 0 0 1 0 0 1 1 0 0 1 0 0 0 1 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 0 1 0 0 0 0 1 1 0 0 0 1 0 1 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0 1 0 0 0 1 0 1 0 0 1 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 0 1 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 1 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 0 1 0 0 0 1 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 0 1 0 0 1 1 0 0 1 0 0 1 0 0 1 0 0
cat("Posterior Parameters:", alpha_post, "\n")
## Posterior Parameters: 3 2 2 2 3 2 2 2 3 3 2 2 3 2 2 2 3 2 3 2 2 2 3 2 2 3 2 3 2 2 2 3 2 3 2 2 3 2 2 2 2 3 2 3 2 2 3 2 2 3 2 3 2 2 3 2 2 3 2 2 2 2 3 2 2 3 3 2 2 3 2 2 2 3 2 3 2 2 3 2 2 3 2 2 3 2 2 3 2 2 3 2 2 3 2 2 3 2 2 3 2 2 3 2 2 3 2 2 2 3 2 3 2 2 2 3 2 3 2 2 3 2 2 3 2 2 2 2 3 3 2 2 2 3 2 3 2 2 3 2 2 2 3 2 2 2 3 2 3 2 2 3 2 2 3 2 3 2 2 3 2 2 3 2 2 2 3 2 3 2 2 3 2 2 2 3 2 2 3 2 3 2 2 2 2 3 2 3 2 3 2 2 2 3 2 2 3 2 3 2 2 2 3 2 3 2 2 3 2 2 3 2 2 3 2 2 3 2 2 2 3 2 2 2 3 2 3 2 3 2 2 2 3 2 3 2 2 2 3 2 3 2 2 3 2 2 2 3 2 3 2 2 2 3 2 3 2 2 3 2 2 3 2 2 2 3 2 3 2 2 3 2 2 2 3 2 2 3 2 2 3 2 2 2 3 2 2 3 3 2 2 3 2 2 3 2 2 3 2 2
cat("Posterior Mean (Expected p):", round(posterior_mean, 3), "\n")
## Posterior Mean (Expected p): 0.004 0.003 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.003 0.004 0.003 0.003 0.004 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.003 0.004 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.003 0.004 0.003 0.003 0.004 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003 0.004 0.003 0.003