R Advent Calendar 2012の29日目の記事です。世間のアドベントっぷりからすっかり乗り遅れてしまい不覚にもクリスマス後の29日の担当となってしまいました。。
今日の話題は、最近巷で話題のモンテカルロサンプラーであるStanをRから並列で動かせたらいいじゃね?という話です。Stanは従来のWinBUGSとかJAGSとかで記述に使うBUGS言語っぽいので記述したものを一度C++コードに変換してから実行するため高速であるわけですが、C++コードへの変換時のオーバーヘッドが大きそうなのは容易に想像が付きそうですね。だから、この変換は1度きりとしてchainの数だけ並列処理できたらハッピーなわけですよ。だからやってみました。
ストラディバリウス(ストラテジ)としてはRからStanを動かすための既存パッケージRStanをあえて使わずにStanプログラムそのものをRのSystem関数で呼び出す感じになってます。何故かというと、ぶっちゃけこのネタのそもそものスタートがCentOS5.4のgcc4.1.2でRStanがビルドできなかったからです。stanそのもののビルドはgcc44で可能だったのですが、R CMD INSTALL 時にgccじゃなくてgcc44を利用する方法が分からなくてしかたなく生Stanじゃ!で、どうせなら並列化もしちゃえといった次第です。
2012年12月29日現在、stanは1.1.0がリリースされていますが今回はあえて旧版の1.0.3を利用しました。理由はforループの処理速度が1.1.0から遅くなってしまっているからです。
また、RStan1.1.0のリリースにともなってドキュメントもちょっとだけ充実してrstanを並列で動かそうぜ的な記述も出現しちゃったりしてます。
# CentOS 5.4だとデフォのgccが4.1.2だからビルドできない
# だから4.4系のgcc44をyumでインストールする
system("yum instal gcc44")
# stanのソースをダウンロードしてビルド
system("wget http://stan.googlecode.com/files/stan-src-1.0.3.tgz")
system("tar -xzf stan-src-1.0.3.tgz")
system("cd stan-src-1.0.3")
# makefileの編集
makefile <- scan("makefile", what = "character", sep = "\n")
# コンパイラの指定をg++からg++44へ変更する
makefile[grep("^CC =", makefile)] <- "CC = g++44"
write.table(makefile, file = "makefile", row.names = FALSE, col.names = FALSE,
quote = FALSE)
system("make bin/libstan.a")
system("make bin/stanc")
library(coda, quietly = TRUE)
library(parallel, quietly = TRUE)
library(dlm, quietly = TRUE)
hansen <- function(file, data, init = "random", iter = 10000, warmup = 5000,
chains = 4, thin = 1, mc.cores = detectCores(logical = FALSE), stan.home = "~/stan-src-1.0.3",
seeds = 111:114) {
# file--> *.stan file data-->
# listかテキストファイル名(かdata.frameかvector) init-->
# listかテキストファイル名
file.sep <- as.character(ifelse(Sys.info()["sysname"] == "Windows", "\\",
"/"))
prg <- strsplit(file, "\\.")[[1]][1]
cur.dir <- getwd()
setwd(stan.home)
system(paste("make ", cur.dir, file.sep, prg, sep = ""))
setwd(cur.dir)
cmd <- ""
if (length(seeds) < chains) {
seeds <- c(apply(data.frame(n = (1:ceiling(chains/length(seeds)))),
1, function(x) {
x * seeds
}))[1:chains]
}
if (chains < mc.cores) {
mc.cores <- chains
}
if (warmup > iter) {
warmup <- as.integer(iter/2)
}
if (!("hansen.stan" %in% dir())) {
stop(paste(file, " is not found.", sep = ""))
}
if (is.data.frame(data)) {
for (i in 1:ncol(data)) {
assign(x = colnames(data)[i], value = data[, i])
}
n <- as.integer(nrow(data))
dump(c(colnames(data), "n"), "data.R")
} else if (is.list(data)) {
for (i in 1:length(data)) {
assign(x = names(data)[i], value = data[[i]])
}
dump(names(data), "data.R")
} else if (is.vector(data)) {
# この場合dataっていう名前がヤヴァい気もする
dump(c("data", "n"), "data.R")
} else if (exists(data)) {
# ファイル名指定の場合はファイルコピーな
file.copy(from = data, to = "data.R")
} else {
stop(paste(data, "must be data.frame or list.", sep = ""))
}
if (is.list(init)) {
# listの場合はdumpしてファイル名を渡す
for (i in 1:length(init)) {
assign(x = names(init)[i], value = init[[i]])
}
dump(names(init), "init.R")
cmd <- paste(cmd, ".", file.sep, prg, " --data=data.R --iter=", iter,
" --warmup=", warmup, " --thin=", thin, " --init=init.R", sep = "")
} else if (exists(init)) {
# ファイル名の場合はそのまま渡しちゃえ
cmd <- paste(cmd, ".", file.sep, prg, " --data=data.R --iter=", iter,
" --warmup=", warmup, " --thin=", thin, " --init=", init, sep = "")
} else if (init == 0 | init %in% c("random", "0")) {
# init無いよ
cmd <- paste(cmd, ".", file.sep, prg, " --data=data.R --iter=", iter,
" --warmup=", warmup, " --thin=", thin, sep = "")
} else {
# init無いことにしちゃえ
cmd <- paste(cmd, ".", file.sep, prg, " --data=data.R --iter=", iter,
" --warmup=", warmup, " --thin=", thin, sep = "")
}
# 内部関数定義
run <- function(chain) {
# seedとchainIdとcsv
cmd2 <- paste(cmd, " --seed=", seeds[chain], " --chain_id=", chain,
" --samples=", prg, ".chain", chain, ".csv", sep = "")
print(cmd2)
system(cmd2)
s <- read.csv(paste(prg, ".chain", chain, ".csv", sep = ""), comment.char = "#")
return(mcmc(s, start = warmup + 1, thin = thin))
}
samples <- as.mcmc.list(mclapply(1:chains, run, mc.cores = mc.cores))
file.remove(dir()[grep(paste("^", prg, ".chain\\d\\.csv$", sep = ""), dir())])
file.remove(paste(prg, c(".cpp", ".d", ".o", ""), sep = ""))
if ("data.R" %in% dir()) {
file.remove("data.R")
}
if ("init.R" %in% dir()) {
file.remove("init.R")
}
return(samples)
}
今回の例で使用するのは「1年毎のナイル川の流量データ」であるNile。
これを状態空間モデルの一種であるdlm(Dynamic Linear Model)でモデリングしてみる
dlmの定義式は以下の通り
\( y_t = F_t\theta_t + v_t, v_t \sim N(0,V_t) \)
\( \theta_t=G_t\theta_{t-1} + w_t, w_t \sim N(0,W_t) \)
data(Nile)
plot(Nile, main = "ナイル川とは川川")
nile.fit <- dlmMLE(y = Nile, parm = c(1, 1), build = function(x) {
dlmModPoly(1, dV = x[1], dW = x[2])
})
sprintf("最尤法で求めると、V=%.2f, W=%.2f", sqrt(nile.fit$par[1]),
sqrt(nile.fit$par[2]))
## [1] "最尤法で求めると、V=122.88, W=38.32"
data(Nile)
data.list <- list(Nile = as.vector(Nile), N = as.integer(length(Nile)), F = 1,
G = 1)
# stanファイル準備 sink関数でファイル出力する方法(Knit
# HTMLした時)がわからんかった
stan.str <- "data {\nint<lower=0> N;\nreal Nile[N];\nreal F;\nreal G;\n}\nparameters {\nreal theta[N];\nreal<lower=0> sigma_w;\nreal<lower=0> sigma_v;\n}\nmodel {\ntheta[1] ~ normal(Nile[1], sigma_w);\nNile[1] ~ normal(F * theta[1], sigma_v);\nfor (i in 2:N) {\ntheta[i] ~ normal(G * theta[i - 1], sigma_w);\nNile[i] ~ normal(F * theta[i], sigma_v);\n}\nsigma_w ~ uniform(0.0, 1.0e+4);\nsigma_v ~ uniform(0.0, 1.0e+4);\n}"
write.table(stan.str, file = "Nile.stan", col.names = FALSE, row.names = FALSE,
quote = FALSE)
# stanファイル書き出しここまで
# mcmcする(ついでに時間を計測)
system.time({
nile.stan <- hansen(file = "Nile.stan", data = data.list, iter = 20000,
thin = 3, stan.home = "~/downloads/stan-src-1.0.3")
})
## user system elapsed
## 79.139 0.708 46.842
# samplingをplot
plot(mcmc.list(lapply(nile.stan, function(x) {
x[, c("sigma_v", "sigma_w")]
})))
stan.vw <- apply(as.data.frame(lapply(nile.stan[, c("sigma_v", "sigma_w")],
function(x) {
apply(x, 2, median)
})), 1, mean)
sprintf("MCMCで求めると、V=%.2f, W=%.2f", stan.vw[1], stan.vw[2])
## [1] "MCMCで求めると、V=122.37, W=40.82"
x <- lapply(1:length(Nile), function(i) {
nile.stan[, paste("theta", i, sep = ".")]
})
plot(Nile, type = "o", las = 1)
lines(tsp(Nile)[1]:tsp(Nile)[2], sapply(1:length(Nile), function(i) {
mean(unlist(x[i]))
}), col = "red")
リファレンス