1 Variational Inference

Consider we have a fully Bayesian probabilistic model in which we denote all observed variables by \(\mathbf{X}\) and all latent variables by \(\mathbf{Z}\). Note that the latent random variables \(\mathbf{Z}\) include both parameters that might govern all the data, as found in Bayesian models, and latent variables that are “local” to individual data points. Our probabilistic model, specifies the joint density \(p(\mathbf{X}, \mathbf{Z})\) and our goal is to compute the posterior distribution \(p(\mathbf{Z} | \mathbf{X})\). This conditional can be used to produce point or interval estimates of the latent variables, form predictive densities of new data, and using Bayes rule we can write it as \[ p(\mathbf{Z} | \mathbf{X}) = \frac{p(\mathbf{X}, \mathbf{Z})}{p(\mathbf{X})} = \frac{p(\mathbf{X} | \mathbf{Z}) p(\mathbf{Z})}{p(\mathbf{X})} \] The denominator contains the marginal likelihood (or model evidence) since we marginalize out the latent variables from the joint density \[ p(\mathbf{X}) = \int p(\mathbf{X}, \mathbf{Z}) d\mathbf{Z} \] For many models, this evidence integral is unavailable in closed form or requires exponential time to compute. The evidence is what we need to compute the conditional density from the joint; this is why inference in such models is hard.

1.1 The evidence lower bound (ELBO)

In variational inference, we specific a family \(\mathcal{Q}\) of densities over the latent variables. Each \(q(\mathbf{Z}) \in \mathcal{Q}\) is a candidate approximation to the exact posterior density. Our goal is to find the best candidate, i.e. the one that minimizes the Kulback-Leibler (\(\mathcal{KL}\)) divergence to the exact posterior. Hence, inference now amounts to solving the following optimization problem \[ \begin{aligned} q^{*}(Z) & = \underset{q(\mathbf{Z}) \in \mathcal{Q}} {\mathrm{argmax}}\; \mathcal{KL}(q(\mathbf{Z})\; || \; p(\mathbf{Z} | \mathbf{X})) \end{aligned} \] Note that the complexity of the family \(\mathcal{Q}\) determines the complexity of this optimization. Let us observe what is the \(\mathcal{KL}\) divergence between these two distributions \[ \begin{aligned} \mathcal{KL}(q(\mathbf{Z})\; || \; p(\mathbf{Z} | \mathbf{X})) & = - \int q(\mathbf{Z}) \ln\frac{p(\mathbf{Z} | \mathbf{X})}{q(\mathbf{Z})} d\mathbf{Z} \\ & = \int q(\mathbf{Z})\ln q(\mathbf{Z}) d\mathbf{Z} - \int q(\mathbf{Z}) \ln p(\mathbf{Z} | \mathbf{X})d\mathbf{Z} \\ & = \mathbb{E}_{q(\mathbf{Z})}\Big[ \ln q(\mathbf{Z})\Big] - \int q(\mathbf{Z}) \ln \frac{p(\mathbf{X}, \mathbf{Z})}{p(\mathbf{X})} d\mathbf{Z} \\ & = \mathbb{E}_{q(\mathbf{Z})}\Big[\ln q(\mathbf{Z})\Big] + \int q(\mathbf{Z}) \ln p(\mathbf{X})d\mathbf{Z} - \int q(\mathbf{Z}) \ln p(\mathbf{X}, \mathbf{Z})d\mathbf{Z}\\ & = \mathbb{E}_{q(\mathbf{Z})}\Big[\ln q(\mathbf{Z})\Big] + \mathbb{E}_{q(\mathbf{Z})}\Big[\ln p(\mathbf{X})\Big] - \mathbb{E}_{q(\mathbf{Z})}\Big[\ln p(\mathbf{X}, \mathbf{Z})\Big] \\ & = \mathbb{E}_{q(\mathbf{Z})}\Big[\ln q(\mathbf{Z})\Big] - \mathbb{E}_{q(\mathbf{Z})}\Big[\ln p(\mathbf{X}, \mathbf{Z})\Big] + \ln p(\mathbf{X}) \\ \end{aligned} \] This quantity is intractable to compute since it denends on the model evidence \(\ln p(\mathbf{X})\), which however does not depend on the choice of the variational distribution \(q(\mathbf{Z})\). Instead, we optimize an alternative objective function which is equivalent to the \(\mathcal{KL}\) up to an added constant (i.e. the model evidence) \[ \begin{aligned} q_{ELBO}(Z) & = \mathbb{E}_{q(\mathbf{Z})}\Big[\ln p(\mathbf{X}, \mathbf{Z})\Big] - \mathbb{E}_{q(\mathbf{Z})}\Big[\ln q(\mathbf{Z})\Big] \end{aligned} \] This function is called the evidence lower bound (ELBO) and maximizing ELBO results in minizing the \(\mathcal{KL}\) divergence defined above, since the ELBO is the negative \(\mathcal{KL}\) plus the model evidence \(\ln p(\mathbf{X})\).

Examining the ELBO gives intuitions about the optimal variational density. We rewrite the ELBO as a sum of the expected log likelihood of the data and the \(\mathcal{KL}\) divergence between the prior \(p(\mathbf{Z})\) and \(q(\mathbf{Z})\) \[ \begin{aligned} q_{ELBO}(Z) & = \mathbb{E}_{q(\mathbf{Z})}\Big[\ln p(\mathbf{X}, \mathbf{Z})\Big] - \mathbb{E}_{q(\mathbf{Z})}\Big[\ln q(\mathbf{Z})\Big] \\ & = \mathbb{E}_{q(\mathbf{Z})}\Big[\ln p(\mathbf{X} | \mathbf{Z})\Big] + \mathbb{E}_{q(\mathbf{Z})}\Big[\ln p(\mathbf{Z})\Big] - \mathbb{E}_{q(\mathbf{Z})}\Big[\ln q(\mathbf{Z})\Big] \\ & = \mathbb{E}_{q(\mathbf{Z})}\Big[\ln p(\mathbf{X} | \mathbf{Z})\Big] - \mathcal{KL} (q(\mathbf{Z}) \; ||\; p(\mathbf{Z}))\\ \end{aligned} \] The first term is an expected likelihood; it encourages densities that place their mass on configurations of the latent variables that explain the observed data. The second term is the negative divergence between the variational density and the prior; it encourages densities close to the prior. Thus, the variational objective mirrors the usual balance between likelihood and prior.

Another property of the ELBO is that it lower-bounds the (log) evidence, \(\ln p(\mathbf{X}) \geq q_{ELBO}(Z)\) for any \(q(\mathbf{Z})\). To see this notice that we can write the log evidence as \[ \ln p(\mathbf{X}) = \mathcal{KL}(q(\mathbf{Z})\; || \; p(\mathbf{Z} | \mathbf{X})) + q_{ELBO}(Z) \] The bound then follows from the fact that \(\mathcal{KL}(q \;||\;p) \geq 0\), and \(q_{ELBO}(Z)\) achieves that bound only when the \(\mathcal{KL}\) divergence vanishes, that is \(\mathcal{KL}\big[q(\mathbf{Z})\;||\; p(\mathbf{Z} | \mathbf{X}) \big] = 0\). But, the \(\mathcal{KL}(q \;||\;p)\) is zero if and only if \(q = p\), which would lead to using the posterior distribution as our proposal. We can derive the same results through Jensen’s inequality Jordan et. al. (1999).

1.2 The mean-field variational family

We now describe a variational family \(\mathcal{Q}\), to complete the specification of the optimization problem. The complexity of the family determines the complexity of the optimization; it is more difficult to optimize over a complex family than a simple family. Here we focus on the mean-field variational family where the latent variables are mutually independent and each governed by a distinct factor in the variational density. A generic member of the mean-field variational family is \[ q(\mathbf{Z}) = \prod_{m=1}^{M} q_{m}(Z_{m}) \] Each latent variable \(Z_{m}\) is governed by its own variational factor, the density \(q_{m}(Z_{m})\). We emphasize that the variational family is not a model of the observed data, indeed, the data \(\mathbf{X}\) does not appear in the above equation. Instead, it is the ELBO, and the corresponding KL minimization problem, that connects the fitted variational density to the data and model. Also, notice we have not specified the parametric form of the individual variational factors. In principle, each can take on any parametric form appropriate to the corresponding random variable.

1.3 Coordinate ascent mean-field variational inference

Using the ELBO and the mean-field family, we have cast approximate conditional inference as an optimization problem. One of the most commonly used algorithms for solving this optimization problem is coordinate ascent variational inference (CAVI) (Bishop, 2006). CAVI iteratively optimizes each factor of the mean-field variational density, while holding the others fixed. It climbs the ELBO to a local optimum.

Consider the \(m^{th}\) latent variable \(Z_{m}\). The complete conditional of \(Z_{m}\) is its conditional density given all of the other latent variables in the model and the observations, i.e. \(p(Z_{m} | \mathbf{Z}_{-m}, \mathbf{X})\). Fix the other variational factors \(q_{j}(Z_{j})\), \(j \neq m\), then the log of the optimal solution for factor \(q_{m}(Z_{m})\) is \[ \ln q_{m}^{*}(Z_{m}) = \mathbb{E}_{j \neq m}\Big[\ln p(\mathbf{X}, \mathbf{Z})\Big] + \text{const} \] That is, the log of the optimal solution for factor \(q_{m}\), is obtained simply by considering the log of the joint distribution over all hidden and observed variables and then taking the expectation w.r.t all of the other factors \(\{q_{j}\}\) for \(j \neq m\).

To obtain the form of \(q^{*}_{m}(Z_{m})\) we exponentiate on both sides and normalize to obtain \[ q_{m}^{*}(Z_{m}) = \frac{\exp\left\{ \mathbb{E}_{j \neq m}\Big[\ln p(\mathbf{X}, \mathbf{Z})\Big] \right\} }{\int\exp\left\{ \mathbb{E}_{j \neq m}\Big[\ln p(\mathbf{X}, \mathbf{Z})\Big] \right\}d Z_{m}} \]

2 Linear regression

Linear regression assumes \[ y_{i} = \mathbf{w}^{T}\mathbf{x}_{i} + \epsilon_{n} \] where the response \(y_{i}\) is a linear function of the covariate \(\mathbf{x}_{i} \in \mathbb{R}^{D}\) and is linear in the parameters \(\mathbf{w}\) as well. Collecting \(I\) response variables \(\mathbf{y} \in \mathbb{R}^{I}\) we have \[ \mathbf{y} = \mathbf{X}\mathbf{w} + \mathbf{\epsilon} \] where \(\mathbf{X}\in \mathbb{R}^{I\times D}\) is referred to as design matrix. Now, given a training set \(\mathcal{D} = \{\mathbf{X}, \mathbf{y}\}\), estimate \(\mathbf{w}\) so that the response \(y_{*}\) to a new data point \(\mathbf{x}_{*}\) can be predicted, i.e., \(\mathbb{E}[y_{*} | \mathbf{x}_{*}] = \mathbf{w}^{T}\mathbf{x}_{*}\).

3 Bayesian Linear Regression

The likelihood function for \(\mathbf{w}\) and the prior over \(\mathbf{w}\), are given by \[ \begin{aligned} p(\mathbf{y} | \mathbf{X}, \mathbf{w}) & = \prod_{i=1}^{I} \mathcal{N}(y_{i} | \mathbf{w}^{T}\mathbf{x}_{i}, \lambda^{-1}) \\ p(\mathbf{w} | \tau) &= \mathcal{N}(\mathbf{w} | \mathbf{m}_{0}, \tau^{-1} \mathbf{I}) \end{aligned} \] where, \(\lambda\) is the noise precision parameter and is assumed to be know for simplicity, although the framework is easily extented to include the distribution over \(\lambda\). A typical choice for the conjugate prior mean parameter would be \(\mathbf{m}_{0} = 0\) We also introduce a prior distribution for the (hyper)-parameter \(\tau\), which is the precision of the Gaussian prior over the weights \[ p(\tau) = Gamma(\tau | \alpha_{0}, \beta_{0}) \] Thus, the joint distribution over all the variables is given by the following factorization. \[ p(\mathbf{y}, \mathbf{w}, \tau | \mathbf{X}) = p(\mathbf{y} | \mathbf{X}, \mathbf{w})\;p(\mathbf{w} | \tau)\;p(\tau) \] The probabilistic graphical model of the Bayesian linear regression model is shown below

4 Bayesian Linear Regression Mixture Model

Imagine that each of our observations comprises of a different regression model and our goal is to cluster together regression models with similar patterns, e.g. we have observed time series data of gene expression levels for different genes, and we want to group genes that have similar patterns with respect to their expression levels over time.

Assume that our observations are \(Y \in \mathbb{R}^{N\times I_{n}}\) and each observation \(\mathbf{y}_{n} \in \mathbb{R}^{I_{n}}\) depends on covariates \(\mathbf{X}_{n}\in\mathbb{R}^{I_{n} \times D}\) and a corresponding latent variable \(\mathbf{c}_{n}\) comprising a 1-of-K binary vector with elements \(c_{nk}\) for \(k = 1, ..., K\). The conditional distribution of \(\mathbf{C}\), given the mixing coefficiens \(\pmb{\pi}\), is given by \[ p(\mathbf{C} | \pmb{\pi}) = \prod_{n=1}^{N}\prod_{k=1}^{K}\pi_{k}^{c_{nk}} \] The conditional distribution of the observed data \(\mathbf{Y}\), given the latent variables \(\mathbf{C}\) and the component parameters \(\mathbf{w}\) is \[ p(\mathbf{Y} | \mathbf{C}, \mathbf{w}, \mathbf{X}) = \prod_{n=1}^{N}\prod_{k=1}^{K}\mathcal{N}(\mathbf{y}_{n} | \mathbf{X}_{n} \mathbf{w}_{k}, \lambda^{-1}\mathbf{I}_{n})^{c_{nk}} \] where \(\mathbf{w} = \{\mathbf{w}_{k}\}\). Next we introduce priors over the parameters. We choose a Dirichlet distribution over the mixing proportions \(\pmb{\pi}\) \[ p(\pmb{\pi}) = \mathcal{D}ir( \pmb{\pi}| \pmb{\delta}_{0}) = C(\pmb{\delta}_{0})\prod_{k=1}^{K}\pi_{k}^{\delta_{0_{k}} - 1} \] where \(C(\pmb{\delta}_{0})\) is the normalization constant for the Dirichlet distribution and we have chosen the same parameter \(\delta_{0_{k}}\) for each of the mixture components to have a symmetrical Dirichlet distribution. We introduce an independent Gaussian prior over the coefficients \(\mathbf{w}_{k}\) as we did in the Bayesian linear regression, i.e. \[ \begin{aligned} p(\mathbf{w} | \pmb{\tau}) & = \prod_{k=1}^{K} \mathcal{N}(\mathbf{w}_{k} | \mathbf{0}, \tau_{k}^{-1} \mathbf{I}) \end{aligned} \]

We also introduce a prior distribution for the (hyper)-parameter \(\pmb{\tau}\), which is the precision of the Gaussian prior over the weights and we assume that each cluster has its own precision parameter \(\tau_{k}\), that is \[ p(\pmb{\tau}) = \prod_{k=1}^{K} \mathcal{G}amma(\tau_{k} | \alpha_{0}, \beta_{0}) = \prod_{k=1}^{K} \frac{1}{\Gamma(\alpha_{0})}\beta_{0}^{\alpha_{0}}\tau_{k}^{\alpha_{0}-1}e^{-\beta_{0}\tau_{k}} \]

The joint distribution over all the variables is given by the following factorization. \[ \begin{aligned} p(\mathbf{Y}, \mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau} | \mathbf{X}) & = p(\mathbf{Y} | \mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau}, \mathbf{X}) p(\mathbf{C} | \pmb{\pi}, \mathbf{w}, \pmb{\tau}) p(\pmb{\pi} | \mathbf{w}, \pmb{\tau}) p(\mathbf{w} | \pmb{\tau}) p(\pmb{\tau}) \\ & = p(\mathbf{Y} | \mathbf{C}, \mathbf{w}, \mathbf{X}) p(\mathbf{C} | \pmb{\pi}) p(\pmb{\pi}) p(\mathbf{w} | \pmb{\tau}) p(\pmb{\tau}) \end{aligned} \] where the decomposition corresponds to the probabilistic graphical model shown below

5 Variational Bayesian Linear Regression Mixture Model

To apply the variational inference machinery, we will divide our parameters in three blocks, the latent variables \(\mathbf{C}\), the parameters \((\mathbf{w}, \pmb{\pi})\) and the parameter \(\pmb{\tau}\). Hence, the variational posterior distribution is given by the factorised expression \[ q(\mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau}) = q(\mathbf{C}) q(\pmb{\pi},\mathbf{w}) q(\pmb{\tau}) \] It should be noted that this is the only assumption that we need to make in order to obtain a tractable practical solution to our Bayesian mixture model. In particular, the functional form of the factors \(q(\mathbf{C})\), \(q(\pmb{\pi},\mathbf{w})\) and \(q(\pmb{\tau})\) will be determinted automatically by optimization of the variational distribution.

5.1 Update factor \(q(\mathbf{C})\)

Let us consider the derivation of the update equation for the factor \(q(\mathbf{C})\). The log of the optimized factor is given by \[ \begin{align} \ln q^{*}(\mathbf{C}) & = \mathbb{E}_{q(\pmb{\pi},\mathbf{w},\pmb{\tau})}\Big[\ln p(\mathbf{Y}, \mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau} | \mathbf{X}) \Big] + \text{const} \\ & = \mathbb{E}_{q(\mathbf{w})}\Big[\ln p(\mathbf{Y} | \mathbf{C}, \mathbf{w}, \mathbf{X})\Big] + \mathbb{E}_{q(\pmb{\pi})} \Big[\ln p(\mathbf{C} | \pmb{\pi})\Big] + \underbrace{\mathbb{E}_{q(\pmb{\pi})}\Big[\ln p(\pmb{\pi})\Big]}_{\text{const}} + \underbrace{\mathbb{E}_{q(\mathbf{w},\pmb{\tau})}\Big[\ln p(\mathbf{w} | \pmb{\tau})\Big]}_{\text{const}} + \underbrace{\mathbb{E}_{q(\pmb{\tau})}\Big[\ln p(\pmb{\tau})\Big]}_{\text{const}} + \text{const} \\ & = \mathbb{E}_{q(\mathbf{w})}\Big[\ln p(\mathbf{Y} | \mathbf{C}, \mathbf{w}, \mathbf{X})\Big] + \mathbb{E}_{q(\pmb{\pi})} \Big[\ln p(\mathbf{C} | \pmb{\pi})\Big] + \text{const} \\ & = \mathbb{E}_{q(\mathbf{w})}\Big[\ln\prod_{n=1}^{N}\prod_{k=1}^{K}\mathcal{N}(\mathbf{y}_{n} | \mathbf{X}_{n} \mathbf{w}_{k}, \lambda^{-1}\mathbf{I}_{n})^{c_{nk}}\Big] + \mathbb{E}_{q(\pmb{\pi})}\Big[\ln \prod_{n=1}^{N}\prod_{k=1}^{K}\pi_{k}^{c_{nk}} \Big] + \text{const} \\ & = \sum_{n=1}^{N}\sum_{k=1}^{K} c_{nk}\;\mathbb{E}_{q(\mathbf{w}_{k})}\Big[\ln\mathcal{N}(\mathbf{y}_{n} | \mathbf{X}_{n} \mathbf{w}_{k}, \lambda^{-1}\mathbf{I}_{n})\Big] + \sum_{n=1}^{N}\sum_{k=1}^{K}c_{nk}\;\mathbb{E}_{q(\pi_{k})}\Big[\ln\pi_{k} \Big] + \text{const} \\ & = \sum_{n=1}^{N}\sum_{k=1}^{K} c_{nk}\;\left\{\mathbb{E}_{q(\mathbf{w}_{k})}\left[-\frac{\lambda}{2}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big)^{T}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big) \right] + \mathbb{E}_{q(\pi_{k})}\Big[\ln\pi_{k} \Big] \right\} + \text{const} \\ & = \sum_{n=1}^{N}\sum_{k=1}^{K}c_{nk}\;\ln\rho_{nk} + \text{const} \end{align} \] where we have defined \[ \ln\rho_{nk} = \mathbb{E}_{q(\mathbf{w}_{k})}\left[-\frac{\lambda}{2}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big)^{T}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big) \right] + \mathbb{E}_{q(\pi_{k})}\Big[\ln\pi_{k} \Big] \] Note that we are only interested in the functional dependence on the variable \(\mathbf{C}\), thus any terms that do not depend on \(\mathbf{C}\) can be absorbed in the additive normalized constant. Taking the exponential on both sides we obtain \[ q^{*}(\mathbf{C}) \propto \prod_{n=1}^{N}\prod_{k=1}^{K}\;\rho_{nk}^{c_{nk}} \] Requiring that this distribution be normalized, and noting that for each value of \(n\) the quantities \(c_{nk}\) are binary and sum to \(1\) over all values of \(k\), we obtain \[ q^{*}(\mathbf{C}) = \prod_{n=1}^{N}\prod_{k=1}^{K}\;r_{nk}^{c_{nk}} \quad\quad\quad \text{where} \quad r_{nk} = \frac{\rho_{nk}}{\sum_{j=1}^{K}\rho_{nj}} \] We note that the functional form for the factor \(q(\mathbf{C})\) takes the same functional form as the prior \(p(\mathbf{C}|\pmb{\pi})\). For the discrete distribution \(q^{*}(\mathbf{C})\) we have the standard result, which is the expected value of a multinomial variable \[ \mathbb{E}_{q(c_{nk})}\big[c_{nk}\big] = r_{nk} \] for which we see that the quantities \(r_{nk}\) are playing the role of responsibilities. Note that the optimal solution for \(q^{*}(\mathbf{C})\) depends on moments evaluated w.r.t the distributions of other variables, and so again the variational update equations are coupled and must be solved iteratively.

5.2 Update factor \(q(\pmb{\tau})\)

The log of the optimized factor is given by \[ \begin{align} \ln q^{*}(\mathbf{\tau}) & = \mathbb{E}_{q(\mathbf{C},\pmb{\pi},\mathbf{w})}\Big[\ln p(\mathbf{Y}, \mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau} | \mathbf{X}) \Big] + \text{const} \\ & = \underbrace{\mathbb{E}_{q(\mathbf{C},\mathbf{w})}\Big[\ln p(\mathbf{Y} | \mathbf{C}, \mathbf{w}, \mathbf{X})\Big]}_{\text{const}} + \underbrace{\mathbb{E}_{q(\mathbf{C},\pmb{\pi})} \Big[\ln p(\mathbf{C} | \pmb{\pi})\Big]}_{\text{const}} + \underbrace{\mathbb{E}_{q(\pmb{\pi})}\Big[\ln p(\pmb{\pi})\Big]}_{\text{const}} + \mathbb{E}_{q(\mathbf{w})}\Big[\ln p(\mathbf{w} | \pmb{\tau})\Big] + \ln p(\pmb{\tau}) + \text{const} \\ & = \mathbb{E}_{q(\mathbf{w})}\Big[\ln p(\mathbf{w} | \pmb{\tau})\Big] + \ln p(\pmb{\tau}) + \text{const} \\ & = \sum_{k=1}^{K}\mathbb{E}_{q(\mathbf{w}_{k})}\Big[\ln p(\mathbf{w}_{k} | \tau_{k})\Big] + \sum_{k = 1}^{K}\ln p(\tau_{k}) + \text{const} \end{align} \] Hence we observe that the right hand comprises a sum over \(k\), i.e. each \(\tau_{k}\) is independent of each other leading to the further factorization \[ q(\pmb{\tau}) = \prod_{k=1}^{K}q(\tau_{k}) \] We refer to these additional factorizations as induced factorizations because they arise from an intersection between the factorization assumed in the variational posterior distribution and the conditional independence properties of the true joint distribution.

Hence, \[ \begin{aligned} \ln q^{*}(\tau_{k}) & = \mathbb{E}_{q(\mathbf{w}_{k})}\Big[\ln p(\mathbf{w}_{k} | \tau_{k})\Big] + \ln p(\tau_{k}) + \text{const} \\ & = \underbrace{\frac{D}{2}\ln\tau_{k} - \frac{\tau_{k}}{2}\mathbb{E}_{q(\mathbf{w}_{k})}[\mathbf{w}_{k}^{T}\mathbf{w}_{k}]}_{\text{Gaussian PDF}} + \underbrace{(\alpha_{0} - 1)\ln\tau_{k} - \beta_{0}\tau_{k}}_{\text{Gamma PDF}}\\ & = \underbrace{(\alpha_{0} + \frac{D}{2} - 1)}_{\alpha_{k}\;\text{parameter}}\ln\tau_{k} - \Big(\underbrace{\beta_{0} + \frac{1}{2}\mathbb{E}_{q(\mathbf{w}_{k})}[\mathbf{w}_{k}^{T}\mathbf{w}_{k}]}_{\beta_{k}\;\text{parameter}}\Big)\tau_{k} \end{aligned} \] which is the log of the (unnormalized) Gamma distribution, leading to \[ \begin{aligned} q^{*}(\tau_{k}) & = \mathcal{G}\text{amma}(\tau_{k} | \alpha_{k}, \beta_{k}) \\ \alpha_{k} & = \alpha_{0} + \frac{D}{2} \\ \beta_{k} & = \beta_{0} + \frac{1}{2}\mathbb{E}_{q(\mathbf{w}_{k})}[\mathbf{w}_{k}^{T}\mathbf{w}_{k}] \end{aligned} \]

5.3 Update factor \(q(\pmb{\pi}, \mathbf{w})\)

Now let us consider the factor \(q(\pmb{\pi}, \mathbf{w})\) in the variational posterior distribution. Using again the general expression for the optimized factor we have \[ \begin{align} \ln q^{*}(\pmb{\pi}, \mathbf{w}) & = \mathbb{E}_{q(\mathbf{C},\pmb{\tau})}\Big[\ln p(\mathbf{Y}, \mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau} | \mathbf{X}) \Big] + \text{const} \\ & = \mathbb{E}_{q(\mathbf{C})}\Big[\ln p(\mathbf{Y} | \mathbf{C}, \mathbf{w}, \mathbf{X})\Big] + \mathbb{E}_{q(\mathbf{C})} \Big[\ln p(\mathbf{C} | \pmb{\pi})\Big] + \ln p(\pmb{\pi}) + \mathbb{E}_{q(\pmb{\tau})}\Big[\ln p(\mathbf{w} | \pmb{\tau})\Big]+ \underbrace{\mathbb{E}_{q(\pmb{\tau})}\Big[\ln p(\pmb{\tau})\Big]}_{\text{const}} + \text{const} \\ & = \underbrace{\mathbb{E}_{q(C)}\Big[\ln p(\mathbf{Y} | \mathbf{C}, \mathbf{w}, \mathbf{X})\Big] + \mathbb{E}_{q(\pmb{\tau})}\Big[\ln p(\mathbf{w} | \pmb{\tau}) \Big]}_{\ln q(\mathbf{w})} + \underbrace{\mathbb{E}_{q(\mathbf{C})} \Big[\ln p(\mathbf{C} | \pmb{\pi})\Big] + \ln p(\pmb{\pi})}_{\ln q(\pmb{\pi})} + \text{const} \\ \end{align} \]

We observe that the right-hand side of this expression decomposes into a sum of terms involving only \(\pmb{\pi}\) together with terms only involving \(\mathbf{w}\), which implies that the variational posterior \(q(\pmb{\pi}, \mathbf{w})\) factorizes to give \(q(\pmb{\pi})q(\mathbf{w})\). Furthermore, the terms involving \(\mathbf{w}\) themselves comprise a sum over \(k\), leading to the further factorization \[ q(\pmb{\pi}, \mathbf{w}) = q(\pmb{\pi})\prod_{k=1}^{K} q(\mathbf{w}_{k}) \]

5.3.1 Update factor \(q(\pmb{\pi})\)

Since this factor depends only on the latent variables \(\mathbf{C}\) and parameters \(\pmb{\pi}\), the optimal variational factor will be the same for any mixture model, irrespective of the observation model. Hence, we have \[ \begin{align} \ln q^{*}(\pmb{\pi}) & = \text{ln}\;p(\pmb{\pi}) + \mathbb{E}_{q(\mathbf{C})}\Big[\ln p(\mathbf{C} | \pmb{\pi})\Big] + \text{const} \\ & = \underbrace{\ln C(\pmb{\delta}_{0})}_{\text{const}} + \sum_{k=1}^{K}\ln\pi_{k}^{\delta_{0_{k}}-1} + \sum_{k=1}^{K}\sum_{n=1}^{N}\underbrace{\mathbb{E}_{q(c_{nk})}\big[c_{nk}\big]}_{r_{nk}}\ln\pi_{k} + \text{const} \\ & = \sum_{k=1}^{K}\ln\pi_{k}^{\delta_{0_{k}}-1} + \sum_{k=1}^{K}\sum_{n=1}^{N}{r_{nk}}\ln\pi_{k} + \text{const} \\ \end{align} \] Taking the exponential on both sides we observe that \(q^{*}(\pmb{\pi})\) is a Dirichlet distribution \[ \begin{align} q^{*}(\pmb{\pi}) & = \prod_{k=1}^{K}\pi_{k}^{\delta_{0_{k}} - 1} + \prod_{k=1}^{K}\prod_{n=1}^{N}\pi_{k}^{r_{nk}} + \text{const} \\ & = \prod_{k=1}^{K}\pi_{k}^{\delta_{0_{k}} - 1} + \prod_{k=1}^{K}\pi_{k}^{\sum_{n=1}^{N}r_{nk}} + \text{const} \\ & = \prod_{k=1}^{K}\pi_{k}^{(\delta_{0_{k}} + \sum_{n=1}^{N}r_{nk} - 1)} \\ & = \mathcal{D}ir(\pmb{\pi} | \pmb{\delta}) \end{align} \] where \(\pmb{\delta}\) has components \(\delta_{k}\) given by \(\delta_{k} = \delta_{0_{k}} + \sum_{n=1}^{N}r_{nk}\).

5.3.2 Update factor \(q(\mathbf{w}_{k})\)

Finally, let us consider the derivation of the update equation for the factor \(q(\mathbf{w}_{k})\). The log of the optimized factor is given by \[ \begin{align} \ln q^{*}(\mathbf{w}_{k}) & = \mathbb{E}_{q(\mathbf{C})}\Big[\ln p(\mathbf{Y} | \mathbf{C}, \mathbf{w}_{k}, \mathbf{X})\Big] + \mathbb{E}_{q(\tau_{k})}\Big[\ln p(\mathbf{w}_{k} | \tau_{k})\Big] + \text{const} \\ & = \sum_{n=1}^{N} \mathbb{E}_{q(c_{nk})}\big[c_{nk}\big]\ln\mathcal{N}(\mathbf{y}_{n} | \mathbf{X}_{n} \mathbf{w}_{k}, \lambda^{-1}\mathbf{I}_{n}) + \mathbb{E}_{q(\tau_{k})}\Big[\ln p(\mathbf{w}_{k} | \tau_{k})\Big] + \text{const} \\ & = \sum_{n=1}^{N} r_{nk} \left\{-\frac{\lambda}{2}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big)^{T}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big) \right\} - \frac{1}{2} \mathbb{E}_{q(\tau_{k})}\big[\tau_{k}\big] \mathbf{w}_{k}^{T}\mathbf{w}_{k} + \text{const} \\ & = \sum_{n=1}^{N} r_{nk} \left\{-\frac{\lambda}{2}\Big(\underbrace{\mathbf{y}_{n}^{T}\mathbf{y}_{n}}_{\text{const}} - 2\mathbf{w}_{k}^{T} \mathbf{X}_{n}^{T}\mathbf{y}_{n} + \mathbf{w}_{k}^{T} \mathbf{X}_{n}^{T} \mathbf{X}_{n} \mathbf{w}_{k} \Big)\right\} - \frac{1}{2} \mathbb{E}_{q(\tau_{k})}\big[\tau_{k}\big] \mathbf{w}_{k}^{T}\mathbf{w}_{k} + \text{const} \\ & = \lambda\mathbf{w}_{k}^{T}\sum_{n=1}^{N} r_{nk} \mathbf{X}_{n}^{T}\mathbf{y}_{n} - \frac{\lambda}{2} \mathbf{w}_{k}^{T}\sum_{n=1}^{N} \Big\{r_{nk} \mathbf{X}_{n}^{T} \mathbf{X}_{n} \Big\} \mathbf{w}_{k} - \frac{1}{2} \mathbb{E}_{q(\tau_{k})}\big[\tau_{k}\big] \mathbf{w}_{k}^{T}\mathbf{w}_{k} + \text{const} \\ & = \lambda\mathbf{w}_{k}^{T}\sum_{n=1}^{N} r_{nk} \mathbf{X}_{n}^{T}\mathbf{y}_{n} - \frac{1}{2} \mathbf{w}_{k}^{T}\left\{ \mathbb{E}_{q(\tau_{k})}\big[\tau_{k}\big]\mathbf{I} + \lambda \sum_{n=1}^{N} r_{nk} \mathbf{X}_{n}^{T} \mathbf{X}_{n} \right\}\mathbf{w}_{k} + \text{const} \\ \end{align} \] Because this is a quadratic form, the distribution \(q^{*}(\mathbf{w}_{k})\) is a Gaussian distribution and we can complete the square to identify the mean and covariance \[ \begin{aligned} q^{*}(\mathbf{w}_{k}) & = \mathcal{N}(\mathbf{w}_{k} | \mathbf{m}_{k}, \mathbf{S}_{k}) \\ \mathbf{m}_{k} & = \lambda\mathbf{S}_{k}\sum_{n=1}^{N}r_{nk}\mathbf{X}_{n}^{T}\mathbf{y}_{n} \\ \mathbf{S}_{k} & = \left(\mathbb{E}_{q(\tau_{k})}\big[\tau_{k}\big]\mathbf{I} + \lambda\sum_{n=1}^{N}r_{nk} \mathbf{X}_{n}^{T}\mathbf{X}_{n}\right)^{-1} \end{aligned} \] To complete the square we make use of the fact that the exponent in a general Gaussian distribution \(\mathcal{N}(\mathbf{x} | \pmb{\mu}, \pmb{\Sigma})\) can be written \[ -\frac{1}{2}(\mathbf{x}-\pmb{\mu})^{T}\pmb{\Sigma}^{-1}(\mathbf{x}-\pmb{\mu}) = -\frac{1}{2}\mathbf{x}^{T}\pmb{\Sigma}^{-1}\mathbf{x} + \mathbf{x}^{T}\pmb{\Sigma}^{-1}\pmb{\mu} + \text{const} \] where const denotes terms that are independent of \(\mathbf{x}\), and we have made use of the symmetry of \(\pmb{\Sigma}\).

5.4 Computing expectations

When deriving the optimized variational factors, the derivations involved expectations with respect to the variational distributions. These expectations are computed as follows

Term \(\mathbb{E}_{q(\tau_{k})}\big[\tau\big]\): The factor \(q(\tau_{k})\) is a Gamma distribution \(\mathcal{G}amma(\tau_{k} | \alpha_{k}, \beta_{k})\), hence its expected value is \[ \mathbb{E}_{q(\tau_{k})}\big[\tau\big] = \frac{\alpha_{k}}{\beta_{k}} \]

Term \(\mathbb{E}_{q(\mathbf{w}_{k})}\big[\mathbf{w}_{k}^{T}\mathbf{w}_{k}\big]\): The factor \(q(\mathbf{w}_{k})\) is a Gaussian distribution \(\mathcal{N}(\mathbf{w}_{k} | \mathbf{m}_{k}, \mathbf{S}_{k})\) hence we have \[ \begin{align} \mathbb{E}_{q(\mathbf{w}_{k})}\big[\mathbf{w}_{k}^{T}\mathbf{w}_{k}\big] = \text{tr}\Big(\mathbb{E}_{q(\mathbf{w}_{k})} \big[\mathbf{w}_{k}\mathbf{w}_{k}^{T} \big]\Big) = \text{tr}\Big(\mathbf{m}_{k}\mathbf{m}_{k}^{T} + \mathbf{S}_{k}\Big) = \text{tr}\left(\mathbf{m}_{k}\mathbf{m}_{k}^{T}\right) + \text{tr}(\mathbf{S}_{k}) = \mathbf{m}_{k}^{T}\mathbf{m}_{k} + \text{tr}(\mathbf{S}_{k}) \\ \end{align} \]

Term \(\mathbb{E}_{q(\pi_{k})}\big[\ln\pi_{k}\big]\): The factor \(q(\pi_{k})\) is a Dirichlet distribution \(\mathcal{D}ir(\pi_{k} | \delta_{k})\) and from standard results (see Bishop book Eq. B.21) we obtain \[ \mathbb{E}_{q(\pi_{k})}\big[\ln\pi_{k}\big] = \psi(\delta_{k}) - \psi(\hat{\delta}) \quad\quad \left(\text{where} \;\; \hat{\delta} = \sum_{k=1}^{K}\delta_{k} \right) \] where \(\psi(\delta) \equiv \frac{d}{d\delta}\ln\Gamma(\delta)\) is the digamma function.

Term \(\mathbb{E}_{q(c_{nk})}\big[c_{nk}\big]\): The factor \(q(c_{nk})\) is a Multinomial distribution \(\mathcal{M}ultin(c_{nk} | r_{nk})\), hence its expected value is \[ \mathbb{E}_{q(c_{nk})}\big[c_{nk}\big] = r_{nk} \]

Term \(\mathbb{E}_{q(\mathbf{w}_{k})}\left[-\frac{\lambda}{2}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big)^{T}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big) \right]\): The factor \(q(\mathbf{w}_{k})\) is a Gaussian distribution \(\mathcal{N}(\mathbf{w}_{k} | \mathbf{m}_{k}, \mathbf{S}_{k})\), hence we have \[ \begin{align} \mathbb{E}_{q(\mathbf{w}_{k})}\left[-\frac{\lambda}{2}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big)^{T}\Big(\mathbf{y}_{n} - \mathbf{X}_{n}\mathbf{w}_{k}\Big) \right] & = \mathbb{E}_{q(\mathbf{w}_{k})}\left[ -\frac{\lambda}{2}\Big(\mathbf{y}_{n}^{T}\mathbf{y}_{n}- 2\mathbf{w}_{k}^{T} \mathbf{X}_{n}^{T}\mathbf{y}_{n} + \mathbf{w}_{k}^{T} \mathbf{X}_{n}^{T} \mathbf{X}_{n} \mathbf{w}_{k} \Big)\right] \\ & = -\frac{\lambda}{2}\mathbf{y}_{n}^{T}\mathbf{y}_{n} + \lambda \mathbb{E}_{q(\mathbf{w}_{k})}\left[\mathbf{w}_{k}^{T}\right] \mathbf{X}_{n}^{T}\mathbf{y}_{n} -\frac{\lambda}{2} \mathbb{E}_{q(\mathbf{w}_{k})}\left[ \mathbf{w}_{k}^{T} \mathbf{X}_{n}^{T} \mathbf{X}_{n} \mathbf{w}_{k}\right] \\ & = -\frac{\lambda}{2}\mathbf{y}_{n}^{T}\mathbf{y}_{n} + \lambda \mathbf{m}_{k}^{T} \mathbf{X}_{n}^{T}\mathbf{y}_{n} -\frac{\lambda}{2} \text{tr}\left( \mathbf{X}_{n}^{T} \mathbf{X}_{n} \mathbb{E}_{q(\mathbf{w}_{k})}\left[\mathbf{w}_{k}\mathbf{w}_{k}^{T}\right] \right) \\ & = -\frac{\lambda}{2}\mathbf{y}_{n}^{T}\mathbf{y}_{n} + \lambda \mathbf{m}_{k}^{T} \mathbf{X}_{n}^{T}\mathbf{y}_{n} -\frac{\lambda}{2} \text{tr}\left( \mathbf{X}_{n}^{T} \mathbf{X}_{n} \big(\mathbf{m}_{k}\mathbf{m}_{k}^{T} + S_{k} \big)\right) \\ \end{align} \]

5.5 Variational lower bound

The variational lower bound \(\mathcal{L}(q)\) (i.e. evidence lower bound (ELBO) ) is given by \[ \begin{aligned} \mathcal{L}(q) & = \sum_{\mathbf{C}}\int\int\int q(\mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau}) \ln\left\{\frac{p(\mathbf{Y}, \mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau} | \mathbf{X})}{q(\mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau})}\right\}d \pmb{\pi}\;d\mathbf{w}\;d\pmb{\tau} \\ & = \;\mathbb{E}_{q(\mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau})}\Big[\ln p(\mathbf{Y}, \mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau} | \mathbf{X})\Big] - \mathbb{E}_{q(\mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau})}\Big[\ln q(\mathbf{C}, \pmb{\pi}, \mathbf{w}, \pmb{\tau})\Big] \\ & = \;\mathbb{E}_{q(\mathbf{C}, \mathbf{w})}\Big[\ln p(\mathbf{Y} | \mathbf{C}, \mathbf{w}, \mathbf{X})\Big] + \mathbb{E}_{q(\mathbf{C}, \pmb{\pi})}\Big[\ln p(\mathbf{C} | \pmb{\pi})\Big] + \mathbb{E}_{q(\pmb{\pi})}\Big[\ln p(\pmb{\pi})\Big] + \mathbb{E}_{q(\mathbf{w}, \pmb{\tau})} \Big[\ln p(\mathbf{w} | \pmb{\tau})\Big] + \mathbb{E}_{q(\pmb{\tau}) }\Big[\ln p(\pmb{\tau})\Big] \\ & \quad \qquad- \mathbb{E}_{q(\mathbf{C})}\Big[\ln q(\mathbf{C})\Big] - \mathbb{E}_{q(\pmb{\pi})}\Big[\ln q(\pmb{\pi})\Big] - \mathbb{E}_{q(\mathbf{w})}\Big[\ln q(\mathbf{w})\Big] - \mathbb{E}_{q(\pmb{\tau})} \Big[\ln q(\pmb{\tau})\Big] \end{aligned} \]

Note that the terms involving expectations of \(\ln q(\cdot)\) distributions simply represent the negative information entropies \(\mathbb{H}(\cdot)\) of those distributions. The various terms in the ELBO are evaluated to give \[ \begin{align} \mathbb{E}_{q(\mathbf{C}, \mathbf{w})}\Big[\ln p(\mathbf{Y} | \mathbf{C}, \mathbf{w}, \mathbf{X})\Big] & = \sum_{n=1}^{N}\sum_{k=1}^{K} \mathbb{E}_{q(c_{nk})} \big[c_{nk}\big]\mathbb{E}_{q(\mathbf{w}_{k})}\big[\ln\mathcal{N}(\mathbf{y}_{n}|\mathbf{X}_{n} \mathbf{w}_{k},\lambda^{-1}\mathbf{I}_{n}) \big] \\ & = \sum_{n=1}^{N}\sum_{k=1}^{K} r_{nk}\left\{-\frac{I_{n}}{2}\ln \left(2\pi\lambda^{-1}\right) - \frac{\lambda}{2} \mathbb{E}_{q(\mathbf{w}_{k})} \left[ \left( \mathbf{y}_{n} - \mathbf{X}_{n} \mathbf{w}_{k}\right)^{T}\left(\mathbf{y}_{n} - \mathbf{X}_{n} \mathbf{w}_{k} \right) \right] \right\}\\ & = \sum_{n=1}^{N}\sum_{k=1}^{K} r_{nk}\left\{-\frac{I_{n}}{2}\ln \left(2\pi\lambda^{-1}\right) -\frac{\lambda}{2} \mathbf{y}_{n}^{T}\mathbf{y}_{n} + \lambda \mathbf{m}_{k}^{T} \mathbf{X}_{n}^{T}\mathbf{y}_{n} -\frac{\lambda}{2} \text{tr}\left( \mathbf{X}_{n}^{T} \mathbf{X}_{n} \big(\mathbf{m}_{k}\mathbf{m}_{k}^{T} + S_{k} \big)\right) \right\}\\ \mathbb{E}_{q(\mathbf{C}, \pmb{\pi})}\Big[\ln p(\mathbf{C} | \pmb{\pi})\Big] & = \sum_{n = 1}^{N}\sum_{k=1}^{K} \mathbb{E}_{q(c_{nk})}\big [c_{nk}\big ]\mathbb{E}_{q(\pi_{k})}\big[\ln\pi_{k}\big] \\ & = \sum_{n = 1}^{N}\sum_{k=1}^{K} r_{nk}\left\{\psi(\delta_{k}) - \psi(\hat{\delta})\right\} \\ \mathbb{E}_{q(\pmb{\pi})}\Big[\ln p(\pmb{\pi})\Big] & = \ln C(\pmb{\delta}_{0}) + \sum_{k=1}^{K} (\delta_{0_{k}} - 1)\left\{\psi(\delta_{k}) - \psi(\hat{\delta})\right\} \\ \mathbb{E}_{q(\mathbf{w}, \pmb{\tau})}\Big[\ln p(\mathbf{w} | \pmb{\tau})\Big] & = \sum_{k=1}^{K}\Bigg\{ \mathbb{E}_{q(\mathbf{w}_{k}, \tau_{k})} \left[ -\frac{D}{2}\ln 2\pi - \frac{D}{2}\ln \tau_{k}^{-1} - \frac{1}{2\tau_{k}^{-1}} \mathbf{w}_{k}^{T}\mathbf{w}_{k}\right]\Bigg\} \\ & = \sum_{k=1}^{K}\Bigg\{ -\frac{D}{2}\ln 2\pi + \frac{D}{2}\mathbb{E}_{q(\tau_{k})} \big[\ln \tau_{k}\big] - \frac{1}{2}\mathbb{E}_{q(\tau_{k})}\big[\tau_{k}\big] \mathbb{E}_{q(\mathbf{w}_{k})} \big[\mathbf{w}_{k}^{T}\mathbf{w}_{k}\big]\Bigg\} \\ & = \sum_{k=1}^{K}\Bigg\{ -\frac{D}{2}\ln 2\pi + \frac{D}{2}\Big(\psi(\alpha_{k}) - \ln \beta_{k}\Big) - \frac{\alpha_{k}}{2\beta_{k}} \Big(\mathbf{m}_{k}^{T}\mathbf{m}_{k} + \text{tr}(\mathbf{S}_{k})\Big)\Bigg\} \\ \mathbb{E}_{q(\pmb{\tau})}\Big[\ln p(\pmb{\tau})\Big] & = \sum_{k=1}^{K}\Bigg\{ \alpha_{0}\ln \beta_{0} + (\alpha_{0} - 1)\Big(\psi(\alpha_{k}) - \ln \beta_{k}\Big) - \beta_{0}\frac{\alpha_{k}}{\beta_{k}} - \ln\Gamma(\alpha_{0})\Bigg\} \\ \mathbb{E}_{q(\mathbf{C})}\Big[\ln q(\mathbf{C})\Big] & = \sum_{n=1}^{N}\sum_{k=1}^{K}r_{nk}\ln r_{nk} \\ \mathbb{E}_{q(\pmb{\pi})}\Big[\ln q(\pmb{\pi})\Big] & = \ln C(\pmb{\delta}) + \sum_{k=1}^{K}(\delta_{k} - 1)\left\{\psi(\delta_{k}) - \psi(\hat{\delta})\right\} \\ \mathbb{E}_{q(\mathbf{w})}\Big[\ln q(\mathbf{w})\Big] & = \sum_{k=1}^{K}\Big\{-\frac{1}{2}\ln |\mathbf{S}_{k}| - \frac{D}{2}(1 + \ln 2\pi)\Big\} \\ \mathbb{E}_{q(\pmb{\tau})}\Big[\ln q(\pmb{\tau})\Big] & = \sum_{k=1}^{K}\Big\{-\ln\Gamma(\alpha_{k}) + (\alpha_{k} - 1)\psi(\alpha_{k}) + \ln \beta_{k} - \alpha_{k}\Big\} \end{align} \]

5.6 Predictive density

The predictive density of a new observation \(\mathbf{y}_{*}\) which will be associated with a latent variable \(\mathbf{c}_{*}\) and covariates \(\mathbf{X}_{*}\) is given by \[ \begin{align} p(\mathbf{y}_{*} | \mathbf{X}_{*}, \mathbf{Y}, \mathbf{X}) & = \sum_{c}\int\int\int p(\mathbf{y}_{*}, \mathbf{c}_{*}, \pmb{\pi}, \mathbf{w}, \pmb{\tau} | \mathbf{X}_{*}, \mathbf{Y}, \mathbf{X}) \;d\pmb{\pi}\;d\mathbf{w}\;d\pmb{\tau} \\ & = \sum_{c}\int\int\int p(\mathbf{y}_{*} | \mathbf{c}_{*}, \mathbf{w}, \mathbf{X}_{*}) p(\mathbf{c}_{*} | \pmb{\pi}) p(\mathbf{w}, \pmb{\pi}, \pmb{\tau} | \mathbf{Y}, \mathbf{X}) \;d\pmb{\pi}\;d\mathbf{w}\;d\pmb{\tau} \\ & = \sum_{k=1}^{K}\int\int\int p(\mathbf{y}_{*} | \mathbf{w}_{k}, \mathbf{X}_{*}) \pi_{k}\; p(\mathbf{w}, \pmb{\pi}, \pmb{\tau} | \mathbf{Y}, \mathbf{X}) \;d\pmb{\pi}\;d\mathbf{w}\;d\pmb{\tau} \qquad\qquad(\text{Summation over}\; c) \\ & = \sum_{k=1}^{K}\int\int\int p(\mathbf{y}_{*} | \mathbf{w}_{k}, \mathbf{X}_{*}) \pi_{k}\; q(\pmb{\pi}) q(\mathbf{w}_{k}) q(\tau_{k}) \;d\pmb{\pi}\;d\mathbf{w}_{k}\;d\tau_{k} \qquad\quad(\text{Variational approximation}) \\ & = \sum_{k=1}^{K}\frac{\delta_{k}}{\hat{\delta}}\int\int p(\mathbf{y}_{*} | \mathbf{w}_{k}, \mathbf{X}_{*}) q(\mathbf{w}_{k}) q(\pmb{\tau}) \;d\mathbf{w}_{k}\;d\tau_{k} \qquad\qquad\qquad\quad(\text{Integrate}\; \pmb{\pi})\\ & = \sum_{k=1}^{K}\frac{\delta_{k}}{\hat{\delta}}\int \mathcal{N}(\mathbf{y}_{*} | \mathbf{X}_{*}\mathbf{w}_{k}, \lambda^{-1}\mathbf{I}_{n}) \mathcal{N}(\mathbf{w}_{k} | \mathbf{m}_{k}, \mathbf{S}_{k}) \;d\mathbf{w}_{k} \quad\qquad\qquad\;\;(\text{Integrate}\; \tau_{k}) \\ & = \sum_{k=1}^{K}\frac{\delta_{k}}{\hat{\delta}} \mathcal{N}\left(\mathbf{y}_{*} | \mathbf{X}_{*}\mathbf{m}_{k},\; \lambda^{-1}\mathbf{I}_{n} + \text{diag}\big(\mathbf{X}_{*} \mathbf{S}_{k} \mathbf{X}_{*}^{T}\big)\right) \end{align} \]

5.7 Update equations summary

Below we summarize the update equations we obtain for the Variational Bayes, which we split to variational E and M steps to be consistent with the Expectation Maximization algorithm.

5.7.1 Variational E-step

Compute responsibilities \[ r_{nk} = \frac{\rho_{nk}}{\sum_{j=1}^{K}\rho_{nj}} \] where \[ \rho_{nk} \propto \exp\left\{\left\{\psi(\delta_{k}) - \psi(\hat{\delta})\right\} + \lambda \mathbf{m}_{k}^{T} \mathbf{X}_{n}^{T}\mathbf{y}_{n} -\frac{\lambda}{2} \text{tr}\left( \mathbf{X}_{n}^{T} \mathbf{X}_{n} \big(\mathbf{m}_{k}\mathbf{m}_{k}^{T} + S_{k} \big)\right) \right\} \]

5.7.2 Variational M-step

Compute variational ditribution parameters \[ \begin{align} \delta_{k} & = \delta_{0_{k}} + \sum_{n=1}^{N}r_{nk} \\ \mathbf{m}_{k} & = \lambda\mathbf{S}_{k}\sum_{n=1}^{N}r_{nk}\mathbf{X}_{n}^{T}\mathbf{y}_{n} \\ \mathbf{S}_{k} & = \left(\frac{\alpha_{k}}{\beta_{k}}\mathbf{I} + \lambda\sum_{n=1}^{N}r_{nk} \mathbf{X}_{n}^{T}\mathbf{X}_{n}\right)^{-1} \\ \alpha_{k} & = \alpha_{0} + \frac{D}{2} \\ \beta_{k} & = \beta_{0} + \frac{1}{2}\left(\mathbf{m}_{k}^{T}\mathbf{m}_{k} + \text{tr}(\mathbf{S}_{k})\right)\\ \end{align} \]

6 Code implementation in R

Helper plotting functions

suppressPackageStartupMessages(require(BPRMeth))
suppressPackageStartupMessages(require(matrixcalc))
suppressPackageStartupMessages(library(ggplot2))
suppressPackageStartupMessages(library(data.table))
suppressPackageStartupMessages(library(purrr))
suppressPackageStartupMessages(library(mvtnorm))
suppressPackageStartupMessages(library(Matrix))

# Define ggplot2 theme
gg_theme <- function(){
  p <- theme(
      plot.title = element_text(size = 20,face = 'bold',
                                margin = margin(0,0,3,0), hjust = 0.5),
      axis.text = element_text(size = rel(1.05), color = 'black'),
      axis.title = element_text(size = rel(1.45), color = 'black'),
      axis.title.y = element_text(margin = margin(0,10,0,0)),
      axis.title.x = element_text(margin = margin(10,0,0,0)),
      axis.ticks.x = element_line(colour = "black", size = rel(0.8)),
      axis.ticks.y = element_blank(),
      legend.position = "right",
      legend.key.size = unit(1.4, 'lines'),
      legend.title = element_text(size = 12, face = 'bold'),
      legend.text = element_text(size = 12),
      panel.border = element_blank(),
      panel.grid.major = element_line(colour = "gainsboro"),
      panel.background = element_blank()
    )
  return(p)
}

# Plot the predictive distribution
draw_predictive <- function(xs, pred, title="", ...){
  K <- NCOL(pred$mu_pred)
  # Store predictions data
  dt <- data.table(pred$mu_pred) %>% 
      setnames(paste0("C", seq(1:K))) %>% 
      melt(variable.name = "Cluster", value.name = "ys") %>% 
      .[, xs := xs]
  dt_high <- data.table(pred$mu_pred + 2*pred$s_pred) %>% 
      setnames(paste0("C", seq(1:K))) %>% 
      melt(variable.name = "Cluster", value.name = "ys_high")
  dt_low <- data.table(pred$mu_pred - 2*pred$s_pred) %>% 
      setnames(paste0("C", seq(1:K))) %>% 
  melt(variable.name = "Cluster", value.name = "ys_low")
  dt <- cbind(dt, dt_low$ys_low, dt_high$ys_high) %>% 
      setnames(c("V2","V3"), c("ys_low","ys_high"))
  
  # alpha <- pred$pi_k/max(pred$pi_k)
  p <- ggplot(dt, aes(x = xs, y = ys, color = Cluster)) +
    geom_line(size = 2) +
    geom_ribbon(aes(ymin = dt$ys_low, ymax = dt$ys_high, fill = Cluster), 
                alpha = 0.23, size = 0.1) +
    scale_x_continuous(limits = c(-1, 1), 
                       labels = c("-5kb", "", "TSS", "", "+5kb")) + 
    scale_color_brewer(palette = "Dark2") +
    scale_fill_brewer(palette = "Dark2") +
    labs(title = title, x = "x", y = "y") + gg_theme()
}

# Use the log sum exp trick for having numeric stability
log_sum_exp <- function(x) {
  # Computes log(sum(exp(x))
  offset <- max(x)
  s <- log(sum(exp(x - offset))) + offset
  i <- which(!is.finite(s))
  if (length(i) > 0) { s[i] <- offset }
  return(s)
}

## A general-purpose adder:
add_func <- function(x) Reduce("+", x)

# Compute predictive distribution of VB_LRMM model
vb_lrmm_predictive <- function(model, X_test){
  mu_pred = s_pred <- matrix(0, ncol = model$K, nrow = NROW(X_test))
  for (k in 1:model$K) {
    # Predictive mean
    mu_pred[,k] <- c(X_test %*% model$m[,k])
    # Predictive variance
    s_pred[,k] <- sqrt(1/model$lambda + 
                           diag(X_test %*% model$S[,,k] %*% t(X_test)))
  }
  pi_k <- model$delta / sum(model$delta)
  return(list(mu_pred = mu_pred, s_pred = s_pred, pi_k = pi_k))
}

Function for fitting the variational Bayesian linear regression mixture model.

# Fit VB_LRMM model
vb_lrmm <- function(x, K=3, basis, lambda=1, delta_0=rep(1/K, K), 
                    alpha_0=1e-1, beta_0=1e-1, max_iter = 500, 
                    epsilon_conv = 1e-4, is_animation = FALSE, 
                    is_verbose = FALSE){

  assertthat::assert_that(is.list(x))
  N <- length(x)            # Number of observations
  D <- basis$M + 1          # Number of features
  # Extract responses y_{n}
  y <- lapply(X = x, FUN = function(x) x[,2])   
  # Create design matrix X 
  X <- lapply(X = x, FUN = function(x) 
      design_matrix(obj = basis, obs = x[,1])$H) 
  # Compute X_{n}'X_{n}
  XX <- lapply(X = X, FUN = function(x) crossprod(x))           
  # Compute y_{n}'y_{n} 
  yy <- unlist(lapply(X = y, FUN = function(y) c(crossprod(y))))     
  # Compute X_{n}'y_{n}
  Xy <- lapply(X = 1:N, FUN = function(i) crossprod(X[[i]], y[[i]]))    
  # Extract observations
  len_y <- unlist(lapply(X = y, FUN = function(y) length(y)))               
  L <- rep(-Inf, max_iter)  # Store the lower bounds
  # Matrices for responsibilities
  r_nk = log_r_nk = log_rho_nk <- matrix(0,nrow = N,ncol = K) 
  E_ww <- vector("numeric", length = K)
  # Compute \alpha_k parameter of Gamma
  alpha_k <- rep(alpha_0 + D/2, K)   
  W <- infer_profiles_mle(X = x, model = "gaussian", basis = basis, H = X,
                          lambda = 1e-2)$W
  # Use Kmeans with random starts
  cl  <- stats::kmeans(W, K, nstart = 25)    
  # Mean for each cluster
  m_k <- t(cl$centers)                                     
  # m_k <- matrix(0, ncol=K, nrow=D)
  # for (k in 1:K) {m_k[,k] <- rnorm(D)}
  # # Covariance of each cluster
  S_k <- array(0, dim = c(D,D,K))
  for (k in 1:K){ S_k[,,k] <- solve(diag(2, D))}      
  # Scale of precision matrix
  beta_k   <- rep(beta_0, K)   
  # Dirichlet parameter
  delta_k  <- delta_0                        
  # Expectation of log Dirichlet  
  e_log_pi <- digamma(delta_k) - digamma(sum(delta_k))     
  mk_Sk    <- lapply(X = 1:K, function(k) tcrossprod(m_k[, k]) + S_k[,,k])
  
  # Iterate to find optimal parameters
  for (i in 2:max_iter) {
    ##-------------------------------
    # Variational E-Step
    ##-------------------------------
    for (k in 1:K) {
      log_rho_nk[,k] <- e_log_pi[k] + lambda*sapply(1:N, function(n) 
          m_k[,k] %*% Xy[[n]] - 0.5*matrix.trace(XX[[n]] %*% mk_Sk[[k]]))
    }
    # Calculate probabilities using the logSumExp trick for numerical stability
    Z        <- apply(log_rho_nk, 1, log_sum_exp)
    log_r_nk <- log_rho_nk - Z
    r_nk     <- apply(log_r_nk, 2, exp)
    
    ##-------------------------------
    # Variational M-Step
    ##-------------------------------
    # Update Dirichlet parameter
    delta_k <- delta_0 + colSums(r_nk)
    for (k in 1:K) {
      # Update covariance for Gaussian
      w_XX <- lapply(X = 1:N, function(x) XX[[x]]*r_nk[x,k])
      S_k[,,k] <- solve(diag(alpha_k[k]/beta_k[k], D) + lambda * 
                            add_func(w_XX))
      # Update mean for Gaussian
      w_Xy <- lapply(X = 1:N, function(x) Xy[[x]]*r_nk[x,k])
      m_k[,k] <- lambda * S_k[,,k] %*% add_func(w_Xy)
      # Update \beta_k parameter for Gamma
      E_ww[k] <- crossprod(m_k[,k]) + matrix.trace(S_k[,,k])
      beta_k[k]  <- beta_0 + 0.5*E_ww[k]
    }
    # Compute expected value of mixing proportions
    pi_k <- (delta_0 + colSums(r_nk)) / (K * delta_0 + N)
    # Update expectations over \ln\pi
    e_log_pi <- digamma(delta_k) - digamma(sum(delta_k))
    # Compute expectation of E[a]
    E_alpha <- alpha_k / beta_k
    
    ##-------------------------------
    # Variational lower bound
    ##-------------------------------
    mk_Sk <- lapply(X = 1:K, function(k) tcrossprod(m_k[, k]) + S_k[,,k])
    
    lb_p_y <- -0.5*sum(len_y)*log(2*pi*(1/lambda)) - 0.5*lambda*sum(yy) + 
        sum(sapply(1:K, function(k) lambda*(sum(sapply(1:N, function(n) 
        r_nk[n,k]*(m_k[,k] %*% Xy[[n]] - 
        0.5*matrix.trace(XX[[n]] %*% mk_Sk[[k]])))))))
    lb_p_w   <- sum(-0.5*D*log(2*pi) + 0.5*D*(digamma(alpha_k) - 
        log(beta_k)) - 0.5*E_alpha*E_ww)
    lb_p_c   <- sum(r_nk %*% e_log_pi)   
    lb_p_pi  <- sum((delta_0 - 1)*e_log_pi) + lgamma(sum(delta_0)) - 
        sum(lgamma(delta_0))
    lb_p_tau <- sum(alpha_0*log(beta_0) + (alpha_0 - 1)*(digamma(alpha_k) - 
        log(beta_k)) - beta_0*E_alpha - lgamma(alpha_0))
    lb_q_c   <- sum(r_nk*log_r_nk)  
    lb_q_pi  <- sum((delta_k - 1)*e_log_pi) + lgamma(sum(delta_k)) - 
        sum(lgamma(delta_k))
    lb_q_w   <- sum(-0.5*log(sapply(X = 1:K,function(k) det(S_k[,,k]))) - 
        0.5*D*(1 + log(2*pi)))
    lb_q_tau <- sum(-lgamma(alpha_k) + (alpha_k - 1)*digamma(alpha_k) + 
        log(beta_k) - alpha_k)
    # Sum all parts to compute lower bound
    L[i] <- lb_p_y + lb_p_c + lb_p_pi + lb_p_w + lb_p_tau - lb_q_c - 
        lb_q_pi - lb_q_w - lb_q_tau
    # Show VB difference
    if (is_verbose) { 
      cat("It:\t",i,"\tLB:\t",L[i],"\tLB_diff:\t",L[i] - L[i - 1],"\n")
      cat("Lik: ",lb_p_y,"\tC: ",lb_p_c - lb_q_c,"\tW: ",lb_p_w - lb_q_w,
          "\tPi: ",lb_p_pi - lb_q_pi,"\tTau: ",lb_p_tau - lb_q_tau,"\n")
    }
    # Check if lower bound decreases
    if (L[i] < L[i - 1]) { message("Warning: Lower bound decreases!\n")} 
    # Check for convergence
    if (abs(L[i] - L[i - 1]) < epsilon_conv) { break }
    # Check if VB converged in the given maximum iterations
    if (i == max_iter) {warning("VB did not converge!\n")}
  }
  
  obj <- structure(list(m = m_k, S = S_k, delta = delta_k, r_nk = r_nk, 
                        lambda = lambda, pi_k = pi_k, beta = beta_k, 
                        alpha = alpha_k, L = L[2:i], K = K, N = N, D = D), 
                  class = "vb_lrmm")
  return(obj)
}

6.1 Cluster synthetic regression data

set.seed(10)  # For reproducibility
x <- BPRMeth::gaussian_data # Observations
K <- 3        # Number of clusters
basis <- create_rbf_object(M = 3) # Basis function object
# Run vb-lrmm model model
vb_lrmm_model <- vb_lrmm(x = x, K = K,lambda = 5,delta_0 = rep(1e-5,K), 
                         basis = basis, max_iter = 101,is_verbose = TRUE)
## It:   2  LB:  -9175.325  LB_diff:     Inf 
## Lik:  -8785.944  C:  -306.6153   W:  -53.00976   Pi:  -25.19721  Tau:  -4.55834 
## It:   3  LB:  -9152.844  LB_diff:     22.48042 
## Lik:  -8763.344  C:  -306.6153   W:  -53.10368   Pi:  -25.19721  Tau:  -4.583825 
## It:   4  LB:  -9152.844  LB_diff:     0.0003043586 
## Lik:  -8763.343  C:  -306.6153   W:  -53.10494   Pi:  -25.19721  Tau:  -4.583919 
## It:   5  LB:  -9152.844  LB_diff:     4.098183e-09 
## Lik:  -8763.343  C:  -306.6153   W:  -53.10495   Pi:  -25.19721  Tau:  -4.583919

Let’s examine the posterior predictive distribution

xs <- seq(-1, 1, len = 100) # create test X values
# Estimate predictive distribution
pred <- vb_lrmm_predictive(model = vb_lrmm_model, 
                           X_test = design_matrix(basis, xs)$H)
# Create plot
p <- draw_predictive(xs = xs, pred = pred, X = dt, 
            title = "Mixture of linear basis function regressions")
print(p)

6.2 Perform model selection using the lower bound

We will keep the data generated from the previous example and we will go through different models, i.e. increase the number of clusters \(K\) and observe how the lower bound \(\mathcal{L}(q)\) changes.

set.seed(1234)
K_model <- 10
L <- vector(mode = "numeric", length = K_model - 1)
for (k in 2:K_model) {
  # Run VBLR model
  vb_lrmm_model <- vb_lrmm(x = x, K = k, lambda = 5, delta_0 = rep(1e-5,k),
                         basis = basis, max_iter = 101, is_verbose = FALSE)
  # Store lower bound for each model and subtract \ln(K!)
  L[k - 1] <- tail(vb_lrmm_model$L, n = 1) - log(factorial(k))
}
## Warning in vb_lrmm(x = x, K = k, lambda = 5, delta_0 = rep(1e-05, k), basis = basis, : VB did not converge!

## Warning in vb_lrmm(x = x, K = k, lambda = 5, delta_0 = rep(1e-05, k), basis = basis, : VB did not converge!

## Warning in vb_lrmm(x = x, K = k, lambda = 5, delta_0 = rep(1e-05, k), basis = basis, : VB did not converge!

## Warning in vb_lrmm(x = x, K = k, lambda = 5, delta_0 = rep(1e-05, k), basis = basis, : VB did not converge!

## Warning in vb_lrmm(x = x, K = k, lambda = 5, delta_0 = rep(1e-05, k), basis = basis, : VB did not converge!

## Warning in vb_lrmm(x = x, K = k, lambda = 5, delta_0 = rep(1e-05, k), basis = basis, : VB did not converge!
dt <- data.table("Model" = 2:K_model, "LB" = L)
ggplot(data = dt, aes(x = Model, y = LB, group = 1)) + geom_line() + 
    geom_point() + scale_x_continuous(breaks = seq(1:K_model)) + gg_theme() 

Plot of the lower bound \(\mathcal{L}(q)\) versus the model complexity \(\mathcal{M}\) (i.e. number of clusters). The value of the lower bound gives the log probability of the model \(\text{ln}\;p(\mathbf{y} | \mathcal{M})\), and we see that the value of the bound peaks peaks at \(\mathcal{M} = 3\) (i.e. three clusters), corresponding to the true model that generated the data.

We can do use the same approach to test how many basis functions \(M\) are needed to model the data by keeping fized the number of clusters to K=3.

7 Conclusions

This tutorial-like document showed how we can perform Variational Bayes for mixture of Bayesian linear regression models.

If you found this document useful, check my homepage at the University of Edinburgh for links to other tutorials.