Stan で新しい値に対する予測区間を出す
######################################## パッケージの読み込み
library(rstan)
## Warning: package 'rstan' was built under R version 3.2.2
## Loading required package: Rcpp
## Warning: package 'Rcpp' was built under R version 3.2.2
## Loading required package: inline
##
## Attaching package: 'inline'
##
## The following object is masked from 'package:Rcpp':
##
## registerPlugin
##
## rstan (Version 2.7.0-1, packaged: 2015-07-17 18:12:01 UTC, GitRev: 05c3d0058b6a)
## For execution on a local, multicore CPU with excess RAM we recommend calling
## rstan_options(auto_write = TRUE)
## options(mc.cores = parallel::detectCores())
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 3.2.2
library(dplyr)
##
## Attaching package: 'dplyr'
##
## The following objects are masked from 'package:stats':
##
## filter, lag
##
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
######################################## モデル作成用のデータを生成
set.seed(314)
N <- 1000
x <- rnorm(N, mean = 50, sd = 10)
y <- 10 + 0.8 * x + rnorm(N, mean =0, sd = 7)
ggplot(data.frame(x=x, y=y), aes(x=x, y=y)) +
geom_point() + theme_bw()

############## 新しい値
x_new <- seq(10, 90, by=10)
x_new
## [1] 10 20 30 40 50 60 70 80 90
######################################## パラメータと予測値を推定する Stan コード
predict_for_new_value.stan <- '
data{
int<lower=0> N;
real x[N];
real y[N];
int<lower=0> N_new;
real x_new[N_new];
}
parameters {
real alpha;
real beta;
real<lower=0> s;
real y_pred[N_new]; # 予測値
}
model{
for(i in 1:N)
y[i] ~ normal(alpha + beta * x[i], s);
for(i in 1:N_new)
y_pred[i] ~ normal(alpha + beta * x_new[i], s);###!!!
alpha ~ normal(0, 100);
beta ~ normal(0, 100);
s ~ inv_gamma(0.001, 0.001);
}
'
######### Stan により、パラメータと予測値を推定
datastan <- list(N=N, x=x, y=y, N_new=length(x_new),
x_new=x_new)
war<-4000
ite<-11000
see<-12345
dig<-3
cha<-2
fit <- stan(model_code = predict_for_new_value.stan,
data=datastan,
iter=ite,
seed=see,
warmup=war,chain=cha)
## COMPILING THE C++ CODE FOR MODEL 'predict_for_new_value.stan' NOW.
##
## SAMPLING FOR MODEL 'predict_for_new_value.stan' NOW (CHAIN 1).
##
## Chain 1, Iteration: 1 / 11000 [ 0%] (Warmup)
## Chain 1, Iteration: 1100 / 11000 [ 10%] (Warmup)
## Chain 1, Iteration: 2200 / 11000 [ 20%] (Warmup)
## Chain 1, Iteration: 3300 / 11000 [ 30%] (Warmup)
## Chain 1, Iteration: 4001 / 11000 [ 36%] (Sampling)
## Chain 1, Iteration: 5100 / 11000 [ 46%] (Sampling)
## Chain 1, Iteration: 6200 / 11000 [ 56%] (Sampling)
## Chain 1, Iteration: 7300 / 11000 [ 66%] (Sampling)
## Chain 1, Iteration: 8400 / 11000 [ 76%] (Sampling)
## Chain 1, Iteration: 9500 / 11000 [ 86%] (Sampling)
## Chain 1, Iteration: 10600 / 11000 [ 96%] (Sampling)
## Chain 1, Iteration: 11000 / 11000 [100%] (Sampling)
## # Elapsed Time: 39.86 seconds (Warm-up)
## # 39.385 seconds (Sampling)
## # 79.245 seconds (Total)
##
##
## SAMPLING FOR MODEL 'predict_for_new_value.stan' NOW (CHAIN 2).
##
## Chain 2, Iteration: 1 / 11000 [ 0%] (Warmup)
## Chain 2, Iteration: 1100 / 11000 [ 10%] (Warmup)
## Chain 2, Iteration: 2200 / 11000 [ 20%] (Warmup)
## Chain 2, Iteration: 3300 / 11000 [ 30%] (Warmup)
## Chain 2, Iteration: 4001 / 11000 [ 36%] (Sampling)
## Chain 2, Iteration: 5100 / 11000 [ 46%] (Sampling)
## Chain 2, Iteration: 6200 / 11000 [ 56%] (Sampling)
## Chain 2, Iteration: 7300 / 11000 [ 66%] (Sampling)
## Chain 2, Iteration: 8400 / 11000 [ 76%] (Sampling)
## Chain 2, Iteration: 9500 / 11000 [ 86%] (Sampling)
## Chain 2, Iteration: 10600 / 11000 [ 96%] (Sampling)
## Chain 2, Iteration: 11000 / 11000 [100%] (Sampling)
## # Elapsed Time: 37.692 seconds (Warm-up)
## # 41.314 seconds (Sampling)
## # 79.006 seconds (Total)
# traceplot(fit, ask=F)
print(fit, digit=dig)
## Inference for Stan model: predict_for_new_value.stan.
## 2 chains, each with iter=11000; warmup=4000; thin=1;
## post-warmup draws per chain=7000, total post-warmup draws=14000.
##
## mean se_mean sd 2.5% 25% 50% 75%
## alpha 8.766 0.011 1.158 6.517 7.984 8.766 9.539
## beta 0.827 0.000 0.023 0.781 0.811 0.827 0.842
## s 7.027 0.001 0.158 6.726 6.920 7.025 7.133
## y_pred[1] 17.033 0.060 7.141 3.059 12.180 17.059 21.869
## y_pred[2] 25.271 0.060 7.139 11.256 20.476 25.225 30.096
## y_pred[3] 33.544 0.060 7.042 19.833 28.824 33.513 38.285
## y_pred[4] 41.809 0.060 7.101 27.824 37.000 41.840 46.603
## y_pred[5] 50.187 0.059 6.967 36.452 45.508 50.210 54.927
## y_pred[6] 58.286 0.061 7.173 43.950 53.530 58.264 63.147
## y_pred[7] 66.686 0.059 7.040 52.869 61.969 66.630 71.433
## y_pred[8] 74.836 0.060 7.054 60.811 70.085 74.839 79.580
## y_pred[9] 83.168 0.060 7.142 69.011 78.388 83.162 87.986
## lp__ -2471.575 0.034 2.513 -2477.454 -2473.018 -2471.205 -2469.751
## 97.5% n_eff Rhat
## alpha 11.059 10995 1
## beta 0.872 10968 1
## s 7.347 14000 1
## y_pred[1] 30.989 14000 1
## y_pred[2] 39.384 14000 1
## y_pred[3] 47.323 14000 1
## y_pred[4] 55.576 14000 1
## y_pred[5] 63.784 14000 1
## y_pred[6] 72.288 14000 1
## y_pred[7] 80.571 14000 1
## y_pred[8] 88.495 14000 1
## y_pred[9] 97.284 14000 1
## lp__ -2467.741 5517 1
##
## Samples were drawn using NUTS(diag_e) at Tue Sep 15 13:52:31 2015.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
############## 予測値の抽出
y_pred <- data.frame(rstan::extract(fit)$y_pred)
colnames(y_pred) <- paste0("x_", 1:9, rep(0,9))
############## 予測範囲を描画
data <- tidyr::gather(y_pred, x, y, seq_len(ncol(y_pred)))
ggplot(data, aes(x=x, y=y)) + geom_violin()

######################################## 予測範囲の計算
data %>% group_by(x) %>%
summarise(mean=mean(y), lower=quantile(y, probs=0.025),
upper=quantile(y, probs=0.975))
## Source: local data frame [9 x 4]
##
## x mean lower upper
## 1 x_10 17.03271 3.058869 30.98868
## 2 x_20 25.27052 11.255884 39.38365
## 3 x_30 33.54408 19.833171 47.32305
## 4 x_40 41.80897 27.824020 55.57608
## 5 x_50 50.18654 36.451935 63.78367
## 6 x_60 58.28620 43.950372 72.28828
## 7 x_70 66.68622 52.868889 80.57103
## 8 x_80 74.83564 60.811189 88.49461
## 9 x_90 83.16821 69.010776 97.28356
######################################## (参考) lm() による予測範囲の計算
predict(lm(y ~ x), newdata=data.frame(x=x_new),
interval = "prediction")
## fit lwr upr
## 1 17.03153 3.131567 30.93149
## 2 25.29911 11.449794 39.14843
## 3 33.56670 19.753380 47.38002
## 4 41.83428 28.042209 55.62636
## 5 50.10187 36.316215 63.88753
## 6 58.36946 44.575376 72.16354
## 7 66.63704 52.819719 80.45437
## 8 74.90463 61.049319 88.75994
## 9 83.17222 69.264297 97.08013
# x <- rnorm(N, mean = 50, sd = 10)
x <- rnorm(N, mean = 50, sd = 1)