RからStanを並列で走らせてみた

R Advent Calendar 2012の29日目の記事です。世間のアドベントっぷりからすっかり乗り遅れてしまい不覚にもクリスマス後の29日の担当となってしまいました。。

今日の話題は、最近巷で話題のモンテカルロサンプラーであるStanをRから並列で動かせたらいいじゃね?という話です。Stanは従来のWinBUGSとかJAGSとかで記述に使うBUGS言語っぽいので記述したものを一度C++コードに変換してから実行するため高速であるわけですが、C++コードへの変換時のオーバーヘッドが大きそうなのは容易に想像が付きそうですね。だから、この変換は1度きりとしてchainの数だけ並列処理できたらハッピーなわけですよ。だからやってみました。

ストラディバリウス(ストラテジ)としてはRからStanを動かすための既存パッケージRStanをあえて使わずにStanプログラムそのものをRのSystem関数で呼び出す感じになってます。何故かというと、ぶっちゃけこのネタのそもそものスタートがCentOS5.4のgcc4.1.2でRStanがビルドできなかったからです。stanそのもののビルドはgcc44で可能だったのですが、R CMD INSTALL 時にgccじゃなくてgcc44を利用する方法が分からなくてしかたなく生Stanじゃ!で、どうせなら並列化もしちゃえといった次第です。

2012年12月29日現在、stanは1.1.0がリリースされていますが今回はあえて旧版の1.0.3を利用しました。理由はforループの処理速度が1.1.0から遅くなってしまっているからです。

また、RStan1.1.0のリリースにともなってドキュメントもちょっとだけ充実してrstanを並列で動かそうぜ的な記述も出現しちゃったりしてます。

環境構築とか(ここでは実際に実行はしない)

# CentOS 5.4だとデフォのgccが4.1.2だからビルドできない
# だから4.4系のgcc44をyumでインストールする
system("yum instal gcc44")
# stanのソースをダウンロードしてビルド
system("wget http://stan.googlecode.com/files/stan-src-1.0.3.tgz")
system("tar -xzf stan-src-1.0.3.tgz")
system("cd stan-src-1.0.3")
# makefileの編集
makefile <- scan("makefile", what = "character", sep = "\n")
# コンパイラの指定をg++からg++44へ変更する
makefile[grep("^CC =", makefile)] <- "CC = g++44"
write.table(makefile, file = "makefile", row.names = FALSE, col.names = FALSE, 
    quote = FALSE)
system("make bin/libstan.a")
system("make bin/stanc")

初期設定

library(coda, quietly = TRUE)
library(parallel, quietly = TRUE)
library(dlm, quietly = TRUE)

RからStanを使うための関数定義(Win機で動作検証してないけどたぶん動かない)

hansen <- function(file, data, init = "random", iter = 10000, warmup = 5000, 
    chains = 4, thin = 1, mc.cores = detectCores(logical = FALSE), stan.home = "~/stan-src-1.0.3", 
    seeds = 111:114) {
    # file--> *.stan file data-->
    # listかテキストファイル名(かdata.frameかvector) init-->
    # listかテキストファイル名
    file.sep <- as.character(ifelse(Sys.info()["sysname"] == "Windows", "\\", 
        "/"))
    prg <- strsplit(file, "\\.")[[1]][1]
    cur.dir <- getwd()
    setwd(stan.home)
    system(paste("make ", cur.dir, file.sep, prg, sep = ""))
    setwd(cur.dir)
    cmd <- ""
    if (length(seeds) < chains) {
        seeds <- c(apply(data.frame(n = (1:ceiling(chains/length(seeds)))), 
            1, function(x) {
                x * seeds
            }))[1:chains]
    }
    if (chains < mc.cores) {
        mc.cores <- chains
    }
    if (warmup > iter) {
        warmup <- as.integer(iter/2)
    }
    if (!("hansen.stan" %in% dir())) {
        stop(paste(file, " is not found.", sep = ""))
    }
    if (is.data.frame(data)) {
        for (i in 1:ncol(data)) {
            assign(x = colnames(data)[i], value = data[, i])
        }
        n <- as.integer(nrow(data))
        dump(c(colnames(data), "n"), "data.R")
    } else if (is.list(data)) {
        for (i in 1:length(data)) {
            assign(x = names(data)[i], value = data[[i]])
        }
        dump(names(data), "data.R")
    } else if (is.vector(data)) {
        # この場合dataっていう名前がヤヴァい気もする
        dump(c("data", "n"), "data.R")
    } else if (exists(data)) {
        # ファイル名指定の場合はファイルコピーな
        file.copy(from = data, to = "data.R")
    } else {
        stop(paste(data, "must be data.frame or list.", sep = ""))
    }
    if (is.list(init)) {
        # listの場合はdumpしてファイル名を渡す
        for (i in 1:length(init)) {
            assign(x = names(init)[i], value = init[[i]])
        }
        dump(names(init), "init.R")
        cmd <- paste(cmd, ".", file.sep, prg, " --data=data.R --iter=", iter, 
            " --warmup=", warmup, " --thin=", thin, " --init=init.R", sep = "")
    } else if (exists(init)) {
        # ファイル名の場合はそのまま渡しちゃえ
        cmd <- paste(cmd, ".", file.sep, prg, " --data=data.R --iter=", iter, 
            " --warmup=", warmup, " --thin=", thin, " --init=", init, sep = "")
    } else if (init == 0 | init %in% c("random", "0")) {
        # init無いよ
        cmd <- paste(cmd, ".", file.sep, prg, " --data=data.R --iter=", iter, 
            " --warmup=", warmup, " --thin=", thin, sep = "")
    } else {
        # init無いことにしちゃえ
        cmd <- paste(cmd, ".", file.sep, prg, " --data=data.R --iter=", iter, 
            " --warmup=", warmup, " --thin=", thin, sep = "")
    }
    # 内部関数定義
    run <- function(chain) {
        # seedとchainIdとcsv
        cmd2 <- paste(cmd, " --seed=", seeds[chain], " --chain_id=", chain, 
            " --samples=", prg, ".chain", chain, ".csv", sep = "")
        print(cmd2)
        system(cmd2)
        s <- read.csv(paste(prg, ".chain", chain, ".csv", sep = ""), comment.char = "#")
        return(mcmc(s, start = warmup + 1, thin = thin))
    }
    samples <- as.mcmc.list(mclapply(1:chains, run, mc.cores = mc.cores))
    file.remove(dir()[grep(paste("^", prg, ".chain\\d\\.csv$", sep = ""), dir())])
    file.remove(paste(prg, c(".cpp", ".d", ".o", ""), sep = ""))
    if ("data.R" %in% dir()) {
        file.remove("data.R")
    }
    if ("init.R" %in% dir()) {
        file.remove("init.R")
    }
    return(samples)
}

データセットについて

今回の例で使用するのは「1年毎のナイル川の流量データ」であるNile。 これを状態空間モデルの一種であるdlm(Dynamic Linear Model)でモデリングしてみる dlmの定義式は以下の通り
\( y_t = F_t\theta_t + v_t, v_t \sim N(0,V_t) \)
\( \theta_t=G_t\theta_{t-1} + w_t, w_t \sim N(0,W_t) \)

data(Nile)
plot(Nile, main = "ナイル川とは川川")

plot of chunk nile

nile.fit <- dlmMLE(y = Nile, parm = c(1, 1), build = function(x) {
    dlmModPoly(1, dV = x[1], dW = x[2])
})
sprintf("最尤法で求めると、V=%.2f, W=%.2f", sqrt(nile.fit$par[1]), 
    sqrt(nile.fit$par[2]))
## [1] "最尤法で求めると、V=122.88, W=38.32"

Rからstan使う

data(Nile)
data.list <- list(Nile = as.vector(Nile), N = as.integer(length(Nile)), F = 1, 
    G = 1)
# stanファイル準備 sink関数でファイル出力する方法(Knit
# HTMLした時)がわからんかった
stan.str <- "data {\nint<lower=0> N;\nreal Nile[N];\nreal F;\nreal G;\n}\nparameters {\nreal theta[N];\nreal<lower=0> sigma_w;\nreal<lower=0> sigma_v;\n}\nmodel {\ntheta[1] ~ normal(Nile[1], sigma_w);\nNile[1] ~ normal(F * theta[1], sigma_v);\nfor (i in 2:N) {\ntheta[i] ~ normal(G * theta[i - 1], sigma_w);\nNile[i] ~ normal(F * theta[i], sigma_v);\n}\nsigma_w ~ uniform(0.0, 1.0e+4);\nsigma_v ~ uniform(0.0, 1.0e+4);\n}"
write.table(stan.str, file = "Nile.stan", col.names = FALSE, row.names = FALSE, 
    quote = FALSE)
# stanファイル書き出しここまで

# mcmcする(ついでに時間を計測)
system.time({
    nile.stan <- hansen(file = "Nile.stan", data = data.list, iter = 20000, 
        thin = 3, stan.home = "~/downloads/stan-src-1.0.3")
})
##    user  system elapsed 
##  79.139   0.708  46.842
# samplingをplot
plot(mcmc.list(lapply(nile.stan, function(x) {
    x[, c("sigma_v", "sigma_w")]
})))

plot of chunk stan_hansen

stan.vw <- apply(as.data.frame(lapply(nile.stan[, c("sigma_v", "sigma_w")], 
    function(x) {
        apply(x, 2, median)
    })), 1, mean)
sprintf("MCMCで求めると、V=%.2f, W=%.2f", stan.vw[1], stan.vw[2])
## [1] "MCMCで求めると、V=122.37, W=40.82"

Nile観測データと推定結果のプロット

x <- lapply(1:length(Nile), function(i) {
    nile.stan[, paste("theta", i, sep = ".")]
})
plot(Nile, type = "o", las = 1)
lines(tsp(Nile)[1]:tsp(Nile)[2], sapply(1:length(Nile), function(i) {
    mean(unlist(x[i]))
}), col = "red")

plot of chunk plt_nile

リファレンス