#define Dirichlet density
ddirichlet = function(a,b,c) {
  x1 = rgamma(1, a, 1)
  x2 = rgamma(1, b, 1)
  x3 = rgamma(1, c, 1)
  s = x1 + x2 + x3 
  d = c(x1/s, x2/s, x3/s)
  dim(d) = c(1,3)
  d
}

#sample 2250
#initialize vectors
Z1 = rep(0, 2250)
Z2 = rep(0, 2250)
ptheta = matrix(0, nrow=2250,ncol=3)

#starter values
y1 = 89   
y2 = 642
y3 = 195
y4 = 657
Z1[1] = 500
Z2[1] = 500
ptheta[1,]=c(.33,.33,.33) #col1 = pa, col2 = pb, col3 = po
a1 = 1
a2 = 1
a3 = 1

#Gibbs sampling
for (i in 2:2250){
  Z1[i] = rbinom(1, y2, (ptheta[i-1,1]^2)/  ((ptheta[i-1,1]^2)+ 2*ptheta[i-1,1]*ptheta[i-1,3])  )
  Z2[i] = rbinom(1, y3, (ptheta[i-1,2]^2)/  ((ptheta[i-1,2]^2)+ 2*ptheta[i-1,2]*ptheta[i-1,3])  )
  
  m1 = y1 + y2 + Z1[i]
  m2 = y1 + y3 + Z2[i]
  m3 = y2 + y3 - Z1[i] - Z2[i] + 2*y4
  
  ptheta[i,] = ddirichlet(m1+a1,m2+a2,m3+a3)
}


#remove burn in (half the samples)
#final posterior means
mean(ptheta[1125:2250,1])
## [1] 0.266241
mean(ptheta[1125:2250,2])
## [1] 0.0940565
mean(ptheta[1125:2250,3])
## [1] 0.6397025
#plots
ptheta2 =ptheta[1125:2250,]

plot(density(ptheta2[,1]))

plot(density(ptheta2[,2]))

plot(density(ptheta2[,3]))

#standard errors
sd(ptheta[1125:2250,1])
## [1] 0.008463806
sd(ptheta[1125:2250,2])
## [1] 0.005112975
sd(ptheta[1125:2250,3])
## [1] 0.009078872