#Bayesian analysis--------------------------------------------------------
#define Dirichlet density
ddirichlet = function(a,b,c) {
  x1 = rgamma(1, a, 1)
  x2 = rgamma(1, b, 1)
  x3 = rgamma(1, c, 1)
  s = x1 + x2 + x3 
  d = c(x1/s, x2/s, x3/s)
  dim(d) = c(1,3)
  d
}

#sample 2250
#initialize vectors
Z1 = rep(0, 2250)
Z2 = rep(0, 2250)
ptheta = matrix(0, nrow=2250,ncol=3)

#starter values
y1 = 89   
y2 = 642
y3 = 195
y4 = 657
Z1[1] = 500
Z2[1] = 500
ptheta[1,]=c(.33,.33,.33) #col1 = pa, col2 = pb, col3 = po
a1 = 1
a2 = 1
a3 = 1
n = 1583

#Gibbs sampling
#have to run this twice for some reason
for(j in 1:2){
  for (i in 2:2250){
    Z1[i] = rbinom(1, y2, (ptheta[i-1,1]^2)/  ((ptheta[i-1,1]^2)+ 2*ptheta[i-1,1]*ptheta[i-1,3])  )
    Z2[i] = rbinom(1, y3, (ptheta[i-1,2]^2)/  ((ptheta[i-1,2]^2)+ 2*ptheta[i-1,2]*ptheta[i-1,3])  )
    
    m1 = y1 + y2 + Z1[i]
    m2 = y1 + y3 + Z2[i]
    m3 = y2 + y3 - Z1[i] - Z2[i] + 2*y4
    
    ptheta[i,] = ddirichlet(m1+a1,m2+a2,m3+a3)
  }
}

#remove burn in (half the samples)
#final posterior means
mean(ptheta[1125:2250,1])
## [1] 0.2658627
mean(ptheta[1125:2250,2])
## [1] 0.09416281
mean(ptheta[1125:2250,3])
## [1] 0.6399745
#standard errors
sd(ptheta[1125:2250,1])
## [1] 0.0084216
sd(ptheta[1125:2250,2])
## [1] 0.005509095
sd(ptheta[1125:2250,3])
## [1] 0.00907402
#rao blackwellization
ptheta2 =ptheta[1125:2250,]

#pa
plot(density(ptheta2[,1]), ylim=c(0, 130), main="pA")
lines(density((y1+y2+Z1+1) / ((2*y1)+(y2*2)+(2*y3)+(2*y4)+3 )), lty="dotted", col=2)

mean( (y1+y2+Z1+1) / ((2*y1)+(y2*2)+(2*y3)+(2*y4)+3 ) )
## [1] 0.2658844
sd( (y1+y2+Z1+1) / ((2*y1)+(y2*2)+(2*y3)+(2*y4)+3 ) )
## [1] 0.004188003
#pb
plot(density(ptheta2[,2]), ylim=c(0, 350), main="pB")
lines(density((y1+y3+Z2+1) / ((2*y1)+(y2*2)+(2*y3)+(2*y4)+3 )), lty="dotted", col=2)

mean( (y1+y3+Z2+1) / ((2*y1)+(y2*2)+(2*y3)+(2*y4)+3 ) )
## [1] 0.09423008
sd( (y1+y3+Z2+1) / ((2*y1)+(y2*2)+(2*y3)+(2*y4)+3 ) )
## [1] 0.003466504
#po
plot(density(ptheta2[,3]), ylim=c(0, 130), main="pO")
lines(density((y2+y3-Z1-Z2+(2*y4)) / ((2*y1)+(y2*2)+(2*y3)+(2*y4)+3 )), lty="dotted", col=2)

mean( (y2+y3-Z2-Z1+(2*y4)) / ((2*y1)+(y2*2)+(2*y3)+(2*y4)+3 ) )
## [1] 0.63957
sd( (y2+y3-Z2-Z1+(2*y4)) / ((2*y1)+(y2*2)+(2*y3)+(2*y4)+3 ) )
## [1] 0.006861699