ガンマ分布の最尤推定法について調べた時のメモです。

基本的には下記の2つの文献を元にしています。

ガンマ分布

ガンマ分布は、次の確率密度関数をもつ確率分布です。

\[ \begin{align} p(x|a,b) = \frac{x^{a-1}}{\Gamma(a) b^a} \exp(-\frac{x}{b}) & & (x > 0) \end{align} \]

パラメータは次の2つです。

\(n\)個のデータ \(D\) が与えられたとき、ガンマ分布の対数尤度関数は次のようになります。

\[ \begin{align} \log p(D|a,b) &= (a-1)\sum_i \log x_i - n \log \Gamma(a) - n a \log b - \frac{1}{b} \sum_i x_i \\ &= n(a-1)\overline{\log x} - n \log\Gamma(a) - na \log b - \frac{n\bar{x}}{b} \end{align} \] 最尤推定の問題は次のように表されます。

「データ \(D\) が与えられたとき、対数尤度関数が最大となる \(a, b\) を求めよ」

基本アルゴリズム

対数尤度関数を \(b\) で偏微分します。

\[ \begin{align} \frac{\partial \log p}{\partial b} &= - \frac{na}{b} - \frac{-n \bar{x}}{b^2}\\ &= - \frac{na}{b} + \frac{n \bar{x}}{b^2} \end{align} \]

これが \(0\) となる \(b\) は次のようになります。

\[ \begin{align} \frac{na}{b} - \frac{n \bar{x}}{b^2} &= 0 \\ nab &= n\bar{x} \\ b &= \frac{\bar{x}}{a} \end{align} \] これより、基本的な最尤推定アルゴリズムは、次のようになります。

  1. \(a\)\(b\) の初期値を決める
  2. 尤度の変化が十分小さくなるまで以下を繰り返す
    1. \(a \leftarrow \mathrm{argmax}_a \log(D | a, b)\)
    2. \(b \leftarrow \bar{x} / a\)

したがって、最尤推定の問題は、\(\hat{b} = \frac{\bar{x}}{a}\) とおいたときの対数尤度関数

\[ \begin{align} \log p(D|a,\hat{b}) &= n(a-1)\overline{\log x} - n \log\Gamma(a) - na \log \bar{x} + na \log a - na \end{align} \]

を最大にする \(a\) をどうやって求めるかという問題に帰着されます。

一次近似下限最大化法

この問題を解くための方法の1つに一次近似下限最大化法があります。

まず、対数尤度関数の \(na\log a\) という項に注目します。

\[ \begin{align} f(a) &= a\log a \\ f'(a) &= \log a + 1 \\ f''(a) &= 1/a \end{align} \]

\(a > 0\) なので \(f(x)\) は下に凸な関数になります。したがって、一次近似(接線)は \(f(a)\) の下限となります。

\[ \begin{align} f(a) &\geq (\log a_0 + 1)(a - a_0) + a_0 \log a_0 \\ &= a \log a_0 - a_0 \log a_0 + a - a_0 + a_0 \log a_0 \\ &= a \log a_0 + a - a_0 \\ \end{align} \]

これを使うと対数尤度関数の下限が次のように表されます。

\[ \begin{align} \log p(D|a,\hat{b}) &\geq n(a-1)\overline{\log x} - n \log\Gamma(a) - na \log \bar{x} + n(a \log a_0 + a - a_0) + na_0 \log a_0 - na \\ &= n(a-1)\overline{\log x} - n \log\Gamma(a) - na \log \bar{x} + na \log a_0 + na - na_0 + na_0 \log a_0 - na \\ &= n(a-1)\overline{\log x} - n \log\Gamma(a) - na \log \bar{x} + na \log a_0 - na_0 + na_0 \log a_0 \\ \end{align} \]

この下限の微分が \(0\) となる \(a\) は次のようになります。

\[ \begin{align} 0 &= n \overline{\log x} - n \Psi(a) - n\log \bar{x} + n \log a_0 \\ 0 &= \overline{\log x} - \Psi(a) - \log \bar{x} + \log a_0 \\ \Psi(a) &= \overline{\log x} - \log \bar{x} + \log a_0 \\ a &= \Psi^{-1} (\overline{\log x} - \log \bar{x} + \log a_0) \\ \end{align} \]

ただし、\(\Psi(a)\) はディガンマ関数です。

基本アルゴリズム中の \(a\) の更新式としてこの式を使うのが一次近似下限最大化法です。

局所近似最大化法

一次近似下限最大化法に対して、より高速な方法が Minka (2002) で提案されています。

まず、対数尤度関数を次のような局所近似式で表します。

\[ \log p(D|a,\hat{b}) \simeq c_0 + c_1 a + c_2 \log a \]

ただし、

\[ \begin{align} c_1 &= p'(a_0) - \frac{c_2}{a_0} \\ c_2 &= -a_0^2 p''(a_0) \end{align} \]

この局所近似式の詳細については Minka (2000) を参照して下さい。

この局所近似式の一階微分と二階微分は次のようになります。

\[ \begin{align} g(a) &= c_0 + c_1 a + c_2 \log(a) \\ g'(a) &= c_1 + \frac{c_2}{a} \\ g''(a) &= - \frac{c_2}{a^2} \end{align} \]

ここで、対数尤度関数は上に凸なので、\(p''(a) \leq 0\) となります。 したがって、\(c_2 \geq 0\) であるので、\(g(a)\) も上に凸であることがわかります。 \(g(a)\) を最大化する \(a\)\(g'(a) = 0\) となる \(a\) なので、

\[ \begin{align} c_1 + \frac{c_2}{a} &= 0 \\ a &= - \frac{c_2}{c_1} \\ \end{align} \]

として求められます。\(c_1\)\(c_2\) を展開すると、

\[ \begin{align} a &= - \frac{c_2}{c_1} \\ \frac{1}{a} &= -\frac{p'(a_0) - \frac{c_2}{a_0}}{c_2} \\ &= -\frac{p'(a_0)}{c_2} + \frac{1}{a_0} \\ &= -\frac{p'(a_0)}{-a_0^2 \ p''(a_0)} + \frac{1}{a_0} \\ &= \frac{1}{a_0} + \frac{p'(a_0)}{a_0^2 \ p''(a_0)} \\ \end{align} \]

となります。さらに \(p'(a)\)\(p''(a)\) は、

\[ \begin{align} p'(a) &= n\overline{\log x} -n \Psi(a) - n \log \bar{x} +n\log a \\ p''(a) &= -n \Psi'(a) + \frac{n}{a} \\ \end{align} \]

であるため、これを代入して

\[ \begin{align} \frac{1}{a} &= \frac{1}{a_0} + \frac{p'(a_0)}{a_0^2 \ p''(a_0)} \\ &= \frac{1}{a_0} + \frac{n\overline{\log x} -n \Psi(a_0) - n \log \bar{x} +n\log a_0}{a_0^2 ( -n \Psi'(a_0) + \frac{n}{a_0})} \\ &= \frac{1}{a_0} + \frac{\overline{\log x} - \Psi(a_0) - \log \bar{x} + \log a_0}{a_0^2(\frac{1}{a_0} - \Psi'(a_0))} \end{align} \]

となります。

基本アルゴリズム中の \(a\) の更新式としてこの式を使うのが局所近似最大化法です。

R による実装

R による実装は次のようになります。

library(distr)

estimate_gamma <- function(values, method = c("FOApprox", "LocalApprox")) {
  method <- match.arg(method)
  
  a <- 0.5 / (log(mean(values)) - mean(log(values)))
  b <- mean(values) / a
  
  epsilon <- Inf
  log_lik_old <- sum(log(dgamma(values, a, b)))
  while (epsilon > 1e-5) {
    mv <- mean(values)
    if (method == "FOApprox") {
      a <- igamma( mean(log(values)) - log(mean(values)) + log(a))
    } else {
      a <- 1/ (1/a + (mean(log(values)) - digamma(a) - log(mv) + log(a)) / (a^2 * (1/a - trigamma(a))))
    }
    b <- mv / a
    
    log_lik <- sum(log(dgamma(values, a, b)))
    epsilon <- abs(log_lik - log_lik_old)
    log_lik_old <- log_lik
  }
  c(a = a, b = b)
}

関数 estimate_gamma()method 引数を変えることで、一次近似下限最大化法 ("FOApprox") と局所近似最大化法 ("LocalApprox") を切り替えることができます。

\(a = 3, b = 2\) のガンマ分布からデータを生成して、パラメータの推定を行ってみます。

set.seed(314)
N <- 500
values <- rgamma(N, 3, scale = 2)

estimate_gamma(values, method = "FOApprox")
#>        a        b 
#> 2.939870 2.093131

estimate_gamma(values, method = "LocalApprox")
#>        a        b 
#> 2.939866 2.093133

うまく推定できたようです。

速度比較をやってみます。

library(microbenchmark)

microbenchmark(
  times = 30,
  FOApp = estimate_gamma(values, method = "FOApprox"),
  Local = estimate_gamma(values, method = "LocalApprox")
)
#> Unit: microseconds
#>   expr      min        lq       mean     median        uq       max neval
#>  FOApp 17523.28 18330.726 22440.4561 19970.5615 21480.818 89015.202    30
#>  Local   705.81   742.898   826.8821   779.8145   855.093  1473.443    30

局所近似最大化法は、一次近似下限最大化法よりも 25倍程度速いようです。

まとめ

ガンマ分布の最尤推定アルゴリズムとして、

の2つを導出しました。

また、Rによる実装を行い、後者が前者より 25倍高速であることを示しました。

後者が高速な理由として、Minka (2000) では、局所近似の方がよい近似であり、更新幅が大きいということが説明されています。

上図において、青線が目的関数、緑線が二次近似、赤線が局所近似ですが、赤線の方が緑線よりも目的関数に近い値をとっており、青丸から更新する際に赤線の最大値に更新する方が緑線の最大値に更新するよりも目的関数の最大値に速く到達するということです。

実際に使用する場合には \(a\)\(b\) の初期値を決める必要があります。 これについては、Minka (2002) に次の初期値が良いと書かれており、上記の実装ではこれを使用しています。

\[ \begin{align} a_0 &= \frac{0.5}{ \log \bar{x} - \overline{\log x}} \\ b_0 &= \frac{\bar{x}}{a_0} \end{align} \]

(注: イェンセンの不等式より \(\log \bar{x} \geq \overline{\log x}\))

参考文献