Variational Approximations

dummy slide

Introduction

\[ \definecolor{gray}{RGB}{192,192,192} \renewcommand\vec{\boldsymbol} \def\bigO#1{\mathcal{O}(#1)} \def\Cond#1#2{\left(#1\,\middle|\, #2\right)} \def\mat#1{\boldsymbol{#1}} \def\der{{\mathop{}\!\mathrm{d}}} \def\argmax{\text{arg}\,\text{max}} \def\expec{\text{E}} \def\prob{\text{E}} \def\trace{\text{tr}} \]

Presentation Outline

Survival analysis setup.

Computational issues in survival analysis.

Show applications of variational approximations (VAs).

Future work.

Survival Analysis

Notation

\(T_{ki}^*\) be the event time of individual \(i\) in the \(k\)th cluster.

Observe \(T_{ki} = \min (T_{ki}^*, C_{ki})\) where \(C_{ki}\) is the assumed independent censoring time.

Let \(D_{ki} = 1_{\{T_{ki}^* < C_{ki}\}}\) be the event indicators.

\(p\) will denote a (conditional) density function

which specification is implicitly given by the context.

Proportional Hazards (PH) Models

Typical choice is

\[ \lambda\Cond{t}{\vec x} = \lambda_0(t)\exp\left(\vec\beta^\top\vec x\right) \]

Generalization may be intractable due to

\[S\Cond{t}{\vec x} = \exp\left(-\int_0^t\lambda\Cond{s}{\vec x}\der s\right)\]

Leads to one-dimensional numerical integration.

Generalized Survival Models (GSMs)

Let

\[\begin{align*} g\left(S\Cond{t}{\vec x}\right) = g\left(S_0(t)\right) + \vec\beta^\top\vec x \end{align*}\]

where \(g\) is a link function and \(S_0\) is a baseline survival function

e.g., see Royston and Parmar (2002).

Avoid integration.

Computational Issues

Random Effects

May have unobserved factors which we want to account for or are interested in

e.g., twins who share genetic background and environment.

GSMs with Random Effects

\[ \begin{align*} g\left(S\Cond{t_{ki}}{\vec x_{ki}, \vec z_{ki}, \vec u_k}\right) &= g\left(S_0(t_{ki})\right) + \vec\beta^\top\vec x_{ki} + \vec z^\top_{ki}\vec u_k \\ &= g\left(S_0(t_{ki}; \vec x_{ki})\right) + \vec z^\top_{ki}\vec u_k \\ \vec U_k &\sim h(\vec \theta) \end{align*} \]
\(\vec z_{ki}\) is known covariates and \(\vec u_k\in\mathbb{R}^K\) is group \(k\)’s random effect.

\(k=1,\dots,m\) indicates the group and group \(k\) has \(i=1,\dots,n_k\) members.

In particular, we consider

\[\vec U_k \sim N(\vec 0, \mat \Sigma)\]

Notation

Let functions be applied elementwise

\[ \begin{align*} \vec x &= (x_1, x_2)^\top \\ f(\vec x) &= (f(x_1), f(x_2))^\top \end{align*} \]

Further, let

\[ \begin{align*} \vec t_k &= (t_{k1}, \dots, t_{kn_k})^\top, & \vec d_k &= (d_{k1}, \dots, d_{kn_k})^\top \\ \mat X_k &= (\vec x_{k1}, \dots, \vec x_{kn_k})^\top, & \mat Z_k &= (\vec z_{k1}, \dots, \vec z_{kn_k})^\top \end{align*} \]

Marginal Log-likelihood

The marginal log-likelihood (or model evidence) term for each group \(k\) is

\[ \begin{align*} l_k(\vec\beta, \mat\Sigma) &= \log \int \exp\left(h_k(\vec\beta, \mat\Sigma, \vec u) + \log \phi(\vec u;\mat \Sigma)\right) \der \vec u \\ h_k(\vec\beta, \mat\Sigma, \vec u) &= \vec d^\top_k \log \lambda\Cond{\vec t_k}{\mat X_k, \mat Z_k, \vec u} \\ &\hspace{20pt} + \vec 1^\top\log S\Cond{\vec t_k}{\mat X_k, \mat Z_k, \vec u} \end{align*} \]

where \(\phi(\cdot ;\mat \Sigma)\) is the density function of the multivariate normal distribution with a zero mean vector and covariance matrix \(\mat \Sigma\). \(\vec 1\) is a vector of ones.

\(l_k(\vec\beta, \mat\Sigma)\) is intractable in general.

Common Approximations

Laplace Approximation.

Fast, scales well, but may perform poorly e.g., for small groups (\(n_k\) small).

Adaptive Gaussian quadrature (AGQ).

Fast in low dimensions, scales poorly, and performs well. For examples, see Q. Liu and Pierce (1994) and Pinheiro and Bates (1995).

Monte Carlo methods.

Slow, scales poorly, and performs well.

Adaptive Gaussian Quadrature

\[\begin{align*} l &= \int c(\vec u)\der u \qquad\qquad\qquad\qquad\qquad \text{(intractable)} \\ &= \int \frac{c(\vec u)}{\phi(\vec u; \vec\mu, \mat\Lambda)} \phi(\vec u; \vec\mu, \mat\Lambda)\der\vec u \\ &= \int\underbrace{(2\pi)^{k/2} \lvert\Lambda\rvert^{1/2} c(\vec\mu + \Lambda^{1/2}\vec s) \exp\left(\vec s^\top\vec s / 2\right)}_{ \tilde c(\vec s; \vec\mu, \mat\Lambda)} \phi(\vec s; \vec 0, \mat I)\der\vec s \end{align*}\]

where \(\vec u\in \mathbb{R}^K\), \(\phi(\cdot; \vec\mu, \mat\Lambda)\) is the density of a multivariate normal distribution with mean \(\vec\mu\) and covariance matrix \(\mat\Lambda\), and \(\mat I\) is the identity matrix.

Adaptive Gaussian Quadrature

Apply Gauss-Hermite quadrature to each coordinate of \(\vec s\).

Each coordinate, \(s_l\), is approximated at \(b\) points.

Let \((g_i, w_i)\) be one of the node values and the corresponding weight for \(i = 1,\dots,b\) for each coordinate, \(s_l\), and

\[\vec s_{j_1,\cdots,j_k} = (g_{j_1},\cdots,g_{j_K})^\top\]

Adaptive Gaussian Quadrature

Repeatedly applying Gauss-Hermite quadrature yields

\[\begin{align*} l &= \int \tilde c(\vec s; \vec\mu; \mat\Lambda) \phi(\vec s; \vec 0, \mat I)\der\vec s \\ &\approx \sum_{j_1=1}^b\cdots\sum_{j_K=1}^b \tilde c(\vec s_{j_1,\cdots,j_K}; \vec\mu, \mat\Lambda) \prod_{q = 1}^K w_{j_q} \end{align*}\]

Scales poorly in the dimension of the random effect, \(K\).

Fast alternatives are attractive.

Variational Approximations

Lower Bound

A lower bound of the marginal log-likelihood is

\[ \begin{align*} \log p\left(\vec t_k, \vec d_k\right) &= \int q(\vec u;\vec\theta_k) \log \left(\frac {p\left(\vec t_k, \vec d_k, \vec u\right) / q(\vec u;\vec\theta_k)} {p\Cond{\vec u}{\vec t_k, \vec d_k} / q(\vec u;\vec\theta_k)} \right)\der\vec u \\ &\geq \int q(\vec u;\vec\theta_k) \log \left(\frac {p\left(\vec t_k, \vec d_k, \vec u\right)} {q(\vec u;\vec\theta_k)} \right)\der\vec u \\ &= \log \tilde p \left(\vec t_k, \vec d_k;\vec\theta_k\right) \end{align*} \]

for some density function \(q\). Equality is if and only

\[ q(\vec u_k;\vec\theta_k) = p\Cond{\vec u_k}{\vec t_k, \vec d_k} \]

Lower Bound

Approximate maximum likelihood is

\[\argmax_{\vec\beta, \mat \Sigma, \vec\theta_1, \cdots, \vec\theta_m} \sum_{k=1}^m\log \tilde p \left(\vec t_k, \vec d_k;\vec\theta_k\right)\]

Marginal Log-likelihood

Recall the marginal log-likelihood

\[ \begin{align*} l_k(\vec\beta, \mat\Sigma) &= \log \int \exp\underbrace{\left(h_k(\vec\beta, \mat\Sigma, \vec u) + \log \phi(\vec u;\mat \Sigma)\right)}_{ \log p(\vec t_k, \vec d_k, \vec u)} \der \vec u \\ h_k(\vec\beta, \mat\Sigma, \vec u) &= \vec d^\top_k \log \lambda\Cond{\vec t_k}{\mat X_k, \mat Z_k, \vec u} \\ &\hspace{20pt} + \vec 1^\top\log S\Cond{\vec t_k}{\mat X_k, \mat Z_k, \vec u} \end{align*} \]

Applying the Lower Bound

\(\log \phi(\vec u;\mat \Sigma)\) term requires a solution for \(\expec(\vec U_k\vec U_k^\top)\)

or the mean and covariance.

Need to compute the entropy due to \(- \log q(\vec u; \vec\theta_k)\).

The remaining two types of terms

\[\vec d^\top_k \log \lambda\Cond{\vec t_k}{\mat X_k, \mat Z_k, \vec U_k} + \vec 1^\top\log S\Cond{\vec t_k}{\mat X_k, \mat Z_k, \vec U_k}\]

Applying the Lower Bound

\[ \lambda\Cond{\vec t_k}{\mat X_k, \mat Z_k, \vec u_k} = s\left(\vec t_k; \eta_1(\mat X_k) + \mat Z_k\vec u_k\right) \]

for some function \(s\). Then if we know the distribution of

\[ U_k^{\vec z_k}= \vec z_k^\top\vec U_k \]

then we compute

\[ \begin{align*} \expec\left(\log\lambda\Cond{t_{ki}}{\vec x_{ki}, \vec z_{ki}, \vec U_k}\right) &= \expec\left( \log s \left(t_{ki};\eta_1(\vec x_{ki}) + U_k^{\vec z_k}\right) \right) \end{align*} \]

I.e., \(n_k\) one-dimensional integrals

instead of one \(K\)-dimensional integral. As emphasized by Ormerod and Wand (2012) for generalized linear mixed models.

Gaussian Variational Approximations

Gaussian Variational Approximation (GVA)

\[ q(\vec u;\vec\mu_k, \mat\Lambda_k) = \phi\left(\vec u; \vec\mu_k, \mat\Lambda_k\right) \]

Closed-form entropy and expected outer product.

\[ \begin{align*} \hat l_k(\vec \beta, \mat\Sigma;\vec\mu_k,\mat\Lambda_k) &= \int h_k(\vec\beta, \mat\Sigma, \vec u) \phi\left(\vec u; \vec\mu_k, \mat\Sigma_k\right) \der\vec u \\ &\hspace{20pt} + \frac 12\bigg( \log\lvert\mat\Sigma^{-1}\mat\Lambda_k\rvert - \trace \mat\Sigma^{-1}\mat\Lambda_k \\ &\hspace{85pt}- \vec\mu_k^\top\mat\Sigma^{-1}\vec\mu_k + K \bigg) \end{align*} \]

Generalized PH Model

\[ \lambda\Cond{s}{\vec x_{ki}, \vec z_{ki}, \vec u_k} = \lambda_0(s ;\vec x_{ki}) \exp\left(\vec f(s, \vec z_{ki})^\top\vec u_k\right) \]

where \(\vec f:\, [0,\infty)\times \mathbb{R}^v\rightarrow\mathbb{R}^K\) is not applied elementwise. Then

\[\begin{align*} \log S\Cond{t_{ki}}{\vec x_{ki}, \vec z_{ki}, \vec u_k} &\\ &\hspace{-40pt}= -\int_0^{t_{ki}} \lambda_0(s ;\vec x_{ki}) \exp\left(\vec f(s, \vec z_{ki})^\top\vec u_k\right) \der s \\ \log\lambda\Cond{t_{ki}}{\vec x_{ki}, \vec z_{ki}, \vec u_k} &= \log \lambda_0(t_{ki};\vec x_{ki}) + \vec f(t_{ki}, \vec z_{ki})^\top\vec u_k \end{align*}\]

Generalized PH Model

\[ \begin{align*} \hspace{40pt}&\hspace{-40pt} \int \log S\Cond{t_{ki}}{\vec x_{ki}, \vec z_{ki}, \vec u} \phi(\vec u; \vec\mu_k, \mat\Lambda_k)\der\vec u \\ &=- \int_0^{t_{ki}} \lambda_0(s ;\vec x_{ki}) \bigg(\int \exp\left(\vec f(s, \vec z_{ki})^\top\vec u\right) \\ &\hspace{110pt}\cdot \phi(\vec u; \vec\mu_k, \mat\Lambda_k) \der\vec u\bigg)\der s \end{align*} \]

Assuming we can change order of integration. Similar to result by Yue and Kontar (2019).

Example

\[ \begin{align*} g(S(t_{ki} \mid \vec x_{ki}, \vec z_{ki}, u_k)) &= g(S_0(t_{ki})) + \vec x_{ki}^\top\vec\beta + u_k \\ U_k &\sim N(0, \sigma^2) \\ S_0(t) &= c \exp(-\lambda_1 t^{\gamma_1}) + (1 - c) \exp(-\lambda_2 t^{\gamma_2}) \\ \hphantom{\vec(c, \lambda_1, \gamma_1, \lambda_2, \gamma_2, \log\sigma)} & \end{align*} \]

\[ \begin{align*} \vec\beta &= (0, 0.5, 0.5)^\top \\ \vec x_{ki} &= (1, b_{ki}, s_{ki})^\top \\ b_{ki} &= \text{Bern}(0.5) \\ s_{ki} &= N(0, 1) \\ \vec(c, \lambda_1, \gamma_1, \lambda_2, \gamma_2, \log\sigma) &= (0.5, 0.25, 1.5, 0.25, 0.5, 1.4) \\ (m, n_k) &= (200, 2) \\ \end{align*} \]

Baseline

Comparison

Adaptive Gaussian quadrature using 30 nodes

using rstpm2::stpm2 (Clements and Liu 2019).

Laplace approximation

using TMB (Thygesen et al. 2017).

Own GVA implementation

using automatic differentiation with the CppAD library (Bell 2019).

6 degrees of freedom for the baseline with a natural cubic spline.

One Example

set.seed(17288048)
dat <- get_sim_data(link = "PH")

rbind(
  AGQ = system.time(sfit <- stpm2_wrap (dat                 )),
  Lap = system.time(lfit <- laplace_app(dat, do_hess = FALSE)), 
  GVA = system.time(gfit <- GV_app     (dat, do_hess = FALSE)))[, 1:3]
##     user.self sys.self elapsed
## AGQ     0.899    0.005   0.904
## Lap     0.708    0.016   0.724
## GVA     0.177    0.000   0.178

One Example

rbind(AGQ = sfit$coef, Laplace = lfit$coef, GVA = gfit$coef)
##         (Intercept)     x treatment nsx(log(y), df = 5)1
## AGQ           -10.6 0.579     0.699                 7.69
## Laplace       -10.3 0.550     0.667                 7.57
## GVA           -10.3 0.548     0.663                 7.56
##         nsx(log(y), df = 5)2 nsx(log(y), df = 5)3
## AGQ                     8.64                 7.97
## Laplace                 8.48                 7.76
## GVA                     8.46                 7.75
##         nsx(log(y), df = 5)4 nsx(log(y), df = 5)5 logtheta
## AGQ                     14.8                 7.74    0.656
## Laplace                 14.6                 7.49    0.563
## GVA                     14.5                 7.48    0.556

Simulation Study (Bias)

Simulation Study (Bias)

Simulation Study (Bias)

Simulation Study (MLE Difference)

Simulation Study (MLE Difference)

Simulation Study (MLE Difference)

Lower Bound

Lower Bound

Lower Bound

Estimated Variational Distribution

Skew-normal Variational Approximations

Skew-normal Distribution

Properties

Closed under linear transformations.

Closed-form moment generating function.

Neat for the log-log and generalized PH model.

Properties

Entropy term is

\[ \begin{align*} -2\int \log \left(\phi(\vec u;\vec\mu_k, \mat\Sigma_k) \Phi(\vec\rho_k^\top(\vec u - \vec\mu_k))\right) \hspace{-60pt} \\ \cdot q(\vec u; \vec\mu_k, \mat\Sigma_k, \vec\rho_k)\der\vec u &= \frac 12\log\lvert\mat\Lambda_k\rvert +\frac 12 \log 2\pi \\ &\hspace{40pt}-\log 2- \psi(\vec\rho_k^\top\mat\Lambda_k\vec\rho_k) \\ \psi(\sigma^2) &= 2\int \phi(z;\sigma^2)\Phi(z)\log\Phi(z)\der z \end{align*} \]

As shown by Ormerod (2011) who also coins SNVA.

Estimated Variational Distribution

Simulation Study (Bias)

Simulation Study (Bias)

Simulation Study (Bias)

Simulation Study (MLE Difference)

Simulation Study (MLE Difference)

Simulation Study (MLE Difference)

Lower Bound

Lower Bound

Lower Bound

Computation Time

log-log -logit probit
AGQ 0.941 1.061 3.251
Laplace 0.539 0.549 1.111
GVA 0.094 0.162 0.282
SNVA 1.963 2.502 2.923
SNVA (CP) 0.408 0.689 0.915

Future work

Multivariate Random Effects

Comparing \(\bigO{mb^K}\) complexity to estimation of \(\bigO{mK^2}\) parameters

if there are no restriction on \(\vec\mu_k\)s, \(\mat\Lambda_k\)s, and \(\vec\rho_k\)s.

Laplace approximation requires estimation of \(mK\)-dimensional mode

and \(m\) inversions of \(K\times K\) matrices.

Can consider other variational distributions

Multivariate (skew)-\(t\)-distribution, skew-Laplace distribution, the four parameter skewed distributions shown in Arnold et al. (2002), and other generalizations.

Joint Models

Observe biomarker \(g(t)\) at time \(o_{k1}, \dots, o_{kl_k}\). Assume

\[ \begin{align*} g(o_{kj};\vec u_k) &= \bar g(o_{kj};\vec u_k) + \vec \epsilon_{ij} & \epsilon_{ij} &\sim N(0, \zeta^2) \\ \bar g(s;\vec u_k) &= \vec d(s)^\top\vec\gamma + \vec r(s)^\top\vec u_k & \vec u_k &\sim N(0, \mat\Sigma) \end{align*} \]

and a generalized PH survival sub-model

\[ \lambda\Cond{s}{\vec x_{ki}, \vec z_{ki}, \vec u_k} = \lambda_0(s ;\vec x_{ki}) \exp\left(\alpha \bar g(s;\vec u_k)\right) \]

Summary

Variational approximations are a broad class of approximations.

The Gaussian variational approximation seems to work similar to a Laplace approximation.

The skew-normal variational approximation seems fast and precise.

Thank You!

The presentation is at rpubs.com/boennecd/MEB-Thursday-19.

The markdown and code is at github.com/boennecd/Talks.

References are on the next slide.

References

Arnold, Barry C., Robert J. Beaver, A. Azzalini, N. Balakrishnan, A. Bhaumik, D. K. Dey, C. M. Cuadras, J. M. Sarabia, Barry C. Arnold, and Robert J. Beaver. 2002. “Skewed Multivariate Models Related to Hidden Truncation and/or Selective Reporting.” Test 11 (1): 7–54. doi:10.1007/BF02595728.

Bell, B. 2019. CppAD: A Package for C++ Algorithmic Differentiation. http://www.coin-or.org/CppAD.

Clements, Mark, and Xing-Rong Liu. 2019. Rstpm2: Smooth Survival Models, Including Generalized Survival Models. https://CRAN.R-project.org/package=rstpm2.

Liu, Qing, and Donald A. Pierce. 1994. “A Note on Gauss-Hermite Quadrature.” Biometrika 81 (3). [Oxford University Press, Biometrika Trust]: 624–29. http://www.jstor.org/stable/2337136.

Ormerod, J. T. 2011. “Skew-Normal Variational Approximations for Bayesian Inference.” Unpublished Article.

Ormerod, J. T., and M. P. Wand. 2012. “Gaussian Variational Approximate Inference for Generalized Linear Mixed Models.” Journal of Computational and Graphical Statistics 21 (1). Taylor & Francis: 2–17. doi:10.1198/jcgs.2011.09118.

Pinheiro, José C., and Douglas M. Bates. 1995. “Approximations to the Log-Likelihood Function in the Nonlinear Mixed-Effects Model.” Journal of Computational and Graphical Statistics 4 (1). American Statistical Association, Taylor & Francis, Ltd., Institute of Mathematical Statistics, Interface Foundation of America: 12–35. http://www.jstor.org/stable/1390625.

Royston, Patrick, and Mahesh K. B. Parmar. 2002. “Flexible Parametric Proportional-Hazards and Proportional-Odds Models for Censored Survival Data, with Application to Prognostic Modelling and Estimation of Treatment Effects.” Statistics in Medicine 21 (15): 2175–97. doi:10.1002/sim.1203.

Thygesen, Uffe H., Christoffer M. Albertsen, Casper W. Berg, Kasper Kristensen, and Anders Nielsen. 2017. “Validation of Ecological State Space Models Using the Laplace Approximation.” Environmental and Ecological Statistics, 1–23. doi:10.1007/s10651-017-0372-4.

Yue, Xubo, and Raed Kontar. 2019. “Variational Inference of Joint Models Using Multivariate Gaussian Convolution Processes.”