Abstract
わかりやすいパターン認識5章の勉強メモ。require(UsingR)
require(ggplot2)
require(tidyr)
require(dplyr)
require(plotly)
require(coefplot)
問題設定:\(c\)個の見分けの付かないコインが箱の中にたくさん入っている。これらのコインはそれぞれ表が出る確率が異なる。箱の中からコインを取り出し投げ、表裏を確認してからまた箱に戻すことを繰り返す。\(n\)回試行を繰り返す。箱の中のコインの割合\((\pi_1, \pi_2, \pi_3)\)と各種コインの表が出る確率\(\theta_1,\theta_2,\theta_3\)を最尤推定せよ。
真のパラメータを\((\pi_1, \pi_2, \pi_3)=(0.1,0.4, 0.5)\), \(\theta_1=0.8,\theta_2=0.6,\theta_3=0.3\)としてEMアルゴリズムで最尤推定をする。
# 問題設定
# set.seed(10)
n <- 10000
piv <- c(0.1, 0.4, 0.5)
theta <- c(0.8, 0.6, 0.3)
ni <- as.vector( rmultinom(1,n, piv) ) # 各コインが出る回数
n11 <- as.vector( rmultinom(1, ni[1], c(theta[1], 1 - theta[1]) ) )[1]
n21 <- as.vector( rmultinom(1, ni[2], c(theta[2], 1 - theta[2]) ) )[1]
n31 <- as.vector( rmultinom(1, ni[3], c(theta[3], 1 - theta[3]) ) )[1]
ni1 <- c(n11,n21,n31) # コインiが表を出す回数
r1 <- sum(ni1)
r2 <- n - r1
r <- c(r1, r2)
# EMアルゴリズムの初期値設定
tmp <- runif(3)
piinit <- tmp / sum(tmp)
thetainit <- runif(3)
Estep <- function(pi0, theta0){
gamma_ik <- array(1:6, dim=c(3,2))
for (i in 1:3) {
gamma_ik[i,1] <- pi0[i] * theta0[i] / sum(pi0 * theta0) # P(omega_i | head)
gamma_ik[i,2] <- pi0[i] * (1 - theta0[i] ) / sum(pi0 * (1 - theta0)) # P(omega_i | tail)
}
return(gamma_ik)
}
Mstep <- function(gamma_ik){
pinew <- c(1:3)
for (i in 1:3){
pinew[i] <- 1 / n * sum(r * gamma_ik[i, 1:2])
}
thetanew <- c(1:3)
for (i in 1:3){
thetanew[i] <- r[1] * gamma_ik[i, 1] / sum( r * gamma_ik[i,1:2] )
}
return(list(pinew=pinew, thetanew=thetanew))
}
niter <- 50
logPv <- 1:niter
dfpi <- data.frame(matrix(rep(NA,3),nrow=1))[-1,]
names(dfpi) <- c("iter", "coin", "pi")
dflogL <- data.frame(matrix(rep(NA,2),nrow=1))[-1,]
names(dflogL) <- c("iter", "logL")
piold <- piinit
thetaold <- thetainit
for (iter in 1:niter){
gamma_ik <- Estep(piold, thetaold)
Mresult <- Mstep(gamma_ik)
pinew <- Mresult$pinew
thetanew <- Mresult$thetanew
piold <- pinew
thetanew <- thetaold
# calc of likelihood
Phead <- sum(piold * thetaold)
Ptail <- sum(piold * (1 - thetaold))
logL <- r[1] * log(Phead) + r[2] * log(Ptail)
# update data.frame
dflogL <- rbind(dflogL, c(iter, logL))
dfpi <- rbind(dfpi, c(iter, 1, pinew[1]))
dfpi <- rbind(dfpi, c(iter, 2, pinew[2]))
dfpi <- rbind(dfpi, c(iter, 3, pinew[3]))
}
names(dfpi) <- c("iter", "coin", "pi")
dfpi$coin <- as.factor(dfpi$coin)
names(dflogL) <- c("iter", "logL")
logLanal <- sum(log(r/n) *r)
ggplot(dflogL) + geom_line(aes(x=iter, y=logL)) + geom_hline(yintercept = logLanal, linetype="dashed")
ggplot(dfpi) + geom_line(aes(x=iter, y=pi, group=coin, colour=coin)) +
geom_hline(yintercept = piv[1], linetype="dashed") + geom_hline(yintercept = piv[2], linetype="dashed" ) + geom_hline(yintercept = piv[3], linetype="dashed")
logLを最大にする解は無数にあるので、実行する度に収束する\(\pi_1,\pi_2,\pi_3\)は異なるが、毎回正しいlogLに収束する。なぜならばこの問題の対数尤度は\(\pi\)に関して上凸だからである。