library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.2     ✓ purrr   0.3.4
## ✓ tibble  3.0.4     ✓ dplyr   1.0.2
## ✓ tidyr   1.1.1     ✓ stringr 1.4.0
## ✓ readr   1.3.1     ✓ forcats 0.4.0
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(permute)

Goal is to understand the relationship between information-theoretic measures over sequences.

Computational reproducibility exercise with Poli, Serino, Mars, Hunnius (2020). Infants tailor their attention to maximize learning. Science Advances.

Sequences are sequences of locations in a grid with four quadrants.

Sequences

We created the sequences in MATLAB. First, 16 sequences were sampled pseudo-randomly, with the probabilities specified above as only constraint. Then, the sequences were concatenated. To check that the target location could be predicted only by relying on cue- target conditional probabilities, we fed the result of the sampling into a machine learning random forest classifier. If the classifier was able to reliably predict the target location with no information about the cue-target conditional probabilities (e.g., it successfully predicted the target location at trial N only based on target location at trial N-1), then the entire process was repeated and new sequences were sampled.

#sample_seq <- rep(1:4, 4)[shuffle(1:15)]
sample_seq <- c(1,1,3,4,1,1,4,1,2,1,1,4)

Learning model

The model is presented with a set of events x. An event is, for example, the target appearing in the upper left corner of the screen. The events followed each other until the sequence ended (or the in- fant looked away). The last event of a sequence, which also coincides with the length of the sequence, is named j, and the sequence can thus be denoted by Xj = {x1,…,xj}. The first goal of the model is to estimate the probability with which a certain event x will occur. Given that the target can appear in one of four possible locations k, the distribution of probabilities can be parameterized by the random vector p = [p1, …, pk], where pk is the probability of the target ap- pearing in the kth location. In our specific case, the target locations are four, and thus, p = [p1, p2, p3, p4]. The ideal learner treats p1 : 4 as parameters that must be estimated trial by trial given Xj. In other words, given the past events up until the current trial, the ideal learner will estimate the probabilities with which the target will appear in any of the four possible target locations. At the very beginning of each sequence, the ideal learner expects the target to appear in one of the four target locations with equal probability. This is expressed here as a prior Dirichlet distribution P(p∣ ) = Dir(p; k) (1) where all elements of  are equal to one,  = [1,1,1,1]. In this case, the parameter  determines prior expectations. If there is an imbal- ance between the values of  (e. g. ,  = [100,1,1,1]), this means that the model is biased into thinking that the target is more likely to ap- pear in the one location (p1 in the example) rather than the others.

Dirichlet-multinomial is the same thing as the beta binomial.

Play with beta-binomial:

x = seq(0,1,.01)
# shape parameters are the number of counts
plot(x, dbeta(x, shape1 = 200,shape2 = 400))

plot(x, dbeta(x, shape1= 20, shape2 = 400))

prior <- c(1,1,1,1)

Conversely, when the numbers are equal, the ideal learner has no biases toward any location. Moreover, high numbers indicate that the model has strong expectations, while low numbers indicate that the model will quickly change its expectations when presented with new evidence. Thus, specifying  = [1,1,1,1], we are defining a weak uniform prior distribution. In other words, the model has no bias toward any location but is ready to change these expectations if pre- sented with contradicting evidence. At every trial, the prior distribution is updated given the obser- vation of the new event x from the set Xj. The posterior distribution of such update is given by P(p∣Xj,)=Dir(pj;njk+k) (2) j where nk refers to the number of outcomes of type k observed up until the trial j. As a practical example, imagine that, at trial 1, the model observes a target in the location 1 (i.e., [1, 0, 0, 0]). The values of  will be updated with the evidence accumulated, thus moving from [1, 1, 1, 1] to [2, 1, 1, 1]. This implies that now it is slightly more likely to see the target in location 1 than in any of the other loca- tions. Specifically, the probability of the target appearing in any lo- cation can be computed from the posterior distribution P(p∣Xj, ) in the following fashion p(x =k∣X ,)=─j−1+K (3) In words, how likely the target is to appear in a certain corner is given by the total number of times it appeared in that corner, plus one (the value of ), divided by the total number of observations, plus 4 (the sum of the values of ). This updating rule implies that as evidence accumulates, new evidence will weigh less. Given that our sequences are stationary (i.e., the most likely location does not change within the same sequence), this assumption is justified for the current task. At every trial j, the posterior Dirichlet distribution of trial j − 1 becomes the new prior distribution. The new prior is updated using (2) and the probabilities estimates are computed using (3). When infants look away and a new sequence is played, the prior is set back to (1). This means that we assume that when infants start looking to a new sequence, they consider it as independent of the previous ones. Previous research in adults demonstrated the suitability of this assumption (24).

n <- 12
model <- tibble(trials = 0:n, 
                bin1 = 0, 
                bin2 = 0, 
                bin3 = 0, 
                bin4 = 0)
# set the prior 
model[1,2:5] <- as.list(prior)

for (i in 1:length(sample_seq)) {
  model[model$trials == i, 2:5] <- model[model$trials == i-1, 2:5]
  model[model$trials == i, sample_seq[i]+1] <-  
    model[model$trials == i-1, sample_seq[i]+1] + 1
}

# i = 1 
model
## # A tibble: 13 x 5
##    trials  bin1  bin2  bin3  bin4
##     <int> <dbl> <dbl> <dbl> <dbl>
##  1      0     1     1     1     1
##  2      1     2     1     1     1
##  3      2     3     1     1     1
##  4      3     3     1     2     1
##  5      4     3     1     2     2
##  6      5     4     1     2     2
##  7      6     5     1     2     2
##  8      7     5     1     2     3
##  9      8     6     1     2     3
## 10      9     6     2     2     3
## 11     10     7     2     2     3
## 12     11     8     2     2     3
## 13     12     8     2     2     4

Imagine you were estimating a simple ML estimate of the probability. Then you’d just count he number of times you saw a particular thing and normalize:

\[ p(x_j = k|X^{j-1}) = \frac{n^{j-1}_k}{j-1} \] Now all we are going to do is add a prior \(\alpha\) such that every bin has \(\alpha\) added to it:

\[ p(x_j = k|X^{j-1},\alpha) = \frac{n^{j-1}_k + \alpha}{(j-1) + K\alpha} \] So when \(\alpha = 1\), we get the poli formula:

\[ p(x_j = k|X^{j-1},\alpha) = \frac{n^{j-1}_k + 1}{(j-1) + K} \]

So that means all we need to do is get our counts (which we did above), and then we normalize to get probabilities.

library(magrittr)
## 
## Attaching package: 'magrittr'
## The following object is masked from 'package:purrr':
## 
##     set_names
## The following object is masked from 'package:tidyr':
## 
##     extract
model %<>%
  rowwise() %>%
  mutate(total = sum(bin1 + bin2 + bin3 + bin4),
         p1 = bin1 / total,
         p2 = bin2 / total,
         p3 = bin3 / total,
         p4 = bin4 / total)

model$observed <- c(NA, sample_seq)
model
## # A tibble: 13 x 11
## # Rowwise: 
##    trials  bin1  bin2  bin3  bin4 total    p1     p2    p3    p4 observed
##     <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl>    <dbl>
##  1      0     1     1     1     1     4 0.25  0.25   0.25  0.25        NA
##  2      1     2     1     1     1     5 0.4   0.2    0.2   0.2          1
##  3      2     3     1     1     1     6 0.5   0.167  0.167 0.167        1
##  4      3     3     1     2     1     7 0.429 0.143  0.286 0.143        3
##  5      4     3     1     2     2     8 0.375 0.125  0.25  0.25         4
##  6      5     4     1     2     2     9 0.444 0.111  0.222 0.222        1
##  7      6     5     1     2     2    10 0.5   0.1    0.2   0.2          1
##  8      7     5     1     2     3    11 0.455 0.0909 0.182 0.273        4
##  9      8     6     1     2     3    12 0.5   0.0833 0.167 0.25         1
## 10      9     6     2     2     3    13 0.462 0.154  0.154 0.231        2
## 11     10     7     2     2     3    14 0.5   0.143  0.143 0.214        1
## 12     11     8     2     2     3    15 0.533 0.133  0.133 0.2          1
## 13     12     8     2     2     4    16 0.5   0.125  0.125 0.25         4

In words, how likely the target is to appear in a certain corner is given by the total number of times it appeared in that corner, plus one (the value of ), divided by the total number of observations, plus 4 (the sum of the values of ). This updating rule implies that as evidence accumulates, new evidence will weigh less. Given that our sequences are stationary (i.e., the most likely location does not change within the same sequence), this assumption is justified for the current task.

Measures

Surprisal

I( x j = k ) = − log2 p( x j = k∣ X j−1, )

\[ I(x_j = k) = -\log_2 p(x_j = k | X^{j-1}, \alpha) \]

model$surprisal <- 0
for (i in 1:length(sample_seq)){
  curr_bin = sample_seq[i]
  curr_bin_column = paste0("p", curr_bin)
  prev_probabiliy = model %>%
    filter(trials == i-1) %>% 
    select(curr_bin_column) %>% 
    pull()
  
  curr_suprirsal = -log2(prev_probabiliy)
  
  model[model$trials == i, ]$surprisal <- curr_suprirsal
  
}
## Note: Using an external vector in selections is ambiguous.
## ℹ Use `all_of(curr_bin_column)` instead of `curr_bin_column` to silence this message.
## ℹ See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>.
## This message is displayed once per session.
model$surprisal <- NA
for (i in 1:(n+1)) {
  
  p_lasttrial <- c(model$p1[i-1], model$p2[i-1], model$p3[i-1], model$p4[i-1])
  model$surprisal[i] <- -log2( p_lasttrial[model$observed[i]]  )
}
model
## # A tibble: 13 x 12
## # Rowwise: 
##    trials  bin1  bin2  bin3  bin4 total    p1     p2    p3    p4 observed
##     <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl>    <dbl>
##  1      0     1     1     1     1     4 0.25  0.25   0.25  0.25        NA
##  2      1     2     1     1     1     5 0.4   0.2    0.2   0.2          1
##  3      2     3     1     1     1     6 0.5   0.167  0.167 0.167        1
##  4      3     3     1     2     1     7 0.429 0.143  0.286 0.143        3
##  5      4     3     1     2     2     8 0.375 0.125  0.25  0.25         4
##  6      5     4     1     2     2     9 0.444 0.111  0.222 0.222        1
##  7      6     5     1     2     2    10 0.5   0.1    0.2   0.2          1
##  8      7     5     1     2     3    11 0.455 0.0909 0.182 0.273        4
##  9      8     6     1     2     3    12 0.5   0.0833 0.167 0.25         1
## 10      9     6     2     2     3    13 0.462 0.154  0.154 0.231        2
## 11     10     7     2     2     3    14 0.5   0.143  0.143 0.214        1
## 12     11     8     2     2     3    15 0.533 0.133  0.133 0.2          1
## 13     12     8     2     2     4    16 0.5   0.125  0.125 0.25         4
## # … with 1 more variable: surprisal <dbl>

Predictability

Note that, different from surprise, here, predictability is estimated considering also the event j, and not just up to j − 1. This formula was applied when relating predictability to infants’ looking away and looking time, as they have the information relative to trial j when they decide whether to look away and when they look at the target of trial j. However, saccadic latencies do not depend on Xj but rather on Xj−1, as when planning a saccade toward the target of trial j, the target has not appeared yet. Hence, a formula slightly different from (5) was used when relating predictability to saccadic latencies, in which Xj was replaced by Xj−1.

\[ -H(p^{j}) =\sum_{k=1}^{K}p(x^j = k | X^j, \alpha)\log_2 (x_j = k | X^{j}, \alpha) \]

model$predictability <- NA_real_
for (i in 1:(n+1)){
  current_trial_predict = model %>% 
    filter(trials == i) %>% 
    select(p1,p2,p3,p4) %>% 
    pivot_longer(p1:p4, names_to = "prob") %>% 
    mutate(log2_value = log2(value), 
           product = value * log2_value) %>% 
    summarise(sum(product)) %>% 
    pull()

  model[model$trials == i, ]$predictability <- current_trial_predict
  
  
}

model
## # A tibble: 13 x 13
## # Rowwise: 
##    trials  bin1  bin2  bin3  bin4 total    p1     p2    p3    p4 observed
##     <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl>    <dbl>
##  1      0     1     1     1     1     4 0.25  0.25   0.25  0.25        NA
##  2      1     2     1     1     1     5 0.4   0.2    0.2   0.2          1
##  3      2     3     1     1     1     6 0.5   0.167  0.167 0.167        1
##  4      3     3     1     2     1     7 0.429 0.143  0.286 0.143        3
##  5      4     3     1     2     2     8 0.375 0.125  0.25  0.25         4
##  6      5     4     1     2     2     9 0.444 0.111  0.222 0.222        1
##  7      6     5     1     2     2    10 0.5   0.1    0.2   0.2          1
##  8      7     5     1     2     3    11 0.455 0.0909 0.182 0.273        4
##  9      8     6     1     2     3    12 0.5   0.0833 0.167 0.25         1
## 10      9     6     2     2     3    13 0.462 0.154  0.154 0.231        2
## 11     10     7     2     2     3    14 0.5   0.143  0.143 0.214        1
## 12     11     8     2     2     3    15 0.533 0.133  0.133 0.2          1
## 13     12     8     2     2     4    16 0.5   0.125  0.125 0.25         4
## # … with 2 more variables: surprisal <dbl>, predictability <dbl>

Learning progress

\[ D_{KL}(p^j ||p^{j-1}) =\sum_{k=1}^{K}p(x^j = k | X^j, \alpha)\log_2\frac {(x_j = k | X^{j}, \alpha)}{(x_j = k | X^{j-1}, \alpha)} \]

where p is the estimate of the parameters p1 : k at trial j, while pj-1 is estimate of the parameters p1 : k that was performed on the previous trial j − 1. Learning progress has been defined as the reduction in the error of an agent’s prediction (15). DKL is the divergence between a weighted average of prediction error at trial j and a weighted average of prediction error at trial j − 1, and hence, it is a suitable way to model learning progress in this task.

model$learning_progress <- NA_real_
for (i in 1:(n+1)){

  #i = 1 
  
  previous_trial_prob <- model %>% 
    filter(trials == i-1) %>% 
    select(p1:p4) %>% 
    rename(bin1 = p1, bin2 = p2, bin3 = p3, bin4 = p4) %>% 
    pivot_longer(bin1:bin4, names_to = "bin", values_to = "prev_prob")
  
  
  
  current_trial_prob <- model %>% 
    filter(trials == i) %>% 
    select(p1:p4) %>% 
    rename(bin1 = p1, bin2 = p2, bin3 = p3, bin4 = p4) %>% 
    pivot_longer(bin1:bin4, names_to = "bin", values_to = "curr_prob")
  
  trial_bin = left_join(previous_trial_prob, 
                        current_trial_prob, 
                        by = "bin")
  
  d_bin = trial_bin %>% 
    mutate(d_bin = curr_prob * log2(curr_prob/prev_prob)) %>% 
    summarise(sum(d_bin)) %>% 
    pull()

  
  model[model$trials == i, ]$learning_progress <- d_bin
  
}

model
## # A tibble: 13 x 14
## # Rowwise: 
##    trials  bin1  bin2  bin3  bin4 total    p1     p2    p3    p4 observed
##     <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl>    <dbl>
##  1      0     1     1     1     1     4 0.25  0.25   0.25  0.25        NA
##  2      1     2     1     1     1     5 0.4   0.2    0.2   0.2          1
##  3      2     3     1     1     1     6 0.5   0.167  0.167 0.167        1
##  4      3     3     1     2     1     7 0.429 0.143  0.286 0.143        3
##  5      4     3     1     2     2     8 0.375 0.125  0.25  0.25         4
##  6      5     4     1     2     2     9 0.444 0.111  0.222 0.222        1
##  7      6     5     1     2     2    10 0.5   0.1    0.2   0.2          1
##  8      7     5     1     2     3    11 0.455 0.0909 0.182 0.273        4
##  9      8     6     1     2     3    12 0.5   0.0833 0.167 0.25         1
## 10      9     6     2     2     3    13 0.462 0.154  0.154 0.231        2
## 11     10     7     2     2     3    14 0.5   0.143  0.143 0.214        1
## 12     11     8     2     2     3    15 0.533 0.133  0.133 0.2          1
## 13     12     8     2     2     4    16 0.5   0.125  0.125 0.25         4
## # … with 3 more variables: surprisal <dbl>, predictability <dbl>,
## #   learning_progress <dbl>

plot

df.plot <- model %>% 
  select(trials, surprisal, predictability, learning_progress) %>% 
  pivot_longer(cols = c("surprisal", "predictability", "learning_progress"), 
               names_to = "measure", 
               values_to = "value") %>% 
  filter(trials != 0)

surprise_plot <- df.plot %>% 
  filter(measure == "surprisal") %>% 
  ggplot(aes(x = trials, y = value)) + 
  geom_point() + 
  geom_line() + 
  ylab("surprise") + 
   scale_x_continuous(breaks =seq(1,12,1))

predictability_plot <- df.plot %>% 
  filter(measure == "predictability") %>% 
  ggplot(aes(x = trials, y = value)) + 
  geom_point() + 
  geom_line() + 
  ylab("predictability") + 
  scale_x_continuous(breaks =seq(1,12,1))

learning_progress_plot <- df.plot %>% 
  filter(measure == "learning_progress") %>% 
  ggplot(aes(x = trials, y = value)) + 
  geom_point() + 
  geom_line() + 
  ylab("learning_progress") + 
  scale_x_continuous(breaks =seq(1,12,1))

surprise_plot

predictability_plot

learning_progress_plot

library(patchwork)
surprise_plot+predictability_plot+learning_progress_plot + plot_layout(ncol = 1)

try pokebaby

sample_seq <- c(1,1,1,1,1,1,2,1)
n <- 8
model <- tibble(trials = 0:n, 
                bin1 = 0, 
                bin2 = 0, 
                bin3 = 0)
# set the prior 
prior <- c(1,1,1)
model[1,2:4] <- as.list(prior)

for (i in 1:length(sample_seq)) {
  model[model$trials == i, 2:4] <- model[model$trials == i-1, 2:4]
  model[model$trials == i, sample_seq[i]+1] <-  
    model[model$trials == i-1, sample_seq[i]+1] + 1
}

# i = 1 
model
## # A tibble: 9 x 4
##   trials  bin1  bin2  bin3
##    <int> <dbl> <dbl> <dbl>
## 1      0     1     1     1
## 2      1     2     1     1
## 3      2     3     1     1
## 4      3     4     1     1
## 5      4     5     1     1
## 6      5     6     1     1
## 7      6     7     1     1
## 8      7     7     2     1
## 9      8     8     2     1

probability

model %<>%
  rowwise() %>%
  mutate(total = sum(bin1 + bin2 + bin3),
         p1 = bin1 / total,
         p2 = bin2 / total,
         p3 = bin3 / total)

model$observed <- c(NA, sample_seq)
model
## # A tibble: 9 x 9
## # Rowwise: 
##   trials  bin1  bin2  bin3 total    p1    p2     p3 observed
##    <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl>    <dbl>
## 1      0     1     1     1     3 0.333 0.333 0.333        NA
## 2      1     2     1     1     4 0.5   0.25  0.25          1
## 3      2     3     1     1     5 0.6   0.2   0.2           1
## 4      3     4     1     1     6 0.667 0.167 0.167         1
## 5      4     5     1     1     7 0.714 0.143 0.143         1
## 6      5     6     1     1     8 0.75  0.125 0.125         1
## 7      6     7     1     1     9 0.778 0.111 0.111         1
## 8      7     7     2     1    10 0.7   0.2   0.1           2
## 9      8     8     2     1    11 0.727 0.182 0.0909        1

surprisal

model$surprisal <- NA_real_
for (i in 1:length(sample_seq)){
  curr_bin = sample_seq[i]
  curr_bin_column = paste0("p", curr_bin)
  prev_probabiliy = model %>%
    filter(trials == i-1) %>% 
    select(curr_bin_column) %>% 
    pull()
  
  curr_suprirsal = -log2(prev_probabiliy)
  
  model[model$trials == i, ]$surprisal <- curr_suprirsal
  
}

predictability

model$predictability <- NA_real_
for (i in 1:(n+1)){
  current_trial_predict = model %>% 
    filter(trials == i) %>% 
    select(p1,p2,p3) %>% 
    pivot_longer(p1:p3, names_to = "prob") %>% 
    mutate(log2_value = log2(value), 
           product = value * log2_value) %>% 
    summarise(sum(product)) %>% 
    pull()

  model[model$trials == i, ]$predictability <- current_trial_predict
  
  
}

model
## # A tibble: 9 x 11
## # Rowwise: 
##   trials  bin1  bin2  bin3 total    p1    p2     p3 observed surprisal
##    <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl>    <dbl>     <dbl>
## 1      0     1     1     1     3 0.333 0.333 0.333        NA    NA    
## 2      1     2     1     1     4 0.5   0.25  0.25          1     1.58 
## 3      2     3     1     1     5 0.6   0.2   0.2           1     1    
## 4      3     4     1     1     6 0.667 0.167 0.167         1     0.737
## 5      4     5     1     1     7 0.714 0.143 0.143         1     0.585
## 6      5     6     1     1     8 0.75  0.125 0.125         1     0.485
## 7      6     7     1     1     9 0.778 0.111 0.111         1     0.415
## 8      7     7     2     1    10 0.7   0.2   0.1           2     3.17 
## 9      8     8     2     1    11 0.727 0.182 0.0909        1     0.515
## # … with 1 more variable: predictability <dbl>

learning progress

model$learning_progress <- NA_real_
for (i in 1:(n+1)){

  #i = 1 
  
  previous_trial_prob <- model %>% 
    filter(trials == i-1) %>% 
    select(p1:p3) %>% 
    rename(bin1 = p1, bin2 = p2, bin3 = p3) %>% 
    pivot_longer(bin1:bin3, names_to = "bin", values_to = "prev_prob")
  
  
  
  current_trial_prob <- model %>% 
    filter(trials == i) %>% 
    select(p1:p3) %>% 
    rename(bin1 = p1, bin2 = p2, bin3 = p3) %>% 
    pivot_longer(bin1:bin3, names_to = "bin", values_to = "curr_prob")
  
  trial_bin = left_join(previous_trial_prob, 
                        current_trial_prob, 
                        by = "bin")
  
  d_bin = trial_bin %>% 
    mutate(d_bin = curr_prob * log2(curr_prob/prev_prob)) %>% 
    summarise(sum(d_bin)) %>% 
    pull()

  
  model[model$trials == i, ]$learning_progress <- d_bin
  
}

model
## # A tibble: 9 x 12
## # Rowwise: 
##   trials  bin1  bin2  bin3 total    p1    p2     p3 observed surprisal
##    <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl>    <dbl>     <dbl>
## 1      0     1     1     1     3 0.333 0.333 0.333        NA    NA    
## 2      1     2     1     1     4 0.5   0.25  0.25          1     1.58 
## 3      2     3     1     1     5 0.6   0.2   0.2           1     1    
## 4      3     4     1     1     6 0.667 0.167 0.167         1     0.737
## 5      4     5     1     1     7 0.714 0.143 0.143         1     0.585
## 6      5     6     1     1     8 0.75  0.125 0.125         1     0.485
## 7      6     7     1     1     9 0.778 0.111 0.111         1     0.415
## 8      7     7     2     1    10 0.7   0.2   0.1           2     3.17 
## 9      8     8     2     1    11 0.727 0.182 0.0909        1     0.515
## # … with 2 more variables: predictability <dbl>, learning_progress <dbl>

##plot

df.plot <- model %>% 
  select(trials, surprisal, predictability, learning_progress) %>% 
  pivot_longer(cols = c("surprisal", "predictability", "learning_progress"), 
               names_to = "measure", 
               values_to = "value") %>% 
  filter(trials != 0)

surprise_plot <- df.plot %>% 
  filter(measure == "surprisal") %>% 
  ggplot(aes(x = trials, y = value)) + 
  geom_point() + 
  geom_line() + 
  ylab("surprise") + 
   scale_x_continuous(breaks =seq(1,12,1))

predictability_plot <- df.plot %>% 
  filter(measure == "predictability") %>% 
  ggplot(aes(x = trials, y = value)) + 
  geom_point() + 
  geom_line() + 
  ylab("predictability") + 
  scale_x_continuous(breaks =seq(1,12,1))

learning_progress_plot <- df.plot %>% 
  filter(measure == "learning_progress") %>% 
  ggplot(aes(x = trials, y = value)) + 
  geom_point() + 
  geom_line() + 
  ylab("learning_progress") + 
  scale_x_continuous(breaks =seq(1,12,1))



library(patchwork)
surprise_plot+predictability_plot+learning_progress_plot + plot_layout(ncol = 1)

functionalize

poli_model_pokebaby <- function (num_trial, seq){
  
  model <- tibble(trials = 0:num_trial, 
                bin1 = 0, 
                bin2 = 0, 
                bin3 = 0)
  
  prior <- c(1,1,1)
  model[1,2:4] <- as.list(prior)

  # updating prior 
for (i in 1:length(sample_seq)) {
  model[model$trials == i, 2:4] <- model[model$trials == i-1, 2:4]
  model[model$trials == i, sample_seq[i]+1] <- 
    model[model$trials == i-1, sample_seq[i]+1] + 1
}
  
  # calculating probability 
  model %<>%
  rowwise() %>%
  mutate(total = sum(bin1 + bin2 + bin3),
         p1 = bin1 / total,
         p2 = bin2 / total,
         p3 = bin3 / total)

  model$observed <- c(NA, sample_seq)
  
  # calculating surprisal 
  model$surprisal <- NA_real_
  for (i in 1:length(sample_seq)){
      curr_bin = sample_seq[i]
      curr_bin_column = paste0("p", curr_bin)
      prev_probabiliy = model %>%
        filter(trials == i-1) %>% 
        select(curr_bin_column) %>% 
        pull()
      
      curr_suprirsal = -log2(prev_probabiliy)
      
      model[model$trials == i, ]$surprisal <- curr_suprirsal
  }
  
  # calculating predictability
  model$predictability <- NA_real_
  for (i in 1:(n+1)){
      current_trial_predict = model %>% 
      filter(trials == i) %>% 
      select(p1,p2,p3) %>% 
      pivot_longer(p1:p3, names_to = "prob") %>% 
      mutate(log2_value = log2(value), 
             product = value * log2_value) %>% 
      summarise(sum(product)) %>% 
      pull()

      model[model$trials == i, ]$predictability <- current_trial_predict
  }
  
  # calculating learning progress
  model$learning_progress <- NA_real_
  for (i in 1:(n+1)){

  #i = 1 
  
  previous_trial_prob <- model %>% 
    filter(trials == i-1) %>% 
    select(p1:p3) %>% 
    rename(bin1 = p1, bin2 = p2, bin3 = p3) %>% 
    pivot_longer(bin1:bin3, names_to = "bin", values_to = "prev_prob")
  
  
  
  current_trial_prob <- model %>% 
    filter(trials == i) %>% 
    select(p1:p3) %>% 
    rename(bin1 = p1, bin2 = p2, bin3 = p3) %>% 
    pivot_longer(bin1:bin3, names_to = "bin", values_to = "curr_prob")
  
  trial_bin = left_join(previous_trial_prob, 
                        current_trial_prob, 
                        by = "bin")
  
  d_bin = trial_bin %>% 
    mutate(d_bin = curr_prob * log2(curr_prob/prev_prob)) %>% 
    summarise(sum(d_bin)) %>% 
    pull()

  
  model[model$trials == i, ]$learning_progress <- d_bin
  
  }
  
  return(model)
  
}

poli_model_pokebaby(8, c(1,1,1,1,1,1,2,1))
## # A tibble: 9 x 12
## # Rowwise: 
##   trials  bin1  bin2  bin3 total    p1    p2     p3 observed surprisal
##    <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl>    <dbl>     <dbl>
## 1      0     1     1     1     3 0.333 0.333 0.333        NA    NA    
## 2      1     2     1     1     4 0.5   0.25  0.25          1     1.58 
## 3      2     3     1     1     5 0.6   0.2   0.2           1     1    
## 4      3     4     1     1     6 0.667 0.167 0.167         1     0.737
## 5      4     5     1     1     7 0.714 0.143 0.143         1     0.585
## 6      5     6     1     1     8 0.75  0.125 0.125         1     0.485
## 7      6     7     1     1     9 0.778 0.111 0.111         1     0.415
## 8      7     7     2     1    10 0.7   0.2   0.1           2     3.17 
## 9      8     8     2     1    11 0.727 0.182 0.0909        1     0.515
## # … with 2 more variables: predictability <dbl>, learning_progress <dbl>
poli_model_plot <- function(model){
  df.plot <- model %>% 
  select(trials, surprisal, predictability, learning_progress) %>% 
  pivot_longer(cols = c("surprisal", "predictability", "learning_progress"), 
               names_to = "measure", 
               values_to = "value") %>% 
  filter(trials != 0)

surprise_plot <- df.plot %>% 
  filter(measure == "surprisal") %>% 
  ggplot(aes(x = trials, y = value)) + 
  geom_point() + 
  geom_line() + 
  ylab("surprise") + 
   scale_x_continuous(breaks =seq(1,12,1))

predictability_plot <- df.plot %>% 
  filter(measure == "predictability") %>% 
  ggplot(aes(x = trials, y = value)) + 
  geom_point() + 
  geom_line() + 
  ylab("predictability") + 
  scale_x_continuous(breaks =seq(1,12,1))

learning_progress_plot <- df.plot %>% 
  filter(measure == "learning_progress") %>% 
  ggplot(aes(x = trials, y = value)) + 
  geom_point() + 
  geom_line() + 
  ylab("learning_progress") + 
  scale_x_continuous(breaks =seq(1,12,1))

  surprise_plot+predictability_plot+learning_progress_plot + plot_layout(ncol = 1)

  
}

poli_model_plot(model)