On this report we denote \(D\) as the observed-part or extant-species, \(+_i\) as the missing-part or extinct-species of the tree and \(D^+\) is then the complete phylogenetic tree.
The EM algorithms consists on two steps. First, we calculate the conditional expectation:
\[ Q(\theta|\theta^*) = E_{\theta^* } [log P(D^+|\theta) | D] \]
and then we perform the maximization:
\[ \theta ^{**} = argmax_{(\theta)} Q( \theta | \theta ^*) \]
Given that the calculation of the conditional expectation is really hard (if not impossible), we use an approximation, sampling complete-phylogenies under a montecarlo approach. This simulations should be sampled from (real density)
\[ f_{\theta^*} (+_i | D) \]
But instead we sample it from
\[ g_{\theta^*}(+_i | D) \]
To correct this we re-weigh the approximation of the expectation by importance scaling:
\[w_i = \frac{f_{\theta^*} (+_i | D)}{g_{\theta^*}(+_i | D)} = \frac{f_i}{g_i}\]
Thus, the montecarlo re-weighted approximation has the form
\[ E_{\theta^* } [log P(D^+|\theta) | D] \approx \frac{1}{N} \sum^{N}_{i=1} log P(D_i^+ | \theta) \frac{f_i}{g_i}\]
\(f_i\) is the density function (likelihood) of the complete phylogenetic tree.
to calculate \(g_i\) we have 2 ways so far. The first one is related with the diagram above
we multiply the corresponding probabilities to have every outcome, this is done on the piece of code bellow
if (t_spe < cwt){
t_ext = rexp(1,mu0)
t_ext = cbt + t_spe + t_ext
if (t_ext < ct){
up = update_tree(wt=wt,t_spe = (cbt + t_spe), t_ext = t_ext, E = E, n = n)
E = up$E
n = up$n
wt = up$wt
fake = FALSE
prob[i] = dexp(t_ext,rate=mu0,log=TRUE) + dexp(t_spe,rate=s,log = TRUE)
}else{
prob[i] = pexp(q = ct,rate = mu0,lower.tail = F,log.p = TRUE) + dexp(t_spe,rate=s, log = TRUE)
fake = TRUE
i = i-1
}
}else{
fake = FALSE
prob[i] = pexp(q = cwt,rate = s,lower.tail = F,log.p = TRUE)
}
Note that we are working with multiplication of many densities, which in the end are very small values. To avoid cutting most numbers to zero we consider the following fact:
\[ w_i = \frac{f_i}{g_i} = e^{log(f_i) - log(g_i)} = e^{loglik - \sum_j log(g_{i,j})} \]
and that is why on the code above the log of the densities is calculated. Considering that we calculate the weight in the following way
f_n = -llik(b=pars,n=n,E=E,t=wt)
logweight = f_n-sum(prob)
return(list(wt=wt,E=E,n=n,weight=logweight,L=L,g=prob,f_n=f_n))
another way to calculate \(g\) is considering the probabilities of not having a new species on the interval \(\Delta t_i\)
\[ \int_0^{\Delta t_i} s_{\lambda_i}e^{-s_{\lambda_i} t}[1-e^{-\mu (r_i-t)}] dt = s_{\lambda_i}\int_0^{\Delta t_i}e^{-s_{\lambda_i} t}-e^{ t(\mu-s_{\lambda_i})-\mu r_i} = s_{\lambda_i}\int_0^{\Delta t_i}e^{-s_{\lambda_i} t}dt - s_{\lambda_i} e^{-\mu r_i} \int_0^{\Delta t_i}e^{ t(\mu-s_{\lambda_i})} dt \] \[= 1-e^{-s_{\lambda_i} \Delta t} -\frac{s_{\lambda_i} e^{-\mu r_i}}{\mu - s_{\lambda_i}}[e^{\Delta t(\mu-s_{\lambda_i})}-1] \] so we do it with this function
convol <-function(wt,lambda,mu,remt){
out = 1-exp(-lambda*wt)-lambda*exp(-mu*remt)/(mu-lambda)*(exp(wt*(mu-lambda))-1)
return(out)
}
and we add this option to the code
if (t_spe < cwt){
t_ext = rexp(1,mu0)
t_ext = cbt + t_spe + t_ext
if (t_ext < ct){
up = update_tree(wt=wt,t_spe = (cbt + t_spe), t_ext = t_ext, E = E, n = n)
E = up$E
n = up$n
wt = up$wt
fake = FALSE
prob[i] = dexp(t_ext,rate=mu0,log=TRUE) + dexp(t_spe,rate=s,log = TRUE)
}else{
prob[i] = pexp(q = ct,rate = mu0,lower.tail = F,log.p = TRUE) + dexp(t_spe,rate=s,log = TRUE)
if(v2){ prob[i] = log(convol(wt = t_spe,lambda = s,mu = mu,remt = ct-cbt))}
fake = TRUE
i = i-1
}
}else{
fake = FALSE
prob[i] = pexp(q = cwt,rate = s,lower.tail = F,log.p = TRUE)
if(v2){prob[i] = log(convol(wt = t_spe,lambda = s,mu = mu,remt = ct-cbt))}
}
To have a better idea on how this implementation look like we do some simple observations.
We start simulating a whole tree
s <- sim_phyl()
pars = c(1.2,0.4,70)
subplex(par = pars,fn = llik,n = s$n, E = s$E, t = s$wt)
## $par
## [1] 0.5808922 0.1227656 38.6221800
##
## $value
## [1] 209.8768
##
## $counts
## [1] 598
##
## $convergence
## [1] 1
##
## $message
## [1] "limit of machine precision reached"
##
## $hessian
## NULL
Them we drop extinct species
s2 <- drop.fossil(s$newick)
we transform phylo format into branching-times, number of species, and topology vectors
s2 <- phylo2p(s2)
Now we simulate a complete tree (extinct+extans species) based on the observed tree (extant species only)
s3 <- rec_tree(wt=s2$wt)
and we observe the calculated weight for this tree
s3$weight
## [1] -133.4394
it seems very small
This is a random process, we do it now for many iterations to have an idea of the variability of the weight
l = proc.time()
w=0
nLTT=0
n_it = 100000
st=sim_srt(wt=s2$wt,pars=c(0.8,0.0175,0.1),n_trees = n_it,parallel = TRUE)
for(i in 1:n_it){
st3 = st[[i]]
w[i] = st3$weight
rec = p2phylo(st3)
nLTT[i] = ltt_stat(s$newick, rec)
}
print(proc.time()-l)
## user system elapsed
## 16950.836 76.196 19992.056
su=summary(w)
su
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -Inf -194.90 -175.60 -Inf -156.60 -90.27
boxplot(w)
q3 = w>su[5]
boxplot(w[q3])
They seems very small to be the logarithm of the weight, we centreate it to avoid numerical issues
maxw = max(w)
w = w - max(w)
boxplot(w)
## Warning in bplt(at[i], wid = width[i], stats = z$stats[, i], out = z$out[z
## $group == : Outlier (-Inf) in boxplot 1 is not drawn
points(rep(1,length(w)),w)
now we would like to check if trees with larger weight are really ‘’better trees’’. To do that we check the nLTT statistic
qplot(w,nLTT)
qplot(w[q3],nLTT[q3])
Now let´s check the real weights, that is
w = exp(w)
qplot(w,nLTT)
phyl = st[[which(w==1)]]
pars = c(1.2,0.4,70)
p = proc.time()
subplex(par = pars, fn = llik_st, setoftrees = list(phyl), impsam = F)$par
## [1] 0.41878142 0.06194581 45.45618744
proc.time()-p
## user system elapsed
## 0.024 0.000 0.023
ltt_stat(p2phylo(phyl),s$newick)
## [1] 35.76605
p = proc.time()
subplex(par = pars, fn = llik_st, setoftrees = st, impsam = F)$par
## [1] 0.5894714 0.1370194 36.8428328
proc.time()-p
## user system elapsed
## 8744.432 28.244 8776.124
p=proc.time()
subplex(par = pars, fn = llik_st, setoftrees = st, impsam = T)$par
## [1] 0.44339385 0.06755712 41.77933148
proc.time() - p
## user system elapsed
## 6833.556 31.768 6867.890
print(which(w > 0.1))
## [1] 45174 89449
tr1 = which(w > 0.1)[[1]]
print(tr1)
## [1] 45174
phyl2 = st[[tr1]]
p=proc.time()
subplex(par = pars, fn = llik_st, setoftrees = list(phyl2), impsam = F)$par
## [1] 0.5023233 0.0772496 36.2343960
proc.time() - p
## user system elapsed
## 0.020 0.000 0.019
tr2 = which(w > 0.1)[[2]]
print(tr2)
## [1] 89449
phyl2 = st[[tr2]]
p=proc.time()
subplex(par = pars, fn = llik_st, setoftrees = list(phyl2), impsam = F)$par
## [1] 0.41878142 0.06194581 45.45618744
proc.time() - p
## user system elapsed
## 0.020 0.000 0.019
tr = lapply(st, function(x) exp(x$weight-maxw) > 0.1)
phyl2 = st[unlist(tr)]
#phyl2
p=proc.time()
subplex(par = pars, fn = llik_st, setoftrees = phyl2, impsam = F)$par
## [1] 0.45654104 0.06993942 40.24698027
proc.time() - p
## user system elapsed
## 0.044 0.000 0.045
p=proc.time()
subplex(par = pars, fn = llik_st, setoftrees = phyl2, impsam = T)$par
## [1] 0.44301440 0.06729018 41.85387676
proc.time() - p
## user system elapsed
## 0.040 0.004 0.047
tr = lapply(st, function(x) exp(x$weight-maxw) > 0.01)
phyl2 = st[unlist(tr)]
#phyl2
p=proc.time()
subplex(par = pars, fn = llik_st, setoftrees = phyl2, impsam = F)$par
## [1] 0.4520046 0.0707827 40.3960017
proc.time() - p
## user system elapsed
## 0.052 0.000 0.050
p=proc.time()
subplex(par = pars, fn = llik_st, setoftrees = phyl2, impsam = T)$par
## [1] 0.44304223 0.06745399 41.81773025
proc.time() - p
## user system elapsed
## 0.064 0.000 0.065
p=proc.time()
pars = c(1.2,0.4,70)
subplex(par = pars, fn = llik_st, setoftrees = list(s), impsam = F)$par
## [1] 0.5808922 0.1227656 38.6221800
proc.time() - p
## user system elapsed
## 0.020 0.000 0.022
plot(s$newick)
p=proc.time()
pars = c(1.2,0.4,70)
subplex(par = pars,fn = llik,n = s$n, E = s$E, t = s$wt)
## $par
## [1] 0.5808922 0.1227656 38.6221800
##
## $value
## [1] 209.8768
##
## $counts
## [1] 598
##
## $convergence
## [1] 1
##
## $message
## [1] "limit of machine precision reached"
##
## $hessian
## NULL
proc.time() - p
## user system elapsed
## 0.012 0.004 0.017