{rstan}「状態空間時系列分析入門」 を再現したい。やっていること / 数式はテキストを参照。

リポジトリ: https://github.com/sinhrks/stan-statespace

必要パッケージのインストール/読み込み

library(devtools)
# devtools::install_github('hoxo-m/pforeach')
# devtools::install_github('sinhrks/ggfortify')
library(rstan)
library(pforeach)
library(ggplot2)
ggplot2::theme_set(theme_bw(base_family="HiraKakuProN-W3"))
library(ggfortify)

関数定義

# モデルが収束しているか確認
is.converged <- function(stanfit) {
  summarized <- summary(stanfit)  
  all(summarized$summary[, 'Rhat'] < 1.1)
}

# 値がだいたい近いか確認
is.almost.fitted <- function(result, expected, tolerance = 0.001) {
  if (abs(result - expected) > tolerance) {
    print(paste('Result is ', result))
    return(FALSE)
  } else {
    return(TRUE)
  }
}

データの読み込み

ukdrivers <- read.table('../data/UKdriversKSI.txt', skip = 1)
ukdrivers <- ts(ukdrivers[[1]], start = c(1969, 1), frequency = 12)
ukdrivers <- log(ukdrivers)
ukpetrol <- read.table('../data/logUKpetrolprice.txt', skip = 1)
ukpetrol <- ts(ukpetrol, start = start(ukdrivers), frequency = frequency(ukdrivers))
ukseats <- c(rep(0, (1982 - 1968) * 12 + 1), rep(1, (1984 - 1982) * 12 - 1))
ukseats <- ts(ukseats, start = start(ukdrivers), frequency = frequency(ukdrivers))

モデル定義

model_file <- '../models/fig07_02.stan'
cat(paste(readLines(model_file)), sep = '\n')
data {
  int<lower=1> n;
  vector[n] y;
  vector[n] x;
  vector[n] w;
}
parameters {
  # 確率的レベル
  vector<lower=mean(y)-3*sd(y), upper=mean(y)+3*sd(y)>[n] mu;
  # 確率的季節項
  vector[n] seasonal;
  # 確定的回帰係数
  real beta;
  # 確定的係数
  real lambda;
  # レベル撹乱項
  real<lower=0> sigma_level;
  # 季節性撹乱項
  real<lower=0> sigma_seas;
  # 観測撹乱項
  real<lower=0> sigma_irreg;
}
transformed parameters {
  vector[n] yhat;
  for(t in 1:n) {
    yhat[t] <- mu[t] + beta * x[t] + lambda * w[t];
  }
}
model {
  # 式 7.1

  # frequency = 12
  for(t in 12:n) {
    seasonal[t] ~ normal(-seasonal[t-11] - seasonal[t-10] - seasonal[t-9] - seasonal[t-8] - seasonal[t-7] - seasonal[t-6] - seasonal[t-5] - seasonal[t-4] - seasonal[t-3] - seasonal[t-2] - seasonal[t-1], sigma_seas);
  }
  for(t in 2:n)
    mu[t] ~ normal(mu[t-1], sigma_level);
  for(t in 1:n)
    y[t] ~ normal(yhat[t] + seasonal[t], sigma_irreg);

  # beta ~ normal(0, 2);
  # lambda ~ normal(mean(y) / mean(x), 2);
  sigma_level ~ inv_gamma(0.001, 0.001);
  sigma_seas ~ inv_gamma(0.001, 0.001);
  sigma_irreg ~ inv_gamma(0.001, 0.001);
}
y <- ukdrivers
x <- ukpetrol
w <- ukseats

standata <- within(list(), {
  y <- as.vector(y)
  x <- as.vector(x)
  w <- as.vector(w)
  n <- length(y)
})
lmresult <- lm(y ~ x, data = data.frame(x = 1:length(y), y = as.numeric(y)))
init <- list(list(mu = rep(mean(y), length(y)), seasonal = rep(0, length(y)),
                  beta = coefficients(lmresult)[[2]], lambda = mean(y) / mean(x),
                  sigma_level = sd(y) / 2, sigma_irreg = 0.001))

stan_fit <- stan(file = model_file, chains = 0)
## 
## TRANSLATING MODEL 'fig07_02' FROM Stan CODE TO C++ CODE NOW.
## COMPILING THE C++ CODE FOR MODEL 'fig07_02' NOW.
fit <- pforeach(i = 1:4, .final = sflist2stanfit)({
  stan(fit = stan_fit, data = standata, 
       iter = 8000, chains = 1, seed = i)
})
stopifnot(is.converged(fit))

yhat <- get_posterior_mean(fit, par = 'yhat')[, 'mean-all chains']
mu <- get_posterior_mean(fit, par = 'mu')[, 'mean-all chains']
seasonal <- get_posterior_mean(fit, par = 'seasonal')[, 'mean-all chains']
beta <- get_posterior_mean(fit, par = 'beta')[, 'mean-all chains']
lambda <- get_posterior_mean(fit, par = 'lambda')[, 'mean-all chains']
sigma_irreg <- get_posterior_mean(fit, par = 'sigma_irreg')[, 'mean-all chains']
sigma_level <- get_posterior_mean(fit, par = 'sigma_level')[, 'mean-all chains']
sigma_seas <- get_posterior_mean(fit, par = 'sigma_seas')[, 'mean-all chains']

stopifnot(is.almost.fitted(sigma_irreg^2, 0.00378629))
stopifnot(is.almost.fitted(sigma_level^2, 0.000267632))
stopifnot(is.almost.fitted(sigma_seas^2, 0.0000011622))
title <- paste('Figure 7.2. Stochastic level plus variables',
               'log petrol price and seat belt law.', sep = '\n')
title <- paste('図 7.2 確率的レベルプラス対数石油価格と',
               'シートベルト法', sep = '\n')

p <- autoplot(y)
yhat <- ts(yhat, start = start(y), frequency = frequency(y))
p <- autoplot(yhat, p = p, ts.colour = 'blue')
p + ggtitle(title)

title <- 'Figure 7.3. Stochastic seasonal.'
title <- '図 7.3 確率的季節要素'
seasonal <- ts(seasonal, start = start(y), frequency = frequency(y))
autoplot(seasonal, ts.colour = 'blue') + ggtitle(title)

title <- 'Figure 7.4. Irregular component for stochastic level and seasonal model.'
title <- '図 7.4 確率的レベルと季節モデルに対する不規則要素'
autoplot(y - yhat, ts.linetype = 'dashed') + ggtitle(title)