先日GelmanらがStanにおけるLOOCV, WAICについて、以下の論文を発表しました。
論文: Efficient implementation of leave-one-out cross-validation and WAIC for evaluating fitted Bayesian models
Blog記事: New papers on LOO/WAIC and Stan
この論文に関連して、今回のお題であるLOOなるパッケージが合わせて発表されています。
本パッケージはWAICやLOOCVの値をstanfitから算出できる点が売りのパッケージのようです。
インストールは
install.packages(“loo”)
でいけるようですが、Githubから
library(“devtools”)
install_github(“jgabry/loo”) でもインストール可能です。
今回の投稿は以前@berobero11さんがブログで公開されていた、
“WAICとWBICを事後確率から計算してみる”で算出されたWAICの値と、looを使って計算したWAICの値を比較してみたという内容です。 データの作成、WAICの算出についてのコードは許可をいただき引用させて頂きました。
まずはパッケージを読み込みます。
library("loo")
## This is loo version 0.1.2.9000
library("rstan")
## Loading required package: Rcpp
## Loading required package: inline
##
## Attaching package: 'inline'
##
## 以下のオブジェクトは 'package:Rcpp' からマスクされています:
##
## registerPlugin
##
## rstan (Version 2.6.0, packaged: 2015-02-06 21:02:34 UTC, GitRev: 198082f07a60)
ブログ記事にそってデータを作成していきます。 コードは次の通りです。
N <- 100
a_true <- 0.4
mean1 <- 0
mean2 <- 3
sd1 <- 1
sd2 <- 1
set.seed(1)
Y <- c(rnorm((1-a_true)*N, mean1, sd1), rnorm(a_true*N, mean2, sd2))
data <- list(N=N, Y=Y)
続けてブログ記事にそって2つのモデルを書いていきます。いくらパッケージを使ったからといってgenerated quantities部分は省略できるわけではありません。また、パッケージの仕様なのか私の環境なのかはわかりませんが、元コードの“log_likelihood” を論文に載っている“log_lik”にしないと、後に出てくる対数尤度の抜き出しが動きませんでした。
model1 ="
data {
int<lower=1> N;
vector[N] Y;
}
parameters {
real<lower=0, upper=1> a;
real<lower=-50, upper=50> mu;
}
model {
a ~ uniform(0, 1);
mu ~ uniform(-50, 50);
for(n in 1:N)
increment_log_prob(
log_sum_exp(
log(1-a) + normal_log(Y[n], 0, 1),
log(a) + normal_log(Y[n], mu, 1)
)
);
}
generated quantities {
vector[N] log_lik;
for (n in 1:N)
log_lik[n] <- log_sum_exp(
log(1-a) + normal_log(Y[n], 0, 1),
log(a) + normal_log(Y[n], mu, 1)
);
}
"
stanfit1 <- stan(model_code = model1, data=data, iter=10000, thin=10, chains=3)
##
## TRANSLATING MODEL 'model1' FROM Stan CODE TO C++ CODE NOW.
## COMPILING THE C++ CODE FOR MODEL 'model1' NOW.
##
## SAMPLING FOR MODEL 'model1' NOW (CHAIN 1).
##
## Iteration: 1 / 10000 [ 0%] (Warmup)
## Iteration: 1000 / 10000 [ 10%] (Warmup)
## Iteration: 2000 / 10000 [ 20%] (Warmup)
## Iteration: 3000 / 10000 [ 30%] (Warmup)
## Iteration: 4000 / 10000 [ 40%] (Warmup)
## Iteration: 5000 / 10000 [ 50%] (Warmup)
## Iteration: 5001 / 10000 [ 50%] (Sampling)
## Iteration: 6000 / 10000 [ 60%] (Sampling)
## Iteration: 7000 / 10000 [ 70%] (Sampling)
## Iteration: 8000 / 10000 [ 80%] (Sampling)
## Iteration: 9000 / 10000 [ 90%] (Sampling)
## Iteration: 10000 / 10000 [100%] (Sampling)
## # Elapsed Time: 1.07286 seconds (Warm-up)
## # 0.432254 seconds (Sampling)
## # 1.50512 seconds (Total)
##
##
## SAMPLING FOR MODEL 'model1' NOW (CHAIN 2).
##
## Iteration: 1 / 10000 [ 0%] (Warmup)
## Iteration: 1000 / 10000 [ 10%] (Warmup)
## Iteration: 2000 / 10000 [ 20%] (Warmup)
## Iteration: 3000 / 10000 [ 30%] (Warmup)
## Iteration: 4000 / 10000 [ 40%] (Warmup)
## Iteration: 5000 / 10000 [ 50%] (Warmup)
## Iteration: 5001 / 10000 [ 50%] (Sampling)
## Iteration: 6000 / 10000 [ 60%] (Sampling)
## Iteration: 7000 / 10000 [ 70%] (Sampling)
## Iteration: 8000 / 10000 [ 80%] (Sampling)
## Iteration: 9000 / 10000 [ 90%] (Sampling)
## Iteration: 10000 / 10000 [100%] (Sampling)
## # Elapsed Time: 0.450727 seconds (Warm-up)
## # 0.415922 seconds (Sampling)
## # 0.866649 seconds (Total)
##
##
## SAMPLING FOR MODEL 'model1' NOW (CHAIN 3).
##
## Iteration: 1 / 10000 [ 0%] (Warmup)
## Iteration: 1000 / 10000 [ 10%] (Warmup)
## Iteration: 2000 / 10000 [ 20%] (Warmup)
## Iteration: 3000 / 10000 [ 30%] (Warmup)
## Iteration: 4000 / 10000 [ 40%] (Warmup)
## Iteration: 5000 / 10000 [ 50%] (Warmup)
## Iteration: 5001 / 10000 [ 50%] (Sampling)
## Iteration: 6000 / 10000 [ 60%] (Sampling)
## Iteration: 7000 / 10000 [ 70%] (Sampling)
## Iteration: 8000 / 10000 [ 80%] (Sampling)
## Iteration: 9000 / 10000 [ 90%] (Sampling)
## Iteration: 10000 / 10000 [100%] (Sampling)
## # Elapsed Time: 0.39501 seconds (Warm-up)
## # 0.432937 seconds (Sampling)
## # 0.827947 seconds (Total)
model2 ="
data {
int<lower=1> N;
vector[N] Y;
}
parameters {
real mu;
real<lower=0> s;
}
model {
mu ~ normal(0, 100);
s ~ uniform(0, 1000);
Y ~ normal(mu, s);
}
generated quantities {
vector[N] log_lik;
for (n in 1:N)
log_lik[n] <- normal_log(Y[n], mu, s);
}
"
stanfit2 <- stan(model_code = model2, data=data, iter=10000, thin=10, chains=3)
##
## TRANSLATING MODEL 'model2' FROM Stan CODE TO C++ CODE NOW.
## COMPILING THE C++ CODE FOR MODEL 'model2' NOW.
##
## SAMPLING FOR MODEL 'model2' NOW (CHAIN 1).
##
## Iteration: 1 / 10000 [ 0%] (Warmup)
## Iteration: 1000 / 10000 [ 10%] (Warmup)
## Iteration: 2000 / 10000 [ 20%] (Warmup)
## Iteration: 3000 / 10000 [ 30%] (Warmup)
## Iteration: 4000 / 10000 [ 40%] (Warmup)
## Iteration: 5000 / 10000 [ 50%] (Warmup)
## Iteration: 5001 / 10000 [ 50%] (Sampling)
## Iteration: 6000 / 10000 [ 60%] (Sampling)
## Iteration: 7000 / 10000 [ 70%] (Sampling)
## Iteration: 8000 / 10000 [ 80%] (Sampling)
## Iteration: 9000 / 10000 [ 90%] (Sampling)
## Iteration: 10000 / 10000 [100%] (Sampling)
## # Elapsed Time: 0.045082 seconds (Warm-up)
## # 0.049479 seconds (Sampling)
## # 0.094561 seconds (Total)
##
##
## SAMPLING FOR MODEL 'model2' NOW (CHAIN 2).
##
## Iteration: 1 / 10000 [ 0%] (Warmup)
## Iteration: 1000 / 10000 [ 10%] (Warmup)
## Iteration: 2000 / 10000 [ 20%] (Warmup)
## Iteration: 3000 / 10000 [ 30%] (Warmup)
## Iteration: 4000 / 10000 [ 40%] (Warmup)
## Iteration: 5000 / 10000 [ 50%] (Warmup)
## Iteration: 5001 / 10000 [ 50%] (Sampling)
## Iteration: 6000 / 10000 [ 60%] (Sampling)
## Iteration: 7000 / 10000 [ 70%] (Sampling)
## Iteration: 8000 / 10000 [ 80%] (Sampling)
## Iteration: 9000 / 10000 [ 90%] (Sampling)
## Iteration: 10000 / 10000 [100%] (Sampling)
## # Elapsed Time: 0.04503 seconds (Warm-up)
## # 0.045867 seconds (Sampling)
## # 0.090897 seconds (Total)
##
##
## SAMPLING FOR MODEL 'model2' NOW (CHAIN 3).
##
## Iteration: 1 / 10000 [ 0%] (Warmup)
## Iteration: 1000 / 10000 [ 10%] (Warmup)
## Iteration: 2000 / 10000 [ 20%] (Warmup)
## Iteration: 3000 / 10000 [ 30%] (Warmup)
## Iteration: 4000 / 10000 [ 40%] (Warmup)
## Iteration: 5000 / 10000 [ 50%] (Warmup)
## Iteration: 5001 / 10000 [ 50%] (Sampling)
## Iteration: 6000 / 10000 [ 60%] (Sampling)
## Iteration: 7000 / 10000 [ 70%] (Sampling)
## Iteration: 8000 / 10000 [ 80%] (Sampling)
## Iteration: 9000 / 10000 [ 90%] (Sampling)
## Iteration: 10000 / 10000 [100%] (Sampling)
## # Elapsed Time: 0.048658 seconds (Warm-up)
## # 0.046637 seconds (Sampling)
## # 0.095295 seconds (Total)
looによりWAICを算出します。 stanfit内の1データごとの対数尤度をextract_log_lik関数で抜き出せるようになっているようです。抜き出した対数尤度のmatrixをwaicに代入して完成です。
log_lik1 <- extract_log_lik(stanfit1)
waic1 <- waic(log_lik1)
print(waic1 , digits = 4)
## Computed from 1500 by 100 log-likelihood matrix
##
## Estimate SE
## elpd_waic -191.3595 5.4811
## p_waic 2.0126 0.2751
## waic 382.7190 10.9622
log_lik2 <- extract_log_lik(stanfit2)
waic2 <- waic(log_lik2)
print(waic2 , digits = 4)
## Computed from 1500 by 100 log-likelihood matrix
##
## Estimate SE
## elpd_waic -197.9614 5.5406
## p_waic 1.5154 0.2270
## waic 395.9229 11.0812
特別お手軽に計算できるというわけではありませんが、WAICのSEも出してくれるのはいい感じですね。
@berobero11さんのコードで算出されるWAICは渡辺先生の定義されたもの、 LOOパッケージで計算される値はBDA3で定義されている値です。 BDA3で定義された値は渡辺先生の定義された値に2Nをかけたものなので、 LOOで算出した値を2Nで割った値を比較に使います。
waic1$waic/(2*N)
## [1] 1.913595
waic2$waic/(2*N)
## [1] 1.979614
ほぼ同じ値が得られていますね!
compare(waic1, waic2)
## elpd_diff SE
## -6.6 3.5
値比較できるよって書いてあったのでどんな感じか期待してましたが、 普通にelpdの差分とSEが算出されて終了。なんというか、普通ですね…。
とりあえず今回はここまでです。