library(MASS)
library(ggplot2)
library(urca)
library(seasonal)
df <- readxl::read_excel("C:/Users/PC/Desktop/PFE/Data/Variables/df.xlsx")

vars_list <- list(
  reer = ts(df$reer, start = c(2004, 1), frequency = 12),
  ctot = ts(df$ctot, start = c(2004, 1), frequency = 12),
  tnt    = ts(df$tnt, start = c(2004, 1), frequency = 12),
  mre    = ts(df$mre, start = c(2004,1), frequency = 12),
  reser  = ts(df$reser, start = c(2004,1), frequency = 12)
)

var_names <- names(vars_list)

reer_ts    <- ts(na.omit(df$reer),   start = c(2004, 1), frequency = 12)
reer_log   <- log(reer_ts)
reer_sa    <- seas(reer_log)
reer_clean <- final(reer_sa)

ctot_ts    <- ts(na.omit(df$ctot),   start = c(2004, 1), frequency = 12)
ctot_log   <- log(ctot_ts)
ctot_sa    <- seas(ctot_log)
## Model used in SEATS is different: (0 1 1)
ctot_clean <- final(ctot_sa)

tnt_ts    <- ts(na.omit(df$tnt),    start = c(2004, 1), frequency = 12)
tnt_log   <- log(tnt_ts)
tnt_sa    <- seas(tnt_log)
## Model used in SEATS is different: (0 1 1)(1 0 0)
tnt_clean <- final(tnt_sa)


mre_ts    <- ts(na.omit(df$mre),    start = c(2004, 1), frequency = 12)
mre_log   <- log(mre_ts)
mre_sa    <- seas(mre_log)
mre_clean <- final(mre_sa)

reser_ts    <- ts(na.omit(df$reser), start = c(2004, 1), frequency = 12)
reser_log   <- log(reser_ts)
reser_sa    <- seas(reser_log)
reser_clean <- final(reser_sa)



vars_list <- list(
  reer   = reer_clean,
  ctot   = ctot_clean,
  tnt    = tnt_clean,
  mre    = mre_clean,
  reser  = reser_clean
)

var_names <- names(vars_list)
set.seed(42)

N_ITER <- 100000L
N_BURN <- 30000L
THIN    <- 10L

p_lag <- 2L
n <- 5L
vnames <- c("REER", "CTOT", "TNT", "RESER", "MRE")
Y0 <- cbind(
  REER  = as.numeric(reer_clean),
  CTOT  = as.numeric(ctot_clean),
  TNT   = as.numeric(tnt_clean),
  RESER = as.numeric(reser_clean),
  MRE   = as.numeric(mre_clean)
)

Y0 <- Y0[complete.cases(Y0), , drop = FALSE]
stopifnot(nrow(Y0) > 60)

dates <- seq(as.Date("2004-01-01"), by = "month", length.out = nrow(Y0))

# If series are not already in logs, uncomment for strictly positive variables:
# Y0 <- log(Y0)

# Standardize for stability
Y_mean <- colMeans(Y0)
Y_sd   <- apply(Y0, 2, sd)

Y <- scale(Y0)
Y <- as.matrix(Y)
colnames(Y) <- colnames(Y0)

TT <- nrow(Y)

cat(sprintf("Sample size T = %d | %s -> %s\n",
            TT, format(min(dates), "%Y-%m"), format(max(dates), "%Y-%m")))
## Sample size T = 264 | 2004-01 -> 2025-12
Rvec <- c(REER = 0.20^2,
          CTOT = 0.25^2,
          TNT  = 0.25^2,
          RESER= 0.30^2,
          MRE  = 0.30^2)

q_levels <- c(REER = 1e-4,
              CTOT = 1e-4,
              TNT  = 1e-4,
              RESER= 7e-5,
              MRE  = 7e-5)

q_slope <- 1e-5

# Priors for long-run BEER equation on standardized data
# Use mild shrinkage, not ultra-loose priors
mu_prior <- c(0, 0, 0, 0, 0)  # intercept + 4 slopes
V_prior  <- diag(c(0.10^2, 0.25^2, 0.25^2, 0.25^2, 0.25^2))

# Prior for rho of the drift equation
rho_prior_mean <- 0.85
rho_prior_sd   <- 0.10
safe_chol <- function(M, eps = 1e-10) {
  M <- (M + t(M)) / 2
  for (k in 0:8) {
    R <- tryCatch(chol(M + diag(eps * 10^k, nrow(M))), error = function(e) NULL)
    if (!is.null(R)) return(R)
  }
  ev <- eigen(M, symmetric = TRUE)
  ev$values <- pmax(ev$values, eps * 1e4)
  chol(ev$vectors %*% diag(ev$values) %*% t(ev$vectors))
}

riwish <- function(S, nu) {
  S <- (S + t(S)) / 2 + diag(1e-8, nrow(S))
  solve(rWishart(1L, df = nu, Sigma = solve(S))[,,1L])
}

build_G_llt <- function(rho_vec) {
  n_st <- 2L * n
  ix_mu <- seq(1, n_st, by = 2)
  ix_dr <- seq(2, n_st, by = 2)
  
  G <- matrix(0, n_st, n_st)
  for (i in 1:n) {
    G[ix_mu[i], ix_mu[i]] <- 1
    G[ix_mu[i], ix_dr[i]] <- 1
    G[ix_dr[i], ix_dr[i]] <- rho_vec[i]
  }
  G
}

build_Q_llt <- function(qL_vec, qS) {
  n_st <- 2L * n
  ix_mu <- seq(1, n_st, by = 2)
  ix_dr <- seq(2, n_st, by = 2)
  
  Q <- matrix(0, n_st, n_st)
  for (i in 1:n) {
    Q[ix_mu[i], ix_mu[i]] <- qL_vec[i]
    Q[ix_dr[i], ix_dr[i]] <- qS
  }
  Q
}

build_H_obs <- function() {
  n_st <- 2L * n
  H <- matrix(0, n, n_st)
  for (i in 1:n) H[i, 2*i - 1] <- 1
  H
}
ffbs_llt <- function(Y_obs, rho_vec, Q, Rvec) {
  n_st <- 2L * n
  ix_mu <- seq(1, n_st, by = 2)
  ix_dr <- seq(2, n_st, by = 2)
  
  G <- build_G_llt(rho_vec)
  H <- build_H_obs()
  R <- diag(Rvec)
  
  a_f <- matrix(0, n_st, TT)
  P_f <- array(0, c(n_st, n_st, TT))
  
  a_p <- rep(0, n_st)
  P_p <- diag(n_st) * 5
  
  for (i in 1:n) a_p[ix_mu[i]] <- mean(Y_obs[, i])
  
  # Forward filter
  for (t in 1:TT) {
    v_t <- as.numeric(Y_obs[t, ] - H %*% a_p)
    S_t <- H %*% P_p %*% t(H) + R
    K_t <- P_p %*% t(H) %*% solve(S_t)
    
    a_f[, t]  <- as.numeric(a_p + K_t %*% v_t)
    P_f[, , t] <- (diag(n_st) - K_t %*% H) %*% P_p
    
    if (t < TT) {
      a_p <- G %*% a_f[, t]
      P_p <- G %*% P_f[, , t] %*% t(G) + Q
    }
  }
  
  # Backward sampling
  chi <- matrix(0, n_st, TT)
  chi[, TT] <- as.numeric(a_f[, TT] + t(safe_chol(P_f[, , TT])) %*% rnorm(n_st))
  
  for (t in (TT - 1):1) {
    Pp1  <- G %*% P_f[, , t] %*% t(G) + Q
    J    <- P_f[, , t] %*% t(G) %*% solve(Pp1 + diag(1e-10, n_st))
    mu_s <- a_f[, t] + J %*% (chi[, t + 1] - G %*% a_f[, t])
    P_s  <- P_f[, , t] - J %*% Pp1 %*% t(J)
    P_s  <- (P_s + t(P_s)) / 2
    chi[, t] <- as.numeric(mu_s + t(safe_chol(P_s)) %*% rnorm(n_st))
  }
  
  mu_mat <- matrix(NA_real_, TT, n, dimnames = list(NULL, vnames))
  for (i in 1:n) mu_mat[, i] <- chi[ix_mu[i], ]
  
  g_mat <- Y_obs - mu_mat
  
  list(mu = mu_mat, g = g_mat, chi = chi)
}
draw_beer <- function(mu_mat) {
  y <- as.numeric(mu_mat[, "REER"])
  X <- cbind(
    1,
    mu_mat[, "CTOT"],
    mu_mat[, "TNT"],
    mu_mat[, "RESER"],
    mu_mat[, "MRE"]
  )
  
  resid0 <- y - X %*% mu_prior
  s2 <- 1 / rgamma(1, shape = 2 + length(y) / 2, rate = 0.5 + sum(resid0^2) / 2)
  
  V_post <- solve(solve(V_prior) + crossprod(X) / s2)
  m_post <- V_post %*% (solve(V_prior) %*% mu_prior + crossprod(X, y) / s2)
  
  b <- mvrnorm(1, mu = as.numeric(m_post), Sigma = V_post)
  
  list(alpha = b[1], betas = b[-1], sig2 = s2)
}
draw_rho <- function(chi, rho_cur) {
  n_st <- 2L * n
  ix_dr <- seq(2, n_st, by = 2)
  rho_new <- rho_cur
  
  for (i in 1:n) {
    d <- as.numeric(chi[ix_dr[i], ])
    x <- d[-length(d)]
    y <- d[-1]
    
    ok <- is.finite(x) & is.finite(y)
    x <- x[ok]
    y <- y[ok]
    
    if (length(x) > 10) {
      v_post <- 1 / (1 / rho_prior_sd^2 + sum(x^2))
      m_post <- v_post * (rho_prior_mean / rho_prior_sd^2 + sum(x * y))
      rho_new[i] <- rnorm(1, mean = m_post, sd = sqrt(v_post))
      rho_new[i] <- min(0.98, max(0.02, rho_new[i]))
    }
  }
  
  rho_new
}
draw_var_gaps <- function(G_mat) {
  Yg <- G_mat[(p_lag + 1):TT, , drop = FALSE]
  Zg <- cbind(G_mat[p_lag:(TT - 1), , drop = FALSE],
              G_mat[(p_lag - 1):(TT - 2), , drop = FALSE])
  
  ridge <- 1e-4
  Bhat <- solve(crossprod(Zg) + diag(ridge, ncol(Zg)), crossprod(Zg, Yg))
  E <- Yg - Zg %*% Bhat
  
  S0 <- diag(pmax(apply(Yg, 2, var), 1e-4))
  nu0 <- n + 5L
  Sigma <- riwish(S0 + crossprod(E), nu0 + nrow(Yg))
  
  list(B = Bhat, Sigma = Sigma, resid = E)
}
rho_cur <- rep(0.85, n)
names(rho_cur) <- vnames

Q_cur <- build_Q_llt(q_levels, q_slope)

ks0 <- ffbs_llt(Y, rho_cur, Q_cur, Rvec)
mu_cur <- ks0$mu
g_cur  <- ks0$g

cat(sprintf("\nInit: sd(mu_REER)=%.4f | sd(g_REER)=%.4f\n",
            sd(mu_cur[, "REER"]), sd(g_cur[, "REER"])))
## 
## Init: sd(mu_REER)=0.8234 | sd(g_REER)=0.3558
be0 <- draw_beer(mu_cur)
alpha_cur <- be0$alpha
betas_cur <- be0$betas
sig2_cur <- be0$sig2

vd0 <- draw_var_gaps(g_cur)
B_cur <- vd0$B
Sigma_cur <- vd0$Sigma

n_keep <- floor((N_ITER - N_BURN) / THIN)

mu_store    <- array(NA_real_, c(n_keep, TT, n),
                     dimnames = list(NULL, NULL, vnames))
alpha_store <- numeric(n_keep)
beta_store  <- matrix(NA_real_, n_keep, 4,
                      dimnames = list(NULL, c("b_ctot", "b_tnt", "b_reser", "b_mre")))
sig2_store  <- numeric(n_keep)

store_idx <- 0L
t0 <- proc.time()

cat(sprintf("Gibbs: %d iter | burn %d | thin %d | keep %d\n",
            N_ITER, N_BURN, THIN, n_keep))
## Gibbs: 100000 iter | burn 30000 | thin 10 | keep 7000
for (iter in 1:N_ITER) {
  
  # 1) Trend states
  ks <- ffbs_llt(Y, rho_cur, Q_cur, Rvec)
  mu_cur <- ks$mu
  g_cur  <- ks$g
  
  # 2) Long-run BEER
  be <- draw_beer(mu_cur)
  alpha_cur <- be$alpha
  betas_cur <- be$betas
  sig2_cur  <- be$sig2
  
  # 3) Optional short-run VAR on gaps
  vd <- draw_var_gaps(g_cur)
  B_cur <- vd$B
  Sigma_cur <- vd$Sigma
  
  # 4) Update rho
  rho_cur <- draw_rho(ks$chi, rho_cur)
  
  # Store
  if (iter > N_BURN && ((iter - N_BURN) %% THIN == 0L)) {
    store_idx <- store_idx + 1L
    mu_store[store_idx, , ] <- mu_cur
    alpha_store[store_idx]  <- alpha_cur
    beta_store[store_idx, ] <- betas_cur
    sig2_store[store_idx]   <- sig2_cur
  }
  
  if (iter %% max(1L, floor(N_ITER / 10L)) == 0L) {
    cat(sprintf(
      "  %3d%% | a=%.3f | b=[%.3f %.3f %.3f %.3f] | rho_REER=%.3f | sd(mu)=%.4f | %.0fs\n",
      round(100 * iter / N_ITER),
      alpha_cur, betas_cur[1], betas_cur[2], betas_cur[3], betas_cur[4],
      rho_cur[1], sd(mu_cur[, "REER"]),
      (proc.time() - t0)["elapsed"]
    ))
  }
}
##    10% | a=-0.093 | b=[0.302 0.758 0.701 -0.901] | rho_REER=0.809 | sd(mu)=0.8931 | 822s
##    20% | a=0.057 | b=[0.347 0.799 0.242 -0.676] | rho_REER=0.980 | sd(mu)=0.7316 | 1653s
##    30% | a=0.007 | b=[0.428 0.863 0.547 -0.839] | rho_REER=0.782 | sd(mu)=0.9117 | 2409s
##    40% | a=0.004 | b=[0.288 0.598 0.385 -0.554] | rho_REER=0.841 | sd(mu)=0.6971 | 3160s
##    50% | a=-0.003 | b=[0.247 0.808 0.469 -0.734] | rho_REER=0.797 | sd(mu)=0.8136 | 3945s
##    60% | a=-0.067 | b=[0.247 0.682 0.536 -0.840] | rho_REER=0.980 | sd(mu)=0.9035 | 4714s
##    70% | a=0.058 | b=[0.424 0.757 0.517 -0.851] | rho_REER=0.980 | sd(mu)=0.7928 | 5546s
##    80% | a=0.051 | b=[0.283 0.668 0.619 -0.675] | rho_REER=0.835 | sd(mu)=0.7848 | 6408s
##    90% | a=-0.044 | b=[0.232 0.730 0.413 -0.729] | rho_REER=0.924 | sd(mu)=0.8175 | 7216s
##   100% | a=-0.024 | b=[0.437 0.674 0.421 -0.750] | rho_REER=0.972 | sd(mu)=0.9089 | 8098s
mu_mean <- apply(mu_store[1:store_idx,,], c(2,3), mean)

ctot_lr  <- mu_mean[,"CTOT"]
tnt_lr   <- mu_mean[,"TNT"]
reser_lr <- mu_mean[,"RESER"]
mre_lr   <- mu_mean[,"MRE"]

a_hat <- mean(alpha_store)

b_ctot  <- mean(beta_store[,1])
b_tnt   <- mean(beta_store[,2])
b_reser <- mean(beta_store[,3])
b_mre   <- mean(beta_store[,4])

# BEER sur échelle standardisée
LREER_std <-
  a_hat +
  b_ctot*ctot_lr +
  b_tnt*tnt_lr +
  b_reser*reser_lr +
  b_mre*mre_lr

# retour vers échelle originale REER
LREER <-
  LREER_std * Y_sd["REER"] +
  Y_mean["REER"]

reer_log <- Y0[,"REER"]

mis_pct <-
  100*
  (exp(reer_log)-exp(LREER))/
  exp(LREER)

summary(mis_pct)
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -2.90623 -0.89611 -0.11925  0.01028  0.46771  6.08495
REER_trend_scaled <- mu_store[1:store_idx, , "REER"]
REER_trend_log_draws <- REER_trend_scaled * Y_sd["REER"] + Y_mean["REER"]

LREER <- colMeans(REER_trend_log_draws)
LREER_lo <- apply(REER_trend_log_draws, 2, quantile, 0.05)
LREER_hi <- apply(REER_trend_log_draws, 2, quantile, 0.95)

reer_log <- Y0[, "REER"]

# Misalignment in percent if REER is in logs
mis_pct <- 100 * (exp(reer_log) - exp(LREER)) / exp(LREER)

cat("\nPosterior means:\n")
## 
## Posterior means:
cat(sprintf("alpha = %.4f\n", mean(alpha_store[1:store_idx])))
## alpha = -0.0003
cat(sprintf("b_ctot = %.4f\n", mean(beta_store[1:store_idx, "b_ctot"])))
## b_ctot = 0.3288
cat(sprintf("b_tnt  = %.4f\n", mean(beta_store[1:store_idx, "b_tnt"])))
## b_tnt  = 0.7421
cat(sprintf("b_reser= %.4f\n", mean(beta_store[1:store_idx, "b_reser"])))
## b_reser= 0.4581
cat(sprintf("b_mre  = %.4f\n", mean(beta_store[1:store_idx, "b_mre"])))
## b_mre  = -0.7661
cat(sprintf("sd(LREER) = %.4f\n", sd(LREER)))
## sd(LREER) = 0.0208
cat(sprintf("Final misalignment = %.2f%%\n", tail(mis_pct, 1)))
## Final misalignment = -1.11%
par(mfrow = c(2, 3))
plot(alpha_store[1:store_idx], type = "l", main = "Trace alpha")
plot(beta_store[1:store_idx, "b_ctot"], type = "l", main = "Trace b_ctot")
plot(beta_store[1:store_idx, "b_tnt"],  type = "l", main = "Trace b_tnt")
plot(beta_store[1:store_idx, "b_reser"], type = "l", main = "Trace b_reser")
plot(beta_store[1:store_idx, "b_mre"],   type = "l", main = "Trace b_mre")
plot(cumsum(alpha_store[1:store_idx]) / seq_along(alpha_store[1:store_idx]),
     type = "l", main = "Cumulative mean alpha")

par(mfrow = c(1, 1))

ect <- reer_log - LREER
cat("\nADF on ECT:\n")
## 
## ADF on ECT:
print(summary(ur.df(ect, type = "none", lags = 4, selectlags = "AIC")))
## 
## ############################################### 
## # Augmented Dickey-Fuller Test Unit Root Test # 
## ############################################### 
## 
## Test regression none 
## 
## 
## Call:
## lm(formula = z.diff ~ z.lag.1 - 1 + z.diff.lag)
## 
## Residuals:
##       Min        1Q    Median        3Q       Max 
## -0.030384 -0.002311 -0.000135  0.001965  0.048908 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## z.lag.1     -0.16144    0.03951  -4.086 5.87e-05 ***
## z.diff.lag1 -0.31650    0.06556  -4.828 2.38e-06 ***
## z.diff.lag2  0.08460    0.06844   1.236  0.21755    
## z.diff.lag3  0.19635    0.06171   3.182  0.00164 ** 
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.005231 on 255 degrees of freedom
## Multiple R-squared:  0.2377, Adjusted R-squared:  0.2258 
## F-statistic: 19.88 on 4 and 255 DF,  p-value: 2.908e-14
## 
## 
## Value of test-statistic is: -4.0864 
## 
## Critical values for test statistics: 
##       1pct  5pct 10pct
## tau1 -2.58 -1.95 -1.62
df_plot <- data.frame(
  date    = dates,
  reer    = reer_log,
  lreer   = LREER,
  lo      = LREER_lo,
  hi      = LREER_hi,
  mis_pct = mis_pct
)

p1 <- ggplot(df_plot, aes(date)) +
  geom_ribbon(aes(ymin = lo, ymax = hi), fill = "#2980b9", alpha = 0.15) +
  geom_line(aes(y = reer, colour = "REER observé"), linewidth = 0.8) +
  geom_line(aes(y = lreer, colour = "LREER (TVE-VAR)"),
            linetype = "dashed", linewidth = 1) +
  scale_colour_manual(values = c("REER observé" = "#2c3e50",
                                 "LREER (TVE-VAR)" = "#e74c3c")) +
  labs(title = "TVE-VAR Maroc — REER et taux de change d'équilibre",
       subtitle = "Version stabilisée : séries standardisées + priors réguliers",
       y = "log(REER)", x = NULL, colour = NULL) +
  theme_minimal(base_size = 12) +
  theme(legend.position = "bottom")

p2 <- ggplot(df_plot, aes(date, mis_pct)) +
  geom_col(aes(fill = mis_pct > 0), width = 28, show.legend = FALSE) +
  geom_hline(yintercept = 0, linewidth = 0.6) +
  geom_hline(yintercept = c(-5, 5), linetype = "dotted",
             colour = "grey50", linewidth = 0.5) +
  scale_fill_manual(values = c("TRUE" = "#c0392b", "FALSE" = "#27ae60")) +
  labs(title = "Mésalignement du Dirham Marocain (%)",
       subtitle = "Rouge = surévaluation | Vert = sous-évaluation",
       y = "REER - LREER (%)", x = NULL) +
  theme_minimal(base_size = 12)

print(p1)

print(p2)