library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.5     ✓ purrr   0.3.4
## ✓ tibble  3.1.0     ✓ dplyr   1.0.5
## ✓ tidyr   1.1.1     ✓ stringr 1.4.0
## ✓ readr   1.3.1     ✓ forcats 0.4.0
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(rwebppl)
## using webppl version: v0.9.15 /Users/caoanjie/Library/R/3.6/library/rwebppl/js/webppl

basic R version

generate_stimuli <- function(total_number_trial, 
                             deviant_position){

  stimuli_sequence <- rep(TRUE, total_number_trial)
  stimuli_sequence <- replace(stimuli_sequence, deviant_position, FALSE)
  return (stimuli_sequence)
}

stimuli <- generate_stimuli(8, 5)
stimuli
## [1]  TRUE  TRUE  TRUE  TRUE FALSE  TRUE  TRUE  TRUE

track everything

get_pp <- function(stimuli, prior_alpha, prior_beta){
  pp_list <- rep(NA, length(stimuli))
  
  for (i in 1:length(stimuli)){
    if (stimuli[i] == TRUE){
      pp = prior_alpha / (prior_alpha + prior_beta)
      prior_alpha = prior_alpha + 1
    }else{
      pp = prior_beta / (prior_alpha + prior_beta)
      prior_beta = prior_beta + 1
    }
    pp_list[i] <- pp
  }
  
  return(pp_list)
}

get_surprisal <- function(pp){
  sapply(pp, function(x){-log(x)})
}

get_entropy <- function(pp){
  sapply(pp, function(x){-(x * log(x) + (1-x) * log(1-x))})
}


get_beta_function_approximation <- function(x, y){
    sqrt(2*pi) * ((x^(x-1/2)) * (y^(y-1/2)) / ((x+y)^(x+y-1/2)))
}
  
get_diagamma_function_approximation <- function(x){
    return(log(x) - 1/(2*x))
}

calculate_kl <- function(new_alpha, new_beta, old_alpha, old_beta){
  log(get_beta_function_approximation(new_alpha, new_beta)/get_beta_function_approximation(old_alpha, old_beta)) + 
       (old_alpha - new_alpha) * get_diagamma_function_approximation(old_alpha) + (old_beta - new_beta) * get_diagamma_function_approximation(old_beta) + (new_alpha - old_alpha + new_beta - old_beta)* get_diagamma_function_approximation(old_alpha + old_beta)
}
  

get_kl <- function(stimuli, prior_alpha, prior_beta){
  
  
  kl_list <- rep(NA, length(stimuli))
  old_alpha = prior_alpha
  old_beta = prior_beta
  
   for (i in 1:length(stimuli)){
     if (stimuli[i] == TRUE){
       new_alpha = old_alpha + 1
       new_beta = old_beta
     }else{
       new_beta = old_beta + 1 
       new_alpha = old_alpha 
     }

     kl <- calculate_kl(new_alpha, new_beta, old_alpha, old_beta)
       
      
     kl_list[i] <- kl
     
     old_alpha = new_alpha
     old_beta = new_beta
  }
  
  return (kl_list)
  
}


get_ig <- function(obs, mode, current_alpha, current_beta){
  if(obs){
    if(mode == "kl"){
      ig = calculate_kl(current_alpha + 1, current_beta, current_alpha, current_beta)
    }else if(mode == "surprisal"){
      ig = -log(current_alpha / (current_alpha + current_beta))
    }
    
  }else{
    if(mode == "kl"){
       ig = calculate_kl(current_alpha, current_beta + 1, current_alpha, current_beta)
    }else if(mode == "surprisal"){
      ig = -log(current_beta / (current_alpha + current_beta))
    }
  }
  return (ig)
}


get_eig <- function(stimuli, prior_alpha, prior_beta, mode){
  
  eig_list <- rep(NA, length(stimuli))
  
  current_alpha = prior_alpha 
  current_beta = prior_beta
  
  
  for (i in 1:length(stimuli)){
    current_obs <- stimuli[i]
    # if true 
    pp_true = current_alpha / (current_alpha + current_beta)
    ig_val_true = get_ig(TRUE, mode, current_alpha, current_beta)
    
    # if false 
    pp_false = current_beta / (current_alpha + current_beta)
    ig_val_false = get_ig(FALSE, mode, current_alpha, current_beta)
    
    eig = pp_true * ig_val_true + pp_false *ig_val_false
    
    eig_list[[i]] <- eig
    
    if(current_obs){
      current_alpha = current_alpha + 1
    }else{
      current_beta = current_beta + 1
    }
  }
  
  return (eig_list)
}



get_pp(stimuli, 1, 1)
## [1] 0.5000000 0.6666667 0.7500000 0.8000000 0.1666667 0.7142857 0.7500000
## [8] 0.7777778
get_surprisal(get_pp(stimuli, 1, 1))
## [1] 0.6931472 0.4054651 0.2876821 0.2231436 1.7917595 0.3364722 0.2876821
## [8] 0.2513144
get_entropy(get_pp(stimuli, 1, 1))
## [1] 0.6931472 0.6365142 0.5623351 0.5004024 0.4505612 0.5982696 0.5623351
## [8] 0.5297062
get_kl(stimuli, 1, 1)
## [1] 0.27605800 0.09010885 0.04440794 0.02637742 0.45440802 0.02985455 0.02165695
## [8] 0.01643356
get_eig(stimuli, 1, 1, "kl")
## [1] 0.27605800 0.18212818 0.13594965 0.10849238 0.09028118 0.07582432 0.06624445
## [8] 0.05882217
get_eig(stimuli, 1, 1, "surprisal")
## [1] 0.6931472 0.6365142 0.5623351 0.5004024 0.4505612 0.5982696 0.5623351
## [8] 0.5297062
get_measurement <- function(stimuli, prior_alpha, prior_beta){
  
  pps<- get_pp(stimuli, prior_alpha, prior_beta)
  surprisals <- get_surprisal(pps)
  entropys <- get_entropy(pps)
  kls <- get_kl(stimuli, prior_alpha, prior_beta)
  eigs_kl <- get_eig(stimuli, prior_alpha, prior_beta, mode = "kl")
  eigs_surprisal <- get_eig(stimuli, prior_alpha, prior_beta, mode = "surprisal")
  
  df_all <- tibble(
    "stimuli" = stimuli,
    "pps" = pps,
    "surprisal" = surprisals,
    "entropy" = entropys,
    "kls" = kls, 
    "eigs_kl" = eigs_kl,
    "eigs_surprisal" = eigs_surprisal

  )
  
  return (df_all)
  
}

get_measurement(stimuli, prior_alpha = 1, prior_beta = 1) %>% 
  mutate(stimulus_id = row_number()) %>% 
  pivot_longer(cols = 3:7, names_to = "measurement_type", 
               values_to = "values_type") %>% 
  ggplot(aes(x = stimulus_id, y = values_type, color = stimuli))  + 
  geom_point() +
  facet_wrap(~measurement_type)