library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.3     ✓ purrr   0.3.4
## ✓ tibble  3.1.0     ✓ dplyr   1.0.5
## ✓ tidyr   1.1.1     ✓ stringr 1.4.0
## ✓ readr   1.3.1     ✓ forcats 0.4.0
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(patchwork)
library(here)
## here() starts at /Users/caoanjie/Desktop/projects/looking_time/adult_analysis

helper function

stimuli generation

TOTAL_FEATURE_N = 3
COMPLEX_FEATURE_N = 1
SIMPLE_FEATURE_N = 1
SIMILAR_RATIO = .2
DISSIMILAR_RATIO = .8

# generate stimuli
generate_stimuli_vector <- function(total_feature_n, stimuli_feature_n){
  
  vec <- c(rep(0, total_feature_n - stimuli_feature_n), rep(1, stimuli_feature_n))
  return(sample(vec))
  
}

# generate similar / dissimilar stimuli
# similar ratio is the proportion of different features for similar stimuli
# dissimilar ratio is the proportion of different features for dissimilar stimuli
generate_stimuli_similarity <- function(original_stimuli, similarity, 
                                        similar_ratio, 
                                        dissimilar_ratio){
  
  non_overlapping_feature <- ifelse(similarity == "similar", similar_ratio, dissimilar_ratio)
  
  # first figure out where the 1s are at 
  feature_pos <- which(original_stimuli %in% c(1))
  non_feature_pos <- which(original_stimuli %in% c(0))
  
  # change 1 to 0
  feature_change_pos <- sample(feature_pos, 
                               non_overlapping_feature * length(feature_pos), 
                               replace = FALSE)
  
  new_stim <- replace(original_stimuli, feature_change_pos, 0)
  
  # change 0 to 1 
  non_feature_change_pos <- sample(non_feature_pos, 
                                   non_overlapping_feature * length(feature_pos), 
                                   replace = FALSE)
  
  new_stim <- replace(new_stim, non_feature_change_pos, 1)
  
  
  return (new_stim)  
  
}

## create a stimuli sequence 
generate_block_sequence <- function(total_feature_n, 
                             simple_feature_n, 
                             complex_feature_n,
                             similar_ratio, 
                             dissimilar_ratio, 
                             block_length, 
                             deviant_pos, 
                             complexity, 
                             similarity){
  
  TOTAL_FEATURE_N = total_feature_n  
  # the number of 1 in the feature vector 
  SIMPLE_FEATURE_N = simple_feature_n
  COMPLEX_FEATURE_N = complex_feature_n  
  
  feature_n <- ifelse(complexity == "complex", COMPLEX_FEATURE_N, SIMPLE_FEATURE_N)
  background_stim <- generate_stimuli_vector(TOTAL_FEATURE_N, feature_n)
  deviant_stim <- generate_stimuli_similarity(background_stim, 
                                              similarity, 
                                              similar_ratio, 
                                              dissimilar_ratio)
  
  block_list <- replicate(block_length, background_stim, simplify = FALSE)
  
  block_list[deviant_pos] <- replicate(length(deviant_pos), 
                                       deviant_stim, 
                                       simplify = FALSE)
  
  return(block_list)
  
}


example_block_complex_similar <- generate_block_sequence(
  TOTAL_FEATURE_N, 
  SIMPLE_FEATURE_N, 
  COMPLEX_FEATURE_N, 
  SIMILAR_RATIO, 
  DISSIMILAR_RATIO,
  8, c(7), "complex", "similar")
example_block_complex_dissimilar <- generate_block_sequence(
  TOTAL_FEATURE_N, 
  SIMPLE_FEATURE_N, 
  COMPLEX_FEATURE_N, 
  SIMILAR_RATIO, 
  DISSIMILAR_RATIO,
  8, c(7), "complex", "dissimilar")
example_block_simple_similar <- generate_block_sequence(
  TOTAL_FEATURE_N, 
  SIMPLE_FEATURE_N, 
  COMPLEX_FEATURE_N, 
  SIMILAR_RATIO, 
  DISSIMILAR_RATIO,
  8, c(7), "simple", "similar")
example_block_simple_dissimilar <- generate_block_sequence(
  TOTAL_FEATURE_N, 
  SIMPLE_FEATURE_N, 
  SIMILAR_RATIO, 
  DISSIMILAR_RATIO,
  COMPLEX_FEATURE_N, 8, c(7), "simple", "dissimilar")

NUM_FEATURE_SAMPLE_PER_OBSERVATION <- 3

# generate observation sequence 
make_observation <- function(stimuli, feature_sample){
  
  #figure out the feature position  
  feature_pos <- which(stimuli %in% c(1))
  sampled_feature <- sample(feature_pos, 
                            feature_sample, 
                            replace = FALSE)
  observation <- rep(0, length(stimuli))
  observation[sampled_feature] <- 1
  
  return(observation)
  
}

example_block_complex_similar <- generate_block_sequence(
  TOTAL_FEATURE_N, 
  SIMPLE_FEATURE_N, 
  COMPLEX_FEATURE_N,
  SIMILAR_RATIO, 
  DISSIMILAR_RATIO,
  5, c(4), "complex", "similar")

example_block_complex_similar 
## [[1]]
## [1] 0 0 1
## 
## [[2]]
## [1] 0 0 1
## 
## [[3]]
## [1] 0 0 1
## 
## [[4]]
## [1] 0 0 1
## 
## [[5]]
## [1] 0 0 1
#example_observation <- lapply(example_block_complex_similar, make_observation, 
                              #NUM_FEATURE_SAMPLE_PER_OBSERVATION)
#example_observation

naive bayes model

# calculate beta count for a given block
get_beta_count <- function(block, feature_prior = c(3,1)){
  prior <- replicate(length(block[[1]]), feature_prior, simplify = FALSE)
  
  beta_count <- list()
  beta_count[[1]] <- prior 
  for (trial in 1:length(block)){
    beta_count[[trial+1]] <- mapply(function(x, y) {
         x[y + 1] <- x[y + 1] + 1
         return(list(x))
       },
       beta_count[[trial]], 
       block[[trial]])
  }
  
    
    
    
    
    # if (trial == 1) {
    #   beta_count[[trial]] <- prior 
    # } else {
    #   beta_count[[trial]] <- mapply(function(x, y) {
    #     x[y + 1] <- x[y + 1] + 1
    #     return(list(x))
    #   },
    #   beta_count[[trial-1]], 
    #   block[[trial]])
    # }
  
    return(beta_count)

  }
  


# use beta count to calculate feature probability 
get_probability <- function(block_beta){
  
  lapply(block_beta, function(x) lapply(x, function(x) x/sum(x)))
  
}

get_surprise <- function(block_probability, block_sequence){
  
  mapply(function (probs, seq) {
    sum(-log2(mapply(function (x, y) x[y + 1], probs, seq))) 
  }, block_probability[1:length(block_probability)-1], block_sequence)
  
}

#KL divergence betwen two bernoulli distribution? https://math.stackexchange.com/questions/2604566/kl-divergence-between-two-multivariate-bernoulli-distribution#:~:text=The%20KL%20divergence%20between%20two%20such%20distributions%20is%20DK,z)q(z).

get_learning_progress <- function(block_probability){
  learning_progress <- list()

  # stimuli_prob <- mapply(function (probs, seq) {
  #   feature_prob <- list()
  #   for (feature_pos in 1:length(seq)){
  #     if (seq[[feature_pos]] == 1){
  #       feature_prob[feature_pos] <- list(probs[[feature_pos]])
  #     }else{
  #       feature_prob[feature_pos] <- list(rev(probs[[feature_pos]]))
  #     }
  #   }
  #   return(list(feature_prob))
  #    # why does it only index into the second feature? 
  # }, block_probability, block_sequence)
  
  # for (trial in 1:length(block_probability)){
  #   prev = block_probability[[trial]]
  #   curr = stimuli_prob[[trial]]
  #   trial_lp = mapply(function(curr, prev){
  #        #lp <- curr[[1]] * log2(curr[[1]]/prev[[1]]) #the probability of seeing the feature 
  #        lp <- curr[[1]]*log2(curr[[1]]/prev[[1]]) + curr[[2]]*log2(curr[[2]]/prev[[2]])
  #        return(lp) 
  #      }, curr, prev)
  #   learning_progress[trial] <- sum(trial_lp) 
  #    }
  #   
  # return(learning_progress)
  #   
  # }  
  

   for (trial in 2:length(block_probability)){
     
    
      prev_prob = block_probability[[trial-1]]
      curr_prob = block_probability[[trial]]
      trial_lp = mapply(function(curr, prev){
        #lp <- curr[[1]] * log2(curr[[1]]/prev[[1]]) #the probability of seeing the feature 
        lp <- curr[[1]]*log2(curr[[1]]/prev[[1]]) + curr[[2]]*log2(curr[[2]]/prev[[2]])
        return(lp) 
      }, curr_prob, prev_prob)
    
   learning_progress[trial] <- sum(trial_lp) 
   }
  return(learning_progress[2:length(learning_progress)])
}

example_beta <- get_beta_count(example_block_complex_similar)
example_prob <- get_probability(example_beta)
example_surprise <- get_surprise(example_prob,example_block_complex_similar)

get_plot_surprise <- function(surprise){
  tibble("surprise" = surprise) %>% 
    mutate( trial = row_number()) %>% 
    unnest(surprise) %>% 
    ggplot(aes(x = trial, y = surpise)) + 
    geom_point() + 
    geom_line()
  
}

KEY FUNCTIONS

get_block_sequence <- function(
  complexity, 
  similarity, 
  total_feature_n, 
  simple_feature_n, 
  complex_feature_n, 
  similar_ratio, 
  dissimilar_ratio, 
  block_length, 
  deviant_pos = c(3, 5)){
  
  TOTAL_FEATURE_N = total_feature_n  
  # the number of 1 in the feature vector 
  SIMPLE_FEATURE_N = simple_feature_n
  COMPLEX_FEATURE_N = complex_feature_n  
  
  
  feature_n <- ifelse(complexity == "complex", COMPLEX_FEATURE_N, SIMPLE_FEATURE_N)
  background_stim <- generate_stimuli_vector(TOTAL_FEATURE_N, feature_n)
  deviant_stim <- generate_stimuli_similarity(background_stim, 
                                              similarity, 
                                              similar_ratio, 
                                              dissimilar_ratio)
  
  block_list <- replicate(block_length, background_stim, simplify = FALSE)
  
  if (length(deviant_pos) > 0){
    block_list[deviant_pos] <- replicate(length(deviant_pos),
                                         deviant_stim,
                                         simplify = FALSE)
  }
  return(block_list)
  
}

get_block_sequence("complexity", "similar", 10, 3, 5, 0.2, 0.8, 7, NULL)
## [[1]]
##  [1] 0 0 1 0 1 1 0 0 0 0
## 
## [[2]]
##  [1] 0 0 1 0 1 1 0 0 0 0
## 
## [[3]]
##  [1] 0 0 1 0 1 1 0 0 0 0
## 
## [[4]]
##  [1] 0 0 1 0 1 1 0 0 0 0
## 
## [[5]]
##  [1] 0 0 1 0 1 1 0 0 0 0
## 
## [[6]]
##  [1] 0 0 1 0 1 1 0 0 0 0
## 
## [[7]]
##  [1] 0 0 1 0 1 1 0 0 0 0

trying different parameter, not very robust now

complexity = c("complex", "simple")
similarity = c("similar", "dissimilar")
feature = c(200)
simple_feature = c(50)
complex_feature = c(100)
similar_ratio = c(0.1,0.3)
dissimilar_ratio = c(0.5,0.8)
block_length = c(7)



df_parameter <- crossing(complexity, similarity, feature, 
                         simple_feature, complex_feature, 
                         similar_ratio, dissimilar_ratio) %>% 
  filter(complex_feature > simple_feature) %>% 
  filter(similar_ratio < dissimilar_ratio)

df_feature_sim <- df_parameter %>% 
  mutate(
    sequence = pmap(df_parameter, .f = ~with(list(...), 
                                             get_block_sequence(complexity, similarity, 
                                                                feature, 
                                                                simple_feature, complex_feature, 
                                                                similar_ratio, dissimilar_ratio, 
                                                                block_length)))
  ) %>% 
   mutate(beta = map(.x = sequence, 
                    .f = get_beta_count), 
         probability = map(.x = beta, 
                           .f = get_probability), 
         surprise = map2(.x = probability, 
                         .y = sequence, 
                         .f = get_surprise), 
         learning_progress = map(
           .x = probability, 
           .f = get_learning_progress
         )) %>% 
  mutate(
    complex_simple_diff = complex_feature - simple_feature
  )

look at similarity / dissimilarity ratio

df_visualization <- df_feature_sim %>% 
  unnest(c(learning_progress, surprise)) %>% 
  unnest(c(learning_progress, surprise)) %>% 
  group_by(similarity, complexity, feature, 
           simple_feature, complex_feature, 
           similar_ratio, dissimilar_ratio) %>% 
  mutate(trial = row_number(), 
         block_type = paste(similarity, complexity), 
         similarity_parameter = paste("similar ratio: ", similar_ratio, "dissimilar ratio: ", dissimilar_ratio)) %>% 
  pivot_longer(cols = c(surprise, learning_progress), 
               names_to = "measurement_type", 
               values_to = "value") 


df_visualization %>% 
  filter(measurement_type == "surprise") %>% 
  ggplot(aes(x = trial, 
             y = value, 
             color = block_type)) + 
  geom_point() + 
  geom_line() + 
  facet_wrap(~similarity_parameter) + 
  ylab("surprise")

df_visualization %>% 
  filter(measurement_type == "learning_progress") %>% 
  ggplot(aes(x = trial, 
             y = value, 
             color = block_type)) + 
  geom_point() + 
  geom_line() + 
  facet_wrap(~similarity_parameter) + 
  ylab("KL divergence")