先日GelmanらがStanにおけるLOOCV, WAICについて、以下の論文を発表しました。
論文: Efficient implementation of leave-one-out cross-validation and WAIC for evaluating fitted Bayesian models
Blog記事: New papers on LOO/WAIC and Stan

この論文に関連して、今回のお題であるLOOなるパッケージが合わせて発表されています。
本パッケージはWAICやLOOCVの値をstanfitから算出できる点が売りのパッケージのようです。

インストールは
install.packages(“loo”)
でいけるようですが、Githubから
library(“devtools”)
install_github(“jgabry/loo”) でもインストール可能です。

今回の投稿は以前@berobero11さんがブログで公開されていた
“WAICとWBICを事後確率から計算してみる”で算出されたWAICの値と、looを使って計算したWAICの値を比較してみたという内容です。 データの作成、WAICの算出についてのコードは許可をいただき引用させて頂きました。

パッケージ読み込み

まずはパッケージを読み込みます。

library("loo")
## This is loo version 0.1.2.9000
library("rstan")
## Loading required package: Rcpp
## Loading required package: inline
## 
## Attaching package: 'inline'
## 
##  以下のオブジェクトは 'package:Rcpp' からマスクされています: 
## 
##      registerPlugin 
## 
## rstan (Version 2.6.0, packaged: 2015-02-06 21:02:34 UTC, GitRev: 198082f07a60)

データ作成

ブログ記事にそってデータを作成していきます。 コードは次の通りです。

N <- 100
a_true <- 0.4
mean1 <- 0
mean2 <- 3
sd1 <- 1
sd2 <- 1
set.seed(1)
Y <- c(rnorm((1-a_true)*N, mean1, sd1), rnorm(a_true*N, mean2, sd2))
data <- list(N=N, Y=Y)

stanコード

続けてブログ記事にそって2つのモデルを書いていきます。いくらパッケージを使ったからといってgenerated quantities部分は省略できるわけではありません。また、パッケージの仕様なのか私の環境なのかはわかりませんが、元コードの“log_likelihood” を論文に載っている“log_lik”にしないと、後に出てくる対数尤度の抜き出しが動きませんでした。

Model1

model1 ="
data {
int<lower=1> N;
vector[N] Y;
}

parameters {
real<lower=0, upper=1> a;
real<lower=-50, upper=50> mu;
}

model {
a ~ uniform(0, 1);
mu ~ uniform(-50, 50);

for(n in 1:N)
increment_log_prob(
log_sum_exp(
log(1-a) + normal_log(Y[n], 0, 1),
log(a) + normal_log(Y[n], mu, 1)
)
);
}

generated quantities {
vector[N] log_lik;
for (n in 1:N)
log_lik[n] <- log_sum_exp(
log(1-a) + normal_log(Y[n], 0, 1),
log(a) + normal_log(Y[n], mu, 1)
);
}
"

stanfit1 <- stan(model_code = model1, data=data, iter=10000, thin=10, chains=3)
## 
## TRANSLATING MODEL 'model1' FROM Stan CODE TO C++ CODE NOW.
## COMPILING THE C++ CODE FOR MODEL 'model1' NOW.
## 
## SAMPLING FOR MODEL 'model1' NOW (CHAIN 1).
## 
## Iteration:    1 / 10000 [  0%]  (Warmup)
## Iteration: 1000 / 10000 [ 10%]  (Warmup)
## Iteration: 2000 / 10000 [ 20%]  (Warmup)
## Iteration: 3000 / 10000 [ 30%]  (Warmup)
## Iteration: 4000 / 10000 [ 40%]  (Warmup)
## Iteration: 5000 / 10000 [ 50%]  (Warmup)
## Iteration: 5001 / 10000 [ 50%]  (Sampling)
## Iteration: 6000 / 10000 [ 60%]  (Sampling)
## Iteration: 7000 / 10000 [ 70%]  (Sampling)
## Iteration: 8000 / 10000 [ 80%]  (Sampling)
## Iteration: 9000 / 10000 [ 90%]  (Sampling)
## Iteration: 10000 / 10000 [100%]  (Sampling)
## #  Elapsed Time: 1.07286 seconds (Warm-up)
## #                0.432254 seconds (Sampling)
## #                1.50512 seconds (Total)
## 
## 
## SAMPLING FOR MODEL 'model1' NOW (CHAIN 2).
## 
## Iteration:    1 / 10000 [  0%]  (Warmup)
## Iteration: 1000 / 10000 [ 10%]  (Warmup)
## Iteration: 2000 / 10000 [ 20%]  (Warmup)
## Iteration: 3000 / 10000 [ 30%]  (Warmup)
## Iteration: 4000 / 10000 [ 40%]  (Warmup)
## Iteration: 5000 / 10000 [ 50%]  (Warmup)
## Iteration: 5001 / 10000 [ 50%]  (Sampling)
## Iteration: 6000 / 10000 [ 60%]  (Sampling)
## Iteration: 7000 / 10000 [ 70%]  (Sampling)
## Iteration: 8000 / 10000 [ 80%]  (Sampling)
## Iteration: 9000 / 10000 [ 90%]  (Sampling)
## Iteration: 10000 / 10000 [100%]  (Sampling)
## #  Elapsed Time: 0.450727 seconds (Warm-up)
## #                0.415922 seconds (Sampling)
## #                0.866649 seconds (Total)
## 
## 
## SAMPLING FOR MODEL 'model1' NOW (CHAIN 3).
## 
## Iteration:    1 / 10000 [  0%]  (Warmup)
## Iteration: 1000 / 10000 [ 10%]  (Warmup)
## Iteration: 2000 / 10000 [ 20%]  (Warmup)
## Iteration: 3000 / 10000 [ 30%]  (Warmup)
## Iteration: 4000 / 10000 [ 40%]  (Warmup)
## Iteration: 5000 / 10000 [ 50%]  (Warmup)
## Iteration: 5001 / 10000 [ 50%]  (Sampling)
## Iteration: 6000 / 10000 [ 60%]  (Sampling)
## Iteration: 7000 / 10000 [ 70%]  (Sampling)
## Iteration: 8000 / 10000 [ 80%]  (Sampling)
## Iteration: 9000 / 10000 [ 90%]  (Sampling)
## Iteration: 10000 / 10000 [100%]  (Sampling)
## #  Elapsed Time: 0.39501 seconds (Warm-up)
## #                0.432937 seconds (Sampling)
## #                0.827947 seconds (Total)

Model2

model2 ="
data {
int<lower=1> N;
vector[N] Y;
}

parameters {
real mu;
real<lower=0> s;
}

model {
mu ~ normal(0, 100);
s ~ uniform(0, 1000);
Y ~ normal(mu, s);
}

generated quantities {
vector[N] log_lik;
for (n in 1:N)
log_lik[n] <- normal_log(Y[n], mu, s);
}
"
stanfit2 <- stan(model_code = model2, data=data, iter=10000, thin=10, chains=3)
## 
## TRANSLATING MODEL 'model2' FROM Stan CODE TO C++ CODE NOW.
## COMPILING THE C++ CODE FOR MODEL 'model2' NOW.
## 
## SAMPLING FOR MODEL 'model2' NOW (CHAIN 1).
## 
## Iteration:    1 / 10000 [  0%]  (Warmup)
## Iteration: 1000 / 10000 [ 10%]  (Warmup)
## Iteration: 2000 / 10000 [ 20%]  (Warmup)
## Iteration: 3000 / 10000 [ 30%]  (Warmup)
## Iteration: 4000 / 10000 [ 40%]  (Warmup)
## Iteration: 5000 / 10000 [ 50%]  (Warmup)
## Iteration: 5001 / 10000 [ 50%]  (Sampling)
## Iteration: 6000 / 10000 [ 60%]  (Sampling)
## Iteration: 7000 / 10000 [ 70%]  (Sampling)
## Iteration: 8000 / 10000 [ 80%]  (Sampling)
## Iteration: 9000 / 10000 [ 90%]  (Sampling)
## Iteration: 10000 / 10000 [100%]  (Sampling)
## #  Elapsed Time: 0.045082 seconds (Warm-up)
## #                0.049479 seconds (Sampling)
## #                0.094561 seconds (Total)
## 
## 
## SAMPLING FOR MODEL 'model2' NOW (CHAIN 2).
## 
## Iteration:    1 / 10000 [  0%]  (Warmup)
## Iteration: 1000 / 10000 [ 10%]  (Warmup)
## Iteration: 2000 / 10000 [ 20%]  (Warmup)
## Iteration: 3000 / 10000 [ 30%]  (Warmup)
## Iteration: 4000 / 10000 [ 40%]  (Warmup)
## Iteration: 5000 / 10000 [ 50%]  (Warmup)
## Iteration: 5001 / 10000 [ 50%]  (Sampling)
## Iteration: 6000 / 10000 [ 60%]  (Sampling)
## Iteration: 7000 / 10000 [ 70%]  (Sampling)
## Iteration: 8000 / 10000 [ 80%]  (Sampling)
## Iteration: 9000 / 10000 [ 90%]  (Sampling)
## Iteration: 10000 / 10000 [100%]  (Sampling)
## #  Elapsed Time: 0.04503 seconds (Warm-up)
## #                0.045867 seconds (Sampling)
## #                0.090897 seconds (Total)
## 
## 
## SAMPLING FOR MODEL 'model2' NOW (CHAIN 3).
## 
## Iteration:    1 / 10000 [  0%]  (Warmup)
## Iteration: 1000 / 10000 [ 10%]  (Warmup)
## Iteration: 2000 / 10000 [ 20%]  (Warmup)
## Iteration: 3000 / 10000 [ 30%]  (Warmup)
## Iteration: 4000 / 10000 [ 40%]  (Warmup)
## Iteration: 5000 / 10000 [ 50%]  (Warmup)
## Iteration: 5001 / 10000 [ 50%]  (Sampling)
## Iteration: 6000 / 10000 [ 60%]  (Sampling)
## Iteration: 7000 / 10000 [ 70%]  (Sampling)
## Iteration: 8000 / 10000 [ 80%]  (Sampling)
## Iteration: 9000 / 10000 [ 90%]  (Sampling)
## Iteration: 10000 / 10000 [100%]  (Sampling)
## #  Elapsed Time: 0.048658 seconds (Warm-up)
## #                0.046637 seconds (Sampling)
## #                0.095295 seconds (Total)

looによるWAIC算出

looによりWAICを算出します。 stanfit内の1データごとの対数尤度をextract_log_lik関数で抜き出せるようになっているようです。抜き出した対数尤度のmatrixをwaicに代入して完成です。

モデル1のWAIC

log_lik1 <- extract_log_lik(stanfit1)
waic1 <- waic(log_lik1)
print(waic1 , digits = 4)
## Computed from 1500 by 100 log-likelihood matrix
## 
##            Estimate      SE
## elpd_waic -191.3595  5.4811
## p_waic       2.0126  0.2751
## waic       382.7190 10.9622

モデル2のWAIC

log_lik2 <- extract_log_lik(stanfit2)
waic2 <- waic(log_lik2)
print(waic2 , digits = 4)
## Computed from 1500 by 100 log-likelihood matrix
## 
##            Estimate      SE
## elpd_waic -197.9614  5.5406
## p_waic       1.5154  0.2270
## waic       395.9229 11.0812

特別お手軽に計算できるというわけではありませんが、WAICのSEも出してくれるのはいい感じですね。

WAICの値について

@berobero11さんのコードで算出されるWAICは渡辺先生の定義されたもの、 LOOパッケージで計算される値はBDA3で定義されている値です。 BDA3で定義された値は渡辺先生の定義された値に2Nをかけたものなので、 LOOで算出した値を2Nで割った値を比較に使います。

元記事におけるmodel1のWAIC: 1.914

waic1$waic/(2*N)
## [1] 1.913595

元記事におけるmodel2のWAIC: 1.981

waic2$waic/(2*N)
## [1] 1.979614

ほぼ同じ値が得られていますね!

2つのモデルのWAIC比較

compare(waic1, waic2)
## elpd_diff        SE 
##      -6.6       3.5

値比較できるよって書いてあったのでどんな感じか期待してましたが、 普通にelpdの差分とSEが算出されて終了。なんというか、普通ですね…。
とりあえず今回はここまでです。