import pandas as pd
import numpy as np
import time
def G_x_matrix(n,n1,n2,m):
    G=np.zeros((n,m))
    for kk in range(m):
        prob=np.random.uniform(0.2,0.8)
        G[:,kk]=np.random.binomial(2,prob,n)
    x1=(G[:n1,:]-(G[:n1,:]).mean(axis=0))/np.sqrt((G[:n1,:]).var(axis=0))
    x2=(G[-n2:,:]-(G[-n2:,:]).mean(axis=0))/np.sqrt((G[-n2:,:]).var(axis=0))
    return x1,x2

def beta_ge(m,v1t,v2t,rgt):
    x=np.random.randn(m)
    y=np.random.randn(m)
    beta1=np.sqrt(v1t)*x
    beta2=rgt/np.sqrt(v1t)*x+np.sqrt(v2t-rgt**2/v1t)*y
    return beta1,beta2
  
def epsilon_gen(sample_size,n1,ve1,n2,N,ve2,re):
    x=np.random.randn(sample_size,n1)
    epsilon1=np.sqrt(ve1)*x
    epsilon2=np.zeros((sample_size,n2))
    epsilon2[:,:N]=re/np.sqrt(ve1)*x[:,-N:]+np.sqrt(ve2-re**2/ve1)*np.random.randn(sample_size,N)
    epsilon2[:,N:]=np.sqrt(ve2)*np.random.randn(sample_size,(n2-N))
    return epsilon1,epsilon2
n=9000
n1=5000
n2=5000
N=1000
m=1200
print('n={},n1={},n2={},N={},m={}'.format(n,n1,n2,N,m))
## n=9000,n1=5000,n2=5000,N=1000,m=1200
x1,x2=G_x_matrix(n,n1,n2,m)

C=np.zeros((n1,n2))
C[-N:,:N]=np.eye(N)

KA=x1@x2.T/m
K1=x1@x1.T/m
K2=x2@x2.T/m
KC=x1@x2.T@C.T/m

v1=0.4
v2=0.5
rg=0.12
rho=rg/np.sqrt(v1*v2)

ve1=0.4
ve2=0.5
re=0.12

V11=v1*K1+ve1*np.eye(n1)
V12=rg*KA+re*C
V22=v2*K2+ve2*np.eye(n2)

eig1=np.linalg.eig(K1)[0]
eig2=np.linalg.eig(K2)[0]
eigc=np.linalg.eig(KC)[0]
trKAt=(KA@KA.T).trace()

K=(N*KA-sum(eigc)*C)/(N*trKAt-(sum(eigc))**2)
A=(K1-np.eye(n1))/(sum(eig1**2)-n1)
B=(K2-np.eye(n2))/(sum(eig2**2)-n2)

Vu=(V11@K@V22@K.T).trace()+(V12.T@K@V12.T@K).trace()
Vx=((V11@A@V11@A).trace())*2
Vy=((V22@B@V22@B).trace())*2
Vxy=(V12@B@V12.T@A.T).trace()+(V11@A@V12@B).trace()
Vux=(V11@A@V12@K.T).trace()+(V11@K@V12.T@A).trace()
Vuy=(V12@B@V22@K.T).trace()+(V12.T@K@V22@B).trace()

muu=rg
mux=v1
muy=v2
muv=np.sqrt(mux*muy)-np.sqrt(muy)*Vx/(8*mux**(3/2))+(mux*muy)**(-1/2)*Vxy/4-np.sqrt(mux)*Vy/(8*muy**(3/2))
Vuv=(np.sqrt(muy/mux)*Vux+np.sqrt(mux/muy)*Vuy)/2
Vv=Vx*(mux/muy)/4+Vy*(muv/mux)/4+Vxy/2

Lmse=(muu/muv)**2*(Vu/(muu**2)-2*Vuv/(muu*muv)+Vv/(muv**2))
print('rho:',rho,'MSE of  LMM:',Lmse)
## rho: 0.2683281572999748 MSE of  LMM: (0.0017157577600018627+0j)
def simulate_two_traint(m,v1t,v2t,rgt,sample_size,n1,ve1,n2,N,ve2,re,KA,eigc,C,trKAt,K1,eig1,K2,eig2):
    beta1,beta2=beta_ge(m,v1t,v2t,rgt)
    epsilon1,epsilon2=epsilon_gen(sample_size,n1,ve1,n2,N,ve2,re)
    y1=x1@beta1+epsilon1
    y2=x2@beta2+epsilon2
    rghat=(y1@(N*KA-sum(eigc)*C)@(y2.T))/(N*trKAt-(sum(eigc))**2)
    vg1hat=y1@(K1-np.eye(n1))@y1.T/(sum(eig1**2)-n1)
    vg2hat=y2@(K2-np.eye(n2))@y2.T/(sum(eig2**2)-n2)
    rhohat=rghat/np.sqrt(vg1hat*vg2hat)
    return rhohat
v1t=v1/m
v2t=v2/m
rgt=rg/m
sample_size=1

rho_list=[]
for tt in range(50):
    temp=simulate_two_traint(m,v1t,v2t,rgt,sample_size,n1,ve1,n2,N,ve2,re,KA,eigc,C,trKAt,K1,eig1,K2,eig2)
    rho_list.append(temp)
  
t=np.array(rho_list)
print('rho:',rho,'\t','MSE of  LMM:',Lmse.real,'\n','rho_hat:',(np.mean(rho_list)).real,'\t','Var :',np.var(t))
## rho: 0.2683281572999748   MSE of  LMM: 0.0017157577600018627 
##  rho_hat: 0.2734869621910111      Var : 0.001540141148528704
print(max(Lmse.real/np.var(rho_list),np.var(rho_list)/Lmse.real))
## 1.1140263096281307