A matrix-variate Gaussian is a distribution over matrices that can be defined by a mean matrix, and row and column covariances. Supposedly it arises in various situations in multivariate analysis that I don’t fully understand, but what I do understand is that it crops up in things like multiple correlated sensors observed over time. It is also an elegant way (I think) to think about things like the dual view of probabilistic factor analysis (Lawrence, 2015).
Here is the matrix-normal log-likelihood:
\DeclareMathOperator{\Tr}{Tr} \DeclareMathOperator{\vecop}{vec} \newcommand{\trp}{{^\top}} % transpose \newcommand{\inv}{^{-1}} \newcommand{\mb}{\mathbf{b}} \newcommand{\M}{\mathbf{M}} \newcommand{\C}{\mathbf{C}} \newcommand{\L}{\mathbf{L}} \newcommand{\R}{\mathbf{R}} \newcommand{\U}{\mathbf{U}} \newcommand{\V}{\mathbf{V}} \newcommand{\X}{\mathbf{X}} \newcommand{\Z}{\mathbf{Z}}
\X \sim \mathcal{MN}_{mn}(\M,\U,\V)\\ \log p(\X\mid \M,\U, \V) = -2\log mn - n \log|\U| - m \log|\V| - \Tr\left[\V\inv(\X-\M)\trp\U\inv(\X-\M)\right]
The most straightforward implementation will take something like a cholesky of both covariances, and then exploit that cholesky to get the log-determinant in linear time, and the product of the inverse and X by a triangular solve in square time. I think this is the standard way to do MVN as well (at least, I read it in the tensorflow MVN code).
Here is an R implementation:
dmatnorm <- function(X, M, U, V){
# check dimensions
stopifnot(all(dim(X) == dim(M)))
stopifnot(all(nrow(X) == nrow(U), nrow(X) == ncol(U)))
stopifnot(all(ncol(X) == nrow(V), ncol(X) == ncol(V)))
# get sizes and choleskys
rowsize <- nrow(X)
colsize <- ncol(X)
U_chol <- chol(U)
V_chol <- chol(V)
# precompute inverse and determinant
Xresid <- X - M
P <- backsolve(U_chol, backsolve(U_chol, Xresid, transpose=TRUE))
Q <- backsolve(V_chol, backsolve(V_chol, t(Xresid), transpose=TRUE))
logdet_row <- 2 * sum(log(diag(U_chol)))
logdet_col <- 2 * sum(log(diag(V_chol)))
# construct likelihood
log2pi <- 1.8378770664093453
denominator = - rowsize * colsize * log2pi - colsize * logdet_row - rowsize * logdet_col
numerator = - sum(diag(Q %*% P))
return(0.5 * (numerator + denominator))
}
We can sanity-check using dmvnorm
by setting one of the covariances to identity:
library(mvtnorm)
library(testthat)
m <- 3
n <- 5
# row cov
M <- matrix(rep(rnorm(m), n), m, n)
U <- rWishart(1, m+2, diag(m))[,,1]
V <- diag(n)
X <- matrix(rep(rnorm(m), n), m, n)
expect_equal(sum(dmvnorm(t(X), M[,1], U, log=TRUE)), dmatnorm(X, M, U, V))
# column cov
M <- t(matrix(rep(rnorm(m), n), m, n))
U <- diag(n)
V <- rWishart(1, m+2, diag(m))[,,1]
X <- t(matrix(rep(rnorm(m), n), m, n))
expect_equal(sum(dmvnorm(X, M[1,], V, log=TRUE)), dmatnorm(X, M, U, V))
The other thing that is nice to have for dealing with these distributions is that the standard reparameterization trick works here: if \X = \M + \L\Z\R and \Z is a matrix of standard random variates of the appropriate size, then this implies that \X \sim \mathcal{MN}(0, \L\L\trp, \R\R\trp). We can use this to make a simple rmatnorm function:
rmatnorm <- function(M, U, V){
U_chol <- chol(U)
V_chol <- chol(V)
rowsize <- nrow(U)
colsize <- ncol(V)
return(M + U_chol %*% matrix(rnorm(rowsize * colsize), rowsize, colsize) %*% V_chol)
}
Next, we would like to have analytic gradients. Here is an implementation with test against finite differences. The gradient for the mean is fine. The gradient for the covariances is very odd: finite diffs returns an upper triangular matrix, where the upper triangle matches analytic exactly, diagonal is off by a factor of 2, and the lower triangular is zero.
I cannot quite tell whether this is a quirk of me asking grad()
to deal with a function that expects some structure to its inputs (PSD covariance matrices) or something deeper.
library(plyr)
library(numDeriv)
matrix_trace <- function(X){
return (sum(diag(X)))
}
mats_to_vec <- function(M, U, V){
params <- laply(c(M,U, V), as.vector)
return(params)
}
vec_to_mats <- function(params){
indices <- cumsum(c(1,m*n,m*m,n*n))
from <- indices[1:(length(indices)-1)]
to <- indices[-1]-1
split_params <- alply(rbind(from,to), 2, function(x) params[x[1]:x[2]])
M <- matrix(split_params[[1]], m, n)
U <- matrix(split_params[[2]], m, m)
V <- matrix(split_params[[3]], n, n)
return(list(M,U,V))
}
dmatnorm_flattened <- function(params, X_){
# I wish I had python unpacking here
l <- vec_to_mats(params)
return(dmatnorm(X_, l[[1]], l[[2]], l[[3]]))
}
numeric_matnorm_grad <- function(X, M, U, V){
params <- mats_to_vec(M, U, V)
findiff_grad <- grad(func=dmatnorm_flattened, x=params, X_=X, method.args=list(eps=1e-6, d=0.000001, zero.tol=sqrt(.Machine$double.eps/7e-7), r=6, v=2))
return(vec_to_mats(findiff_grad))
}
analytic_matnorm_grad <- function(X, M, U, V){
n <- nrow(U)
m <- nrow(V)
grad_U <- -m * solve(U) + solve(U) %*% (X-M) %*% solve(V) %*% t(X-M) %*% solve(U)
grad_V <- -n * solve(V) + solve(V) %*% t(X-M) %*% solve(U) %*% (X-M) %*% solve(V)
grad_M <- solve(U) %*% (X-M) %*% solve(V)
return(list(grad_M,grad_U,grad_V))
}
m <- 3
n <- 5
M <- matrix(rnorm(m*n), m, n)
U <- rWishart(1, m+2, diag(m))[,,1]
V <- rWishart(1, n+2, diag(n))[,,1]
X <- matrix(rnorm(m*n), m, n)
an <- analytic_matnorm_grad(X,M,U,V)
findiff <- numeric_matnorm_grad(X,M,U,V)
print(an[[1]])
## [,1] [,2] [,3] [,4] [,5]
## [1,] -0.1814540 0.3466090 0.04246733 -0.3632650 -0.03153692
## [2,] 0.1119122 -0.1721125 -0.18823972 0.3516859 -0.09505373
## [3,] 0.2463708 -0.8468918 -0.17035955 0.7714730 0.04783549
print(findiff[[1]])
## [,1] [,2] [,3] [,4] [,5]
## [1,] -0.1814540 0.3466090 0.0424673 -0.3632649 -0.03153724
## [2,] 0.1119122 -0.1721126 -0.1882398 0.3516841 -0.09505402
## [3,] 0.2463707 -0.8468918 -0.1703596 0.7714728 0.04783553
print(an[[2]])
## [,1] [,2] [,3]
## [1,] -1.0736735 0.6377906 -0.1844745
## [2,] 0.6377906 -0.8956581 0.1445865
## [3,] -0.1844745 0.1445865 0.3612363
print(findiff[[2]])
## [,1] [,2] [,3]
## [1,] -0.5368368 0.6377906 -0.1844745
## [2,] 0.0000000 -0.4478290 0.1445863
## [3,] 0.0000000 0.0000000 0.1806182
print(an[[3]])
## [,1] [,2] [,3] [,4] [,5]
## [1,] -1.04075182 -0.1068757 -0.0792679 0.8766246 -0.04982503
## [2,] -0.10687572 0.7523713 0.1318751 -1.0309108 -0.14070128
## [3,] -0.07926790 0.1318751 -0.7484300 0.3342265 -0.40577267
## [4,] 0.87662460 -1.0309108 0.3342265 -0.3128125 0.49347717
## [5,] -0.04982503 -0.1407013 -0.4057727 0.4934772 -0.48172792
print(findiff[[3]])
## [,1] [,2] [,3] [,4] [,5]
## [1,] -0.5203759 -0.1068757 -0.07926793 0.8766246 -0.04982483
## [2,] 0.0000000 0.3761857 0.13187532 -1.0309108 -0.14070121
## [3,] 0.0000000 0.0000000 -0.37421502 0.3342264 -0.40577266
## [4,] 0.0000000 0.0000000 0.00000000 -0.1564063 0.49347717
## [5,] 0.0000000 0.0000000 0.00000000 0.0000000 -0.24086398
An alternate formulation of the matrix-variate normal is the following: if \X \sim \mathcal{MN}_{mn}(\M,\U,\V), then \vecop[\X] \sim \mathcal{N}(\vecop[\M],\V\otimes\U) – from this definition and the MVN density one can derive the density above, using standard properties of the \vecop operator and the kronecker product. The explicit matrix-normal density above is nice because we never have to create or invert that kronecker-factored covariance, which can be very large (and is the thing I have seen in models people have posted to the users list).
Unfortunately, sometimes we want more factors than just two – for example, data which is volumetric (3D), or spatial and observed over time (3D), volumetric and observed over time (4D), or even volumteric, observed over time, and subject-level grouping structure (5D).
In these higher-D cases, we can play the same exact trick (do a cholesky once, then do a solve and a determinant). But the solve needs to work without creating the whole matrix out of its kronecker factors. This is doable – Narayanan Sundaram at intel labs has an implementation (in python/tensorflow) that we will release soon under Apache 2.0 license.
The other thing that gets complicated in flattened array-normal is the reparameterization trick. Again, we want a product without materializing the cholesky of the whole covariance. To do this, we can use a clever algorithm from doi:10.1145/278298.278303. They give it for multiplication from the right, and I also wrote the trivial extension of doing it from the left. Here is naive R code for comparison, the algorithm implemented in R, and same thing in stan:
kronRmult_naive <- function(X, Q){
# X is a p x q matrix.
# Q is a list of square matrices such that the product of the sizes is q
# the return is X %*% Q[[1]] kron Q[[2]] kron Q[[3]]
Qkron <- kronecker(kronecker(Q[[1]],Q[[2]]),Q[[3]])
return( X %*% Qkron)
}
kronLmult_naive <- function(X, Q){
# X is a p x q matrix.
# Q is a list of square matrices such that the product of the sizes is p
# the return is Q[[1]] kron Q[[2]] kron Q[[3]] %*% X
Qkron <- kronecker(kronecker(Q[[1]],Q[[2]]),Q[[3]])
return( Qkron %*% X)
}
kronRmult <- function(X, Q){
# Efficient Kronecker Multiplication
# Algorithm from page 394 of Fernandes et al. 1998,
# JACM 45(3): 381--414 (doi:10.1145/278298.278303).
# X is a p x q matrix.
# Q is a list of square matrices such that the product of the sizes is q
# the return is X %*% Q[[1]] kron Q[[2]] kron Q[[3]] etc.
n <- laply(Q, nrow)
N <- length(n)
nleft <- prod(n[1:(N-1)])
nright <- 1
indices <- rev(1:length(Q))
out <- X
for (i in indices){
Z <- matrix(nrow=nrow(X), ncol=n[i])
base <- 0
jump <- n[i] * nright
for (k in 1:nleft){
for (j in 1:nright){
index <- base + j
for (l in 1:n[i]){
Z[,l] <- out[,index]
index <- index+nright
}
Z <- Z %*% Q[[i]]
index <- base + j
for (l in 1:n[i]){
out[,index] <- Z[,l]
index <- index+nright
}
}
base <- base + jump
}
nleft <- nleft / n[i-1]
nright <- nright * n[i]
}
return(out)
}
kronLmult <- function(X, Q){
# Efficient Kronecker Multiplication
# Algorithm from page 394 of Fernandes et al. 1998,
# JACM 45(3): 381--414 (doi:10.1145/278298.278303) (modified for left mult)
# X is a p x q matrix.
# Q is a list of square matrices such that the product of the sizes is p
# the return is Q[[1]] kron Q[[2]] kron Q[[3]] %*% X etc.
n <- laply(Q, nrow)
N <- length(n)
nleft <- prod(n[1:(N-1)])
nright <- 1
indices <- rev(1:length(Q))
out <- X
for (i in indices){
Z <- matrix(nrow=n[i], ncol=ncol(X))
base <- 0
jump <- n[i] * nright
for (k in 1:nleft){
for (j in 1:nright){
index <- base + j
for (l in 1:n[i]){
Z[l,] <- out[index,]
index <- index+nright
}
Z <- Q[[i]] %*% Z
index <- base + j
for (l in 1:n[i]){
out[index,] <- Z[l,]
index <- index+nright
}
}
base <- base + jump
}
nleft <- nleft / n[i-1]
nright <- nright * n[i]
}
return(out)
}
data{
int dimX;
int dimY;
int dimZ;
int dimT;
int ncol_X;
matrix[dimX, dimX] cholX;
matrix[dimY, dimY] cholY;
matrix[dimZ, dimZ] cholZ;
matrix[dimT, ncol_X] Xin;
}
model{}
generated quantities{
int nleft;
int nright;
int base;
int jump;
int index;
matrix[ncol_X] X;
matrix[dimT, dimX] tmpX;
matrix[dimT, dimY] tmpY;
matrix[dimT, dimZ] tmpZ;
nleft = dimZ*dimY;
nright = 1;
// copy here because we accumulate into X and I don't
// know what happens if you modify data in stan
for (i in 1dimT){
for (j in 1ncol_X){
X[j] = Xin[j];
}
}
# accumulate cholX
base = 0;
jump = dimX * nright;
for (k in 1nleft){
for (j in 1nright){
index = base + j;
for (l in 1dimX){
for (m in 1dimT){
tmpX[l] = X[index];
}nright;
}
tmpX = tmpX * cholX;
index = base + j;
for (l in 1dimX){
for (m in 1dimT){
X[index] = tmpX[l];
}nright;
}
}
base = base + jump;
}
nleft = nleft / dimY;
nright = nright * dimX;
# accumulate cholY
base = 0;
jump = dimY * nright;
for (k in 1nleft){
for (j in 1nright){
index = base + j;
for (l in 1dimY){
for (m in 1dimT){
tmpY[l] = X[index];
}nright;
}
tmpY = tmpY * cholY;
index = base + j;
for (l in 1dimY){
for (m in 1dimT){
X[index] = tmpY[l];
}nright;
}
}
base = base + jump;
}
nleft = nleft / dimZ;
nright = nright * dimY;
# accumulate cholZ
base = 0;
jump = dimZ * nright;
for (k in 1nleft){
for (j in 1nright){
index = base + j;
for (l in 1dimZ){
for (m in 1dimT){
tmpZ[l] = X[index];
}nright;
}
tmpZ = tmpZ * cholZ;
index = base + j;
for (l in 1dimZ){
for (m in 1dimT){
X[index] = tmpZ[l];
}nright;
}
}
}
}
data{
int dimX;
int dimY;
int dimZ;
int dimT;
int nrow_X;
matrix[dimX, dimX] cholX;
matrix[dimY, dimY] cholY;
matrix[dimZ, dimZ] cholZ;
matrix[nrow_X, dimT] Xin;
}
model{}
generated quantities{
int nleft;
int nright;
int base;
int jump;
int index;
matrix[nrow_X, dimT] X;
matrix[dimX, dimT] tmpX;
matrix[dimY, dimT] tmpY;
matrix[dimZ, dimT] tmpZ;
nleft = dimZ*dimY;
nright = 1;
// copy here because we accumulate into X and I don't
// know what happens if you modify data in stan
for (i in 1nrow_X){
X[i] = Xin[i];
}
# accumulate cholX
base = 0;
jump = dimX * nright;
for (k in 1nleft){
for (j in 1nright){
index = base + j;
for (l in 1dimX){
tmpX[l] = X[index]nright;
}
tmpX = cholX * tmpX;
index = base + j;
for (l in 1dimX){
X[index] = tmpX[l]nright;
}
}
base = base + jump;
}
nleft = nleft / dimY;
nright = nright * dimX;
# accumulate cholY
base = 0;
jump = dimY * nright;
for (k in 1nleft){
for (j in 1nright){
index = base + j;
for (l in 1dimY){
tmpY[l] = X[index]nright;
}
tmpY = cholY * tmpY;
index = base + j;
for (l in 1dimY){
X[index] = tmpY[l]nright;
}
}
base = base + jump;
}
nleft = nleft / dimZ;
nright = nright * dimY;
# accumulate cholZ
base = 0;
jump = dimZ * nright;
for (k in 1nleft){
for (j in 1nright){
index = base + j;
for (l in 1dimZ){
tmpZ[l] = X[index]nright;
}
tmpZ = cholZ * tmpZ;
index = base + j;
for (l in 1dimZ){
X[index] = tmpZ[l]nright;
}
}
}
}
A small wrapper for the stan stuff:
kronRmult_stan <- function(X, Q){
dims <- laply(Q, dim)[,1]
standat <- list(dimX = dims[1], dimY = dims[2], dimZ = dims[3], dimT = dim(X)[1], ncol_X = prod(dims), cholX = Q[[3]], cholY = Q[[2]], cholZ = Q[[1]], Xin = X)
stano <- sampling(kronRmult_stanObj, data=standat, iter=1, chains=1, algorithm="Fixed_param")
return(extract(stano, "X")$X[1,,])
}
kronLmult_stan <- function(X, Q){
dims <- laply(Q, dim)[,1]
standat <- list(dimX = dimx, dimY = dimy, dimZ = dimz, dimT = dimt, nrow_X = dimx*dimy*dimz, cholX = Q[[3]], cholY = Q[[2]], cholZ = Q[[1]], Xin = X)
stano <- sampling(kronLmult_stanObj, data=standat, iter=1, chains=1, algorithm="Fixed_param")
return(extract(stano, "X")$X[1,,])
}
Tests:
library(rstan)
## Loading required package: ggplot2
## Loading required package: StanHeaders
## rstan (Version 2.12.1, packaged: 2016-09-11 13:07:50 UTC, GitRev: 85f7a56811da)
## For execution on a local, multicore CPU with excess RAM we recommend calling
## rstan_options(auto_write = TRUE)
## options(mc.cores = parallel::detectCores())
dimt <- 10
dimx <- 5
dimy <- 5
dimz <- 5
Q <- llply(c(dimx, dimy, dimz), function(x) rWishart(1, df=x+2, Sigma=diag(x))[,,1])
X <- matrix(rnorm(dimx*dimy*dimz*dimt), nrow=dimt)
naive <- kronRmult_naive(X, Q)
fast <- kronRmult(X, Q)
stanvers <- kronRmult_stan(X, Q)
##
## SAMPLING FOR MODEL '5856c0cd77b5248d77b893e6428d44e0' NOW (CHAIN 1).
##
## Chain 1, Iteration: 1 / 1 [100%] (Sampling)
## Elapsed Time: 4e-06 seconds (Warm-up)
## 0.001411 seconds (Sampling)
## 0.001415 seconds (Total)
expect_equal(naive, fast)
expect_equal(naive, stanvers)
naive <- kronLmult_naive(t(X), Q)
fast <- kronLmult(t(X), Q)
stanvers <- kronLmult_stan(t(X), Q)
##
## SAMPLING FOR MODEL '718d53202c4e7f0845f5e5955e2a3e1a' NOW (CHAIN 1).
##
## Chain 1, Iteration: 1 / 1 [100%] (Sampling)
## Elapsed Time: 6e-06 seconds (Warm-up)
## 0.001946 seconds (Sampling)
## 0.001952 seconds (Total)
expect_equal(naive, fast)
expect_equal(naive, stanvers)