このページのEMアルゴリズムについての実装は、以下の書籍にある実行例3.4を参考にPythonとJuliaで書き直したものです。ですので、問題設定や学習に利用した初期値はその記載内容に準拠しています。データは、書籍の手順に従ってR言語の乱数生成で得たデータをem.csvとして保存したものを利用しました。値の正確性については、一応R、Python、Juliaそれぞれによる計算結果が一致するところまで確認しています。なお、書籍の記載内容の原文ママになってしまうのは問題なので、Rによる実装はこのページには記載しておりません。
warningsの処理options(warn=-1)
options(digits=16)
import warnings
warnings.simplefilter('ignore')
.csv化# 定数
N <- 1000
pi0 <- 0.6
pi1 <- 1 - pi0
mu0 <- 3
mu1 <- 0
sig0 <- 0.5
sig1 <- 3
attr <- sample(0:1, N, replace = T, prob = c(pi0, pi1))
x <- rep(0, N)
x[which(attr == 0)] <- rnorm(length(which(attr == 0)), mu0, sig0)
x[which(attr == 1)] <- rnorm(length(which(attr == 1)), mu1, sig1)
data <- as.data.frame(x)
write.csv(data, "em.csv", row.names=FALSE)
Pythonによる実装# インポート
import numpy as np
import pandas as pd
from scipy.stats import norm
# 実験のために設定する初期値
N = 1000
pi0 = 0.5
pi1 = 1 - pi0
mu0 = 5
mu1 = -5
sig0 = 1.0
sig1 = 5
# データの読み込み
df = pd.read_csv("em.csv", dtype={"x": np.float64})
x = df.values
# 10回だけ実験
for ite in range(10):
piN0 = pi0 * norm.pdf(x, mu0, sig0)
piN1 = pi1 * norm.pdf(x, mu1, sig1)
qn0 = piN0 / (piN0 + piN1)
qn1 = piN1 / (piN0 + piN1)
pi0 = np.sum(qn0) / N
pi1 = np.sum(qn1) / N
mu0 = qn0.T @ x / (N * pi0)
mu1 = qn1.T @ x / (N * pi1)
sig0 = np.sqrt(np.sum(qn0 * (x - mu0) * (x - mu0)) / (N * pi0))
sig1 = np.sqrt(np.sum(qn1 * (x - mu1) * (x - mu1)) / (N * pi1))
print(pi0, pi1, mu0, mu1, sig0, sig1)
## 0.630393622123659 0.3696063778763408 [[2.98208769]] [[-0.17487378]] 0.48942260441041313 3.1314749321867934
Juliaによる実装やたらとglobalを使っている理由については、以下の記事を参照してほしいです。Rstudio上でJuliaを利用する場合Julia1.5以前のスコープの仕様のままなのでエラーになってしまうのであわてて付け加えたものになっています。あとで機会を見て直したいと思います。
# インポート
using CSV
using DataFrames
using Distributions
using LinearAlgebra
# 実験のために設定する初期値
N = 1000
## 1000
pi0 = 0.5
## 0.5
pi1 = 1 - pi0
## 0.5
mu0 = 5
## 5
mu1 = -5
## -5
sig0 = 1.0
## 1.0
sig1 = 5
## 5
# データの読み込み
df = CSV.read("em.csv", DataFrame)
## 1000×1 DataFrame
## Row │ x
## │ Float64
## ──────┼───────────
## 1 │ -0.516415
## 2 │ -1.35378
## 3 │ 3.2327
## 4 │ 2.8465
## 5 │ -5.45566
## 6 │ 3.01359
## 7 │ -1.24304
## 8 │ 3.56958
## ⋮ │ ⋮
## 994 │ 3.51346
## 995 │ -5.91
## 996 │ -3.54468
## 997 │ -3.10714
## 998 │ 2.40271
## 999 │ 2.48333
## 1000 │ 2.42434
## 985 rows omitted
x = df[:, :x];
for ite in 1:10
piN0 = pi0 * pdf.(Normal(mu0, sig0), x)
piN1 = pi1 * pdf.(Normal(mu1, sig1), x)
qn0 = piN0 ./ (piN0 .+ piN1)
qn1 = piN1 ./ (piN0 .+ piN1)
global pi0 = sum(qn0) / N
global pi1 = sum(qn1) / N
global mu0 = dot(qn0, x) / (N * pi0)
global mu1 = dot(qn1, x) / (N * pi1)
global sig0 = sqrt(sum(qn0 .* (x .- mu0).^2) / (N * pi0))
global sig1 = sqrt(sum(qn1 .* (x .- mu1).^2) / (N * pi1))
end
print(pi0, " ", pi1, " ", mu0, " ", mu1, " ", sig0, " ", sig1)
## 0.6303936221236591 0.3696063778763409 2.9820876873258224 -0.17487378059631653 0.489422604410413 3.131474932186793