Below are the derivatives of the Q() function with respect to the parameters of the multivariate normal. These are used for the EM algorithm gradient, which is used for the convergence criteria.
\(\frac{\partial{Q}}{\partial{\mu'}}=n[-\Sigma^{-1}(\mu'-x^*)]\)
\(\frac{\partial{Q}}{\partial\Sigma'}=-\frac{n}{2}[\Sigma'^{-1}(\Sigma'-A)\Sigma'^{-1}]\)
Where,
\(A=s^*-x^*\mu'^T+\mu'x^*+\mu'\mu'^T\)
The gradient of the log-likelihood will be
\(\frac{\partial{Q}}{\partial{\theta'}}|_{\theta'=\theta}=\nabla
l(\theta)\)
data = read.table('trivariatenormal.txt', h=T)
Expectation = function(mu, sig, data){
data = as.matrix(data)
n = nrow(data)
p = ncol(data)
xbar.star = matrix(0,nrow = p, ncol = 1)
s.star = matrix(0, nrow = p, ncol = p)
for(i in 1:n){
obs = which(!is.na(data[i,]))
miss = which(is.na(data[i,]))
mu.o = mu[obs] ## regular vector, no need for ","
mu.m = mu[miss]
sig.oo = matrix(sig[obs, obs], nrow =length(obs), ncol=length(obs))
sig.mm = matrix(sig[miss, miss], nrow=length(miss), ncol=length(miss))
sig.mo = matrix(sig[miss, obs], nrow=length(miss), ncol=length(obs))
sig.om = matrix(sig[obs, miss], nrow=length(obs), ncol=length(miss))
sig.ooI = solve(sig.oo)
yobs = as.matrix(data[i,obs]) ## observed data at current row
ymiss = as.matrix(data[i, miss]) ########### data type issue?
ym.star = mu.m + sig.mo %*% solve(sig.oo) %*% (yobs-mu.o)
e.yobs = yobs %*% t(yobs)
e.yoym = yobs %*% t(ym.star)
e.ymyo = ym.star %*% t(yobs)
e.ymym = sig.mm - sig.mo %*% sig.ooI %*% sig.om + ym.star %*% t(ym.star)
s.star.t = matrix(0, nrow = p, ncol = p)
xbar.star.t = matrix(0,nrow = p, ncol = 1)
s.star.t[obs,obs] = e.yobs
xbar.star[obs] = yobs ## Fully observed row
if(length(miss)>0){ ## There exists missing data in the current row.
mu.o = mu[obs] ## regular vector, no need for ","
mu.m = mu[miss]
sig.oo = matrix(sig[obs, obs], nrow =length(obs), ncol=length(obs))
sig.mm = matrix(sig[miss, miss], nrow=length(miss), ncol=length(miss))
sig.mo = matrix(sig[miss, obs], nrow=length(miss), ncol=length(obs))
sig.om = matrix(sig[obs, miss], nrow=length(obs), ncol=length(miss))
sig.ooI = solve(sig.oo)
s.star.t[miss,miss]= e.ymym
s.star.t[obs,miss] = e.yoym
s.star.t[miss,obs] = e.ymyo
ym.star = mu.m + sig.mo %*% sig.ooI %*% (yobs-mu.o)
xbar.star.t[miss] = ym.star
}
## combine xbar[obs] xbar[miss]
xbar.star = xbar.star + xbar.star.t
s.star = s.star + s.star.t
}
xbar.star = xbar.star/n
s.star = s.star/n
list(xbar.star=xbar.star, s.star=s.star)
}
mu = c(0,0,0)
sig = diag(3)
grad = function(mu, sig, data){
xbar.star = Expectation(mu,sig,data)$xbar.star
s.star = Expectation(mu,sig,data)$s.star
data = as.matrix(data)
n = nrow(data)
mu = as.matrix(mu)
sig = as.matrix(sig)
A = s.star - xbar.star %*% t(mu) + mu %*% t(xbar.star) + mu %*% t(mu)
dQdm = n*sum(diag(-solve(sig)%*%(mu-xbar.star))) ## sum(diag()) accounts for the trace in differential
dQds = -n/2 *sum(diag((solve(sig) %*%(sig-A)%*%solve(sig))))
logl = c(dQdm, dQds)
grad.norm=norm(logl, type='2')
list(logl=logl, grad.norm=grad.norm)
}
EM = function(mu, sig, data, maxit, tolgrad=1e-6){
data = as.matrix(data)
n = nrow(data)
p = ncol(data)
header = paste0("iteration", " mu1"," mu3"," sig[1,1]"," sig[1,3]"," grad norm")
print(header)
print(sprintf('%2.0f %8.8f %8.8f %8.8f %8.8f --',0, mu[1],mu[3],sig[1,1],sig[1,3]))
for(it in 1:maxit){
## E-step
xbar.star = Expectation(mu,sig,data)$xbar.star
s.star = Expectation(mu,sig,data)$s.star
grad.norm = grad(mu,sig,data)$grad.norm
logl = grad(mu,sig,data)$logl
## M-Step
A = s.star - xbar.star %*% t(mu) + mu %*% t(xbar.star) + mu %*% t(mu)
mu.hat = xbar.star
sig.hat = A
if (it == 1 | it == 2 | it== 3 | it == 19| it == 20 | it == 21) {
print(sprintf('%2.0f %8.8f %8.8f %8.8f %8.8f %.1e', it, mu.hat[1],mu.hat[3],sig.hat[1,1],sig.hat[1,3], grad.norm))
}
## Convergence criteria
if(grad.norm<tolgrad){
print("Converged")
break
}
## Update
mu = mu.hat
sig = sig.hat
}
print(mu)
print(sig)
}
EM(mu,sig, data, maxit=200)
## [1] "iteration mu1 mu3 sig[1,1] sig[1,3] grad norm"
## [1] " 0 0.00000000 0.00000000 1.00000000 0.00000000 --"
## [1] " 1 -0.00729167 0.14854167 1.77024375 5.09936667 1.9e+03"
## [1] " 2 -0.00729167 0.14854167 1.98061142 7.65996883 1.4e+01"
## [1] " 3 -0.00729167 0.14854167 2.01117656 8.22875621 2.2e+01"
## [1] "19 -0.00729167 0.14854167 2.02142393 8.47127107 2.2e-05"
## [1] "20 -0.00729167 0.14854167 2.02142394 8.47127116 8.5e-06"
## [1] "21 -0.00729167 0.14854167 2.02142394 8.47127120 3.3e-06"
## [1] "Converged"
## [,1]
## [1,] -0.007291667
## [2,] 0.047708333
## [3,] 0.148541667
## [,1] [,2] [,3]
## [1,] 2.021424 3.266909 8.471271
## [2,] 3.266909 8.732438 25.965679
## [3,] 8.471271 25.965679 82.528182