Variational Approximations in Survival Analysis

dummy slide

Project Overview

\[ \renewcommand\vec{\boldsymbol} \def\bigO#1{\mathcal{O}(#1)} \def\Cond#1#2{\left(#1\,\middle|\, #2\right)} \def\mat#1{\boldsymbol{#1}} \def\der{{\mathop{}\!\mathrm{d}}} \def\argmax{\text{arg}\,\text{max}} \def\Prob{\text{P}} \def\Expec{\text{E}} \def\logit{\text{logit}} \def\diag{\text{diag}} \]

General Problem: Fitting Mixed Models

Have outcomes \(\vec Y\)

plausibly of heterogeneous types. E.g. a complication (yes/no), time-to-diagnosis (continuous or interval), or blood measurements (continuous).

Relate the outcomes with an unobserved random effects \(\vec U\) with density \(h_{\vec\omega}(\vec u)\) and a conditional density \(g_{\vec\omega}(\vec y \mid \vec u)\)

depending model parameter \(\vec\omega\).

Thus, the log likelihood is

\[ l(\vec y,\vec\omega) = \log \int h_{\vec\omega}(\vec u)g_{\vec\omega}(\vec y \mid \vec u)d\vec u. \]

Needed for model estimation. \(g\) and \(h\) are usually easy to evaluate pointwise but no closed form solution.

Issue

Current methods and implementations are often slow or too biased.

Researchers restrict their models and questions because of these constraints.

Non-Variational Approximations Work

Study of approximation methods for a commonly used and broad class of mixed models.

Draft has been submitted.

Lead to imputation method for mixed data types.

Article is accepted at the Asian Conference on Machine Learning. The software has 3197 downloads and a draft is at Christoffersen et al. (2021).

Non-Variational Approximations Work

Lead to a fast and precise estimation method for common heritability models.

Mature draft to be submitted. Started collaborations with Behrang Mahjani (ISMMS, New York), Benjamin Yip (CUHK; Hong Kong), and Sven Sandin (ISMMS, New York).

Data sets include:

  • The Swedish National Patient Register.
  • The Swedish Multi-Generation Registry.

Variational Approximations

Replace the log likelihood with a lower bound for some \(\vec\theta\in\Theta\):

\[ \begin{align*} l(\vec y, \vec\omega) &\geq \tilde l(\vec y, \vec\omega, \vec\theta). \end{align*} \]

\(\tilde l(\vec y, \vec\omega, \vec\theta)\) usually requires one-dimensional integration at worst.

A “by-product” is an approximation of the conditional distribution of the random effects given the data.

Typically have \(n\) outcomes

and work with \(l^{(i)}\), \(h^{(i)}_{\vec\omega}\), \(g^{(i)}_{\vec\omega}\), \(q_{\vec\theta_i}^{(i)}\), \(\Theta_i\), and \(\tilde l^{(i)}\).

Variational Approximations (Cont.)

Approximate maximum likelihood

\[ \text{arg max}_{\vec\omega} \sum_{i = 1}^{n}l^{(i)}(\vec y_i, \vec\omega) \approx \text{arg max}_{\vec\omega} \sum_{i = 1}^{n} \max_{\vec\theta_i}\tilde l^{(i)}(\vec y_i, \vec\omega, \vec\theta_i). \]

Very fast pointwise evaluation but many more parameters!

10,000 to millions with register data sets.

Gradient descent or limited-memory BFGS are too slow.

We have developed special methods.

Comparison with Machine Learning

ML goal is often the performance of the left-hand side of

\[ E_{q_{\hat{\vec\theta}_i}}(f_{\hat{\vec\omega}}(\vec U)) \approx E_{p_{\hat{\vec\omega}}}(f_{\hat{\vec\omega}}(\vec U)) \]

where \(q\) is a distribution chosen for the variational approximation and \(p\) is the true conditional distribution of the random effects. Other quantities using \(q_{\hat{\vec\theta}_i}\) are also used. The precision of the approximation is not the focus.

In statistics, the interest is usually on a precise approximation of

\[ \text{arg max}_{\vec\omega} \sum_{i = 1}^{n}l^{(i)}(\vec y_i, \vec\omega) \approx \text{arg max}_{\vec\omega} \sum_{i = 1}^{n} \max_{\vec\theta_i}\tilde l^{(i)}(\vec y_i, \vec\omega, \vec\theta_i). \]

Often requires a tight lower bound! \(\vec\omega\) is also of much lower dimension.

Concrete Work

Developed a variational approximations for a class of mixed survival models.

Draft have been submitted.

Working on joint survival and marker models.

Seen up to two orders of magnitude faster estimation and only small bias in preliminary studies. Plan collaboration with Juan Jesus Carrero (KI) and Maya Alsheh Ali (KI).

Data sets include:

  • Electronic health records.
  • Images.

Concrete Work (cont.)

Implemented a library to optimize partially separable functions.

A header-only C++ library with an R interface and 5142 downloads.

Optimizing Variational Approximations

Example of Computation Times

Toy example with a mixed logistic regression with six random effects per \(n = 1000\) observations.

E.g. an observation is a student taking multiple tests, or multiple patients at the same doctor. Random effects: person or doctor specific effects.

Sampled five data sets.

Example of Computation Times

Mean Meadian
lme4 105.320 73.671
psqn (our) 2.562 2.423
psqn (our; 4 threads) 0.804 0.768
LBFGS 24.636 23.298

Running times are in seconds. lme4: a Laplace approximation from the lme4 package, psqn: uses our psqn package, and LBFGS: uses limited-memory BFGS. From https://github.com/boennecd/psqn-va-ex.

The Laplace approximation even has higher bias in this example!

Methods

  1. Recursive method.
  2. Newton’s method.
  3. Quasi-Newton using the partial separability.

Recursive Method

Solve \(\widehat{\vec\theta}_i(\vec\omega) = \text{arg max}_{\vec\theta_i}\tilde l^{(i)}(\vec y_i, \vec\omega, \vec\theta_i)\) for \(i = 1,\dots,n\).

Take one step of Newton’s method to update \(\vec\omega\) to find the maximum of

\[ \text{arg max}_{\vec\omega} \sum_{i = 1}^n\tilde l^{(i)}\left(\vec y_i, \vec\omega, \widehat{\vec\theta}_i(\vec\omega) \right) \]

Repeat if not converged.

The Hessian in the second step can be efficiently computed. Similar to the method suggested by Ormerod and Wand (2012).

Newton’s Method

Use Newton’s method to solve

\[ \text{arg max}_{\vec\omega,\vec\theta_1,\dots,\vec\theta_n} \sum_{i = 1}^{n} \tilde l^{(i)}(\vec y_i, \vec\omega, \vec\theta_i). \]

Solve the Hessian using e.g. conjugate gradient.

The Hessian is very sparse and has a arrowhead matrix-like structure.

Using the Partial Separability

Approximate the Hessian of each \(\tilde l^{(i)}\) using BFGS.

Do not have to compute or implement the Hessian.

Use a quasi-Newton method to solve

\[ \text{arg max}_{\vec\omega,\vec\theta_1,\dots,\vec\theta_n} \sum_{i = 1}^{n} \tilde l^{(i)}(\vec y_i, \vec\omega, \vec\theta_i) \]

with the Hessian approximation. Solve with conjugate gradient.

Implemented in the psqn package. Used in the example.

Remarks

The problems are commonly

  • embarrassing parallel.
  • easy to make cache-friendly implementation for.
Automatic differentiation can be used for the gradient or Hessian.

A fast implementation like Hogan (2014) gives only a moderate overhead. We are using the very similar implementation from Savine (2018).

Summary

Variational approximation can be used to enable researchers to fit models that where not possible before…

… but this requires special methods for optimizing the lower bound.

Thank You!

The presentation is at rpubs.com/boennecd/SeRC-meeting-2021.

The markdown is at github.com/boennecd/Talks.

The psqn package is on CRAN and at github.com/boennecd/psqn.

The imputation package, mdgc, is on CRAN and at github.com/boennecd/mdgc.

References are on the next slide.

References

Christoffersen, Benjamin, Mark Clements, Keith Humphreys, and Hedvig Kjellström. 2021. “Asymptotically Exact and Fast Gaussian Copula Models for Imputation of Mixed Data Types.” http://arxiv.org/abs/2102.02642.
Hogan, Robin J. 2014. “Fast Reverse-Mode Automatic Differentiation Using Expression Templates in c++.” ACM Trans. Math. Softw. 40 (4). https://doi.org/10.1145/2560359.
Ormerod, J. T., and M. P. Wand. 2012. “Gaussian Variational Approximate Inference for Generalized Linear Mixed Models.” Journal of Computational and Graphical Statistics 21 (1): 2–17. http://www.jstor.org/stable/23248820.
Savine, Antoine. 2018. Modern Computational Finance: AAD and Parallel Simulations. John Wiley & Sons.