{rstan} で 「状態空間時系列分析入門」 を再現したい。やっていること / 数式はテキストを参照。
リポジトリ: https://github.com/sinhrks/stan-statespace
library(devtools)
# devtools::install_github('hoxo-m/pforeach')
# devtools::install_github('sinhrks/ggfortify')
library(rstan)
library(pforeach)
library(ggplot2)
ggplot2::theme_set(theme_bw(base_family="HiraKakuProN-W3"))
library(ggfortify)
# モデルが収束しているか確認
is.converged <- function(stanfit) {
summarized <- summary(stanfit)
all(summarized$summary[, 'Rhat'] < 1.1)
}
# 値がだいたい近いか確認
is.almost.fitted <- function(result, expected, tolerance = 0.001) {
if (abs(result - expected) > tolerance) {
print(paste('Result is ', result))
return(FALSE)
} else {
return(TRUE)
}
}
ukinflation <- read.table('../data/UKinflation.txt', skip = 1)
ukinflation <- ts(ukinflation[[1]], start = c(1950, 1), frequency = 4)
ukpulse <- rep(0, length.out = length(ukinflation))
ukpulse[4*(1975-1950)+2] <- 1
ukpulse[4*(1979-1950)+3] <- 1
ukpulse <- ts(ukpulse, start = start(ukinflation), frequency = frequency(ukinflation))
model_file <- '../models/fig07_04.stan'
cat(paste(readLines(model_file)), sep = '\n')
data {
int<lower=1> n;
vector[n] y;
vector[n] w;
}
parameters {
# 確率的レベル
vector[n] mu;
# 確率的季節項
vector[n] seasonal;
# 確定的係数
real lambda;
# レベル撹乱項
real<lower=0> sigma_level;
# 季節性撹乱項
real<lower=0> sigma_seas;
# 観測撹乱項
real<lower=0> sigma_irreg;
}
transformed parameters {
vector[n] yhat;
for(t in 1:n) {
yhat[t] <- mu[t] + lambda * w[t];
}
}
model {
# 式 7.1
# frequency = 4
for(t in 4:n) {
seasonal[t] ~ normal(-seasonal[t-3] - seasonal[t-2] - seasonal[t-1], sigma_seas);
}
for(t in 2:n)
mu[t] ~ normal(mu[t-1], sigma_level);
for(t in 1:n)
y[t] ~ normal(yhat[t] + seasonal[t], sigma_irreg);
}
y <- ukinflation
w <- ukpulse
standata <- within(list(), {
y <- as.vector(y)
w <- as.vector(w)
n <- length(y)
})
stan_fit <- stan(file = model_file, chains = 0)
##
## TRANSLATING MODEL 'fig07_04' FROM Stan CODE TO C++ CODE NOW.
## COMPILING THE C++ CODE FOR MODEL 'fig07_04' NOW.
fit <- pforeach(i = 1:4, .final = sflist2stanfit)({
stan(fit = stan_fit, data = standata,
iter = 2000, chains = 1, seed = i)
})
stopifnot(is.converged(fit))
yhat <- get_posterior_mean(fit, par = 'yhat')[, 'mean-all chains']
mu <- get_posterior_mean(fit, par = 'mu')[, 'mean-all chains']
seasonal <- get_posterior_mean(fit, par = 'seasonal')[, 'mean-all chains']
lambda <- get_posterior_mean(fit, par = 'lambda')[, 'mean-all chains']
sigma_irreg <- get_posterior_mean(fit, par = 'sigma_irreg')[, 'mean-all chains']
sigma_level <- get_posterior_mean(fit, par = 'sigma_level')[, 'mean-all chains']
sigma_seas <- get_posterior_mean(fit, par = 'sigma_seas')[, 'mean-all chains']
stopifnot(is.almost.fitted(sigma_irreg^2, 2.1990e-5))
stopifnot(is.almost.fitted(sigma_level^2, 1.8595e-5))
stopifnot(is.almost.fitted(sigma_seas^2, 0.0110e-5))
title <- paste('Figure 7.7.1. Local level (including pulse interventions) ',
'for UK inflation time series data.', sep = '\n')
title <- paste('図 7.7.1 英国インフレーション時系列データに',
'対するローカル・レベル(含むパルス干渉変数)', sep = '\n')
p <- autoplot(y)
yhat <- ts(yhat, start = start(y), frequency = frequency(y))
p <- autoplot(yhat, p = p, ts.colour = 'blue')
p + ggtitle(title)
title <- 'Figure 7.7.2. Local seasonal for UK inflation time series data.'
title <- '図 7.7.2 ローカル季節'
seasonal <- ts(seasonal, start = start(y), frequency = frequency(y))
autoplot(seasonal, ts.colour = 'blue') + ggtitle(title)
title <- 'Figure 7.7.3. Irregular for UK inflation time series data.'
title <- '図 7.7.3 不規則要素'
autoplot(y - yhat, ts.linetype = 'dashed') + ggtitle(title)