Computational Concerns with Mixed Models

dummy slide

Introduction

\[ \renewcommand\vec{\boldsymbol} \def\bigO#1{\mathcal{O}(#1)} \def\Cond#1#2{\left(#1\,\middle|\, #2\right)} \def\mat#1{\boldsymbol{#1}} \def\der{{\mathop{}\!\mathrm{d}}} \def\argmax{\text{arg}\,\text{max}} \def\Prob{\text{P}} \def\Expec{\text{E}} \def\logit{\text{logit}} \def\diag{\text{diag}} \]

Running Example

Students attempt multiple tasks. We want to know:

  1. How hard is each task?
  2. How much do the skill level differs between students?

Running Example (Cont.)

Unobserved \(\vec U_i\) is the skill level of student \(i\).

Assume that each \(\vec U_i\) has density \(g(\cdot;\vec\theta)\).

Parts of \(\vec\theta\) quantifies the difference in skill level.

Given skill level \(\vec U_i\), the outcomes \(\vec Y_i\) (whether each task is completed) has density \(h(\cdot\mid\vec u;\vec\theta)\).

Parts of \(\vec\theta\) quantifies the difficulty of each test.

Often care only about \(\vec\theta\) and not \(\vec U_i\).

Common Mixed Effect Model

The likelihood for individual \(i\) is

\[ \exp f(\vec y_i, \vec u_i;\vec\theta) = g(\vec u_i;\vec\theta) h(\vec y_i\mid\vec u_i;\vec\theta) \]

\(g(\vec u_i;\vec\theta)\) is similar to a prior in a Bayesian analysis.

The log marginal likelihood for individual \(i\) is

\[ l(\vec\theta;\vec y_i) = \log \int \exp(f(\vec y_i, \vec u;\vec\theta))\der \vec u \]

Needed to find the maximum likelihood estimator of \(\vec\theta\).

Running Example (Cont.)

Assume \(\vec y_i\in\{0,1\}^{n}\) is whether the answer to each task is correct.

Use a mixed logistic regression

\[ \begin{align*} f(\vec y_i,u_i;\vec\theta) &= \log\phi(u_i;\sigma^2) \\ &\hspace{25pt} +\sum_{j = 1}^{n} \Big((\eta_j + u_i)y_{ij} -\log(1 + \exp(\eta_j + u_i))\Big) \end{align*} \]

\(\eta_j\): difficulty of task \(j\).

\(\sigma^2\): difference between students’ skill level.

Concrete example

Three tasks with \((\eta_1,\eta_2,\eta_3) = (-2, 0, 2)\).

Assumed know but almost never known in practice and have to be estimated.

Need an approximation for the log marginal likelihood \(l(\vec\theta;\vec y_i) = l(\sigma^2;\vec y_i)\)

We will suppress the dependence on \(\vec y_i\) and \(i\) and write e.g. \(f(u;\vec\theta)\).

Use one dimensional \(U\in\mathbb R\) for simplicity.

Laplace Approximation

Laplace Approximation

\[ \begin{align*} l(\vec\theta) &= \log\int\exp(f(\vec u;\vec\theta))\der \vec u \\ &\approx \frac K2\log 2\pi - \frac 12\log \lVert -f''_{\vec u\vec u}(\vec u_0(\vec\theta);\vec\theta)\rVert + f(\vec u_0(\vec\theta);\vec\theta) \\ &= \tilde L(\vec\theta) \end{align*} \]

where \(\vec u \in \mathbb R^K\).

\[\vec u_0(\vec\theta) = \text{arg max}_\vec u f(\vec u;\vec\theta)\]

and \(f''_{\vec u\vec u}(\vec u_0;\vec\theta)\) is negative definite.

Needed for a maximum and in the proof of the approximation.

Running Example

\[ \begin{align*} f'_u(\vec y,u;\vec\theta) &= -\frac u{\sigma^2} + \sum_{i = 1}^{n_i}\left( y_i - \frac{\exp(\eta_i + u)}{1 + \exp(\eta_i + u)}\right) \\ f''_{uu}(\vec y,u;\vec\theta) &= -\frac 1{\sigma^2} - \sum_{i = 1}^{n_i} \frac{\exp(\eta_i + u)}{(1 + \exp(\eta_i + u))^2} \\ \tilde L(\vec\theta) &= \frac 12\log 2\pi - \frac 12\log \lVert -f''_{uu}(u_0(\vec\theta);\vec\theta)\rVert + f(u_0(\vec\theta);\vec\theta) \end{align*} \]

Example: Laplace Approximation

The continuous line is true curve and the dashed line is the Laplace approximation.

Remarks

\[ \begin{align*} l(\vec\theta) \approx \frac K2\log 2\pi - \frac 12\log \lVert -f''_{\vec u\vec u}(\vec u_0(\vec\theta);\vec\theta)\rVert + f(\vec u_0(\vec\theta);\vec\theta) \end{align*} \]

Finding \(\vec u_0(\vec\theta) = \text{arg max}_\vec u f(\vec u;\vec\theta)\) is often fast.

The Laplace approximation is fast but may be biased.

Has nice asymptotics in many cases.

In the running example when there are many tasks done by each student.

Laplace Approximation (Cont.)

\[ \begin{multline*} \tilde L'_{\vec\theta}(\vec\theta) = f'_{\vec\theta}(\vec u_0(\vec\theta);\vec\theta) +\frac 12\left( \nabla_{\vec u}\log \lVert -f''_{\vec u\vec u}(\vec u_0(\vec\theta);\vec\theta)\rVert \right) \\ f''_{\vec u\vec u}(\vec u_0(\vec\theta);\vec\theta)^{-1} f''_{\vec u\vec\theta}(\vec u_0(\vec\theta);\vec\theta) \end{multline*} \]

Point: requires 3rd order derivatives for the gradient.

Automatic differentiation implementations with higher order derivatives often have a big overhead.

A general approach has been implemented in the TMB package.

Gaussian Quadrature

Gaussian Quadrature

\[ \int \omega(u)g(u)\der u \approx \sum_{i = 1}^m w_ig(u_i) \]

with nodes \(u_1,\dots,u_m\) and weights \(w_1,\dots,w_m\).

The approximation is exact for polynomials of degree \(2m - 1\) or less.

Gauss–Hermite quadrature is often used where \(\omega(x) = \exp(-x^2)\).

Application

\[ \begin{align*} \int\exp f(\vec y,u;\vec\theta) \der u &= \frac 1{\sqrt{2\pi}\sigma}\int \exp\left( -\frac{u^2}{2\sigma^2}\right) h(\vec y\mid u;\vec\theta)\der u \\ &= \frac 1{\sqrt{\pi}}\int \underbrace{\exp\left(-u^2\right)}_{\omega(u)} h(\vec y\mid \sqrt 2 \sigma u;\vec\theta)\der u \end{align*} \]

Application (Cont.)

The continuous line is true curve, the dashed line is the Laplace approximation, and the dotted line is the Gauss-Hermite quadrature.

Adaptive Version

The integrand (continuous) and \(g(\vec u_i;\vec\theta)\) (dashed). The ticks on the first axis are the quadrature nodes, \(u_i\).

Adaptive Version

The integrand (continuous) and the shifted and re-scaled weight (dashed). The ticks on the first axis are the quadrature nodes, \(u_i\).

Adaptive Version (Cont.)

\[ \begin{align*} \int \exp\left(f\left(u;\vec\theta\right)\right)\der u &= \sqrt{2\hat\sigma^2(\vec\theta)}\int \exp(-u^2)\\ &\hspace{25pt} \cdot \exp\left(u^2 + f\left( \sqrt{2\hat\sigma^2(\vec\theta)} u + u_0(\vec\theta);\vec\theta\right)\right) \der u \\ u_0(\vec\theta) &= \text{arg max}_u f\left(u;\vec\theta\right) \\ \hat\sigma^2(\vec\theta) &= (-f''_{uu}(u_0(\vec\theta);\vec\theta))^{-1} \end{align*} \]

Application of the Adaptive Version

The continuous line is true curve, the dashed line is the Laplace approximation, and the dotted (green) line is the (adapted) Gauss-Hermite quadrature.

Remarks

Often works very well with common models and few random effects.

If \(\vec U\in\mathbb R^K\) then say \(K \leq 5\).

The computational complexity is \(\bigO{m^K}\).

\(m\) is the number of quadrature nodes.

But many problems do not have many random effects.

\(K\) is small.

Monte Carlo Approximation

Simple Monte Carlo

\[ \int g(\vec u;\vec\theta) h(\vec y\mid\vec u;\vec\theta) \der\vec u \]

Draw from \(g\), evaluate \(h\), and take the average.

Application of Monte Carlo

The continuous line is true curve, the dashed line is the Laplace approximation, and the dots are Monte Carlo estimators. The dotted line is an average estimate.

Importance Sampling

\[ \begin{align*} \int \exp\left(f\left(u;\vec\theta\right)\right)\der u &= \int \phi\left(u;u_0(\vec\theta),\hat\sigma^2(\vec\theta)\right) \frac{\exp\left(f\left( u;\vec\theta\right)\right)} {\phi\left(u;u_0(\vec\theta),\hat\sigma^2(\vec\theta)\right)} \der u \\ u_0(\vec\theta) &= \text{arg max}_u f\left(u;\vec\theta\right) \\ \hat\sigma^2(\vec\theta) &= (-f''_{uu}(u_0(\vec\theta);\vec\theta))^{-1} \end{align*} \]

Like the re-scaling and shifting with quadrature.

Application of Importance Sampling

The continuous line is true curve, the dashed line is the Laplace approximation, and the dots are importance sampling estimators. The dotted line is an average estimate.

Error Term

The error of the estimated marginal likelihood versus the number of samples. The Monte Carlo estimates are black and the importance sampling estimates are blue. Lines are from log-log regressions.

Remarks

Both estimates are \(\bigO{m^{-1/2}}\) where \(m\) is the number of samples but the constant may differ greatly.

Very efficient importance samplers exists for some models.

Various variance reduction methods exists.

E.g. antithetic variates and control variates. Effects only the constant.

Biased estimate of the log marginal likelihood

in finite samples. Usually not an issue.

Markov chain Monte Carlo and related methods are for higher dimensional problems.

Quasi-Monte Carlo

Select numbers that are not random but a sequence with a low-discrepancy.

Estimates are \(\bigO{m^{-1 + \epsilon}}\) where \(m\) is the length of the sequence.

Quasi-Monte Carlo

The continuous line is true curve, the dashed line is the Laplace approximation, and the dotted line is the quasi-Monte Carlo estimates.

Error Term: with Quasi-Monte Carlo

The error of the estimated marginal likelihood versus the number of samples. The Monte Carlo estimates are black and the importance sampling estimates are blue. Lines are from log-log regressions. A line is also added for the quasi-Monte Carlo estimate.

Variational Approximations

Variational Approximations

Pick some variational distribution with density \(q(\cdot;\vec\omega)\) for some \(\vec\omega\in\Omega\). Then:

\[ \begin{align*} l(\vec\theta) &= \log \int q(\vec u;\vec\omega) \log \left(\frac{p(\vec y, \vec u;\vec\theta)\big/ q(\vec u;\vec\omega)} {p(\vec u\mid\vec y)\big/ q(\vec u;\vec\omega)}\right) \der \vec u\\ &= \log \int q(\vec u;\vec\omega) \left( \log \left(\frac{p(\vec y, \vec u;\vec\theta)} {q(\vec u;\vec\omega)}\right) + \log \left(\frac{q(\vec u;\vec\omega)} {p(\vec u\mid\vec y;\vec\theta)}\right) \right) \der \vec u \\ &\geq \log \int q(\vec u;\vec\omega) \log \left(\frac{p(\vec y, \vec u;\vec\theta)} {q(\vec u;\vec\omega)}\right) \der \vec u = \tilde l(\vec\theta, \vec\omega) \end{align*} \]

Variational Approximations (Cont.)

\[ l(\vec\theta) \geq \text{arg max}_{\vec\omega} \tilde l(\vec\theta, \vec\omega) \]

Maybe very tight (almost equal).

Running Example

We need to evaluate:

The entropy: \(-E_{q(\cdot;\vec\omega)}(\log(q(U;\vec\omega)))\).

The first two moments: \(E_{q(\cdot;\vec\omega)}(U^2)\big/2\sigma^2\) and \(E_{q(\cdot;\vec\omega)}(U)y_i\).

The following expectation: \(E_{q(\cdot;\vec\omega)}(\log(1 + \exp(\eta_i + U)))\).

Gaussian Variational Approximation

The continuous line is true curve, the dashed line is the Laplace approximation, and the dotted line is the lower bound from the Gaussian variational approximation.

Skew-normal VA

The continuous line is true curve, the dashed line is the Laplace approximation, and the (blue) dotted line is the lower bound from the Gaussian (skew normal) variational approximation.

Remarks

Works quite well in this case but this is not universally true.

The computational complexity of the lower bound is often \(\bigO{K^2}\) or \(\bigO{K^3}\) at worst

where \(\vec U\in\mathbb R^K\).

A more flexible variational distribution can be used if the lower bounds is not tight enough.

Usually, never has no more than intractable one-dimensional integrals

where quadrature can be used.

Summary

Summary

An intractable likelihood is common with mixed models.

The Laplace approximations is fast but may be biased.

Gaussian quadrature works well for many low dimensional problems.

Monte Carlo approximation are widely applicable but naive applications can have high variance.

Variational approximation are fast but may be biased.

Thank You!

The presentation is at rpubs.com/boennecd/Comp-Mixed-Models.

The markdown is at github.com/boennecd/Talks.