TLDR

  • General trend: good fit on both train (r = 0.97) and test (r = 0.99)
  • By condition trend: also good fit on both train (r = 0.88) and test (r = 0.98)
  • Issue: the order of dishabituation magnitude is off on despite near ceiling fit (.98)
    • Possible Reason: grid is too narrow to cover the range
    • Possible solution: when ran on scaled embeddings (scaled between 0-1), the order is correct and the fit is great (train r = 0.88; test r = 0.98)

Methods

The general approach for this simulation run is as following:

  • We first ran a iterative grid search on a train dataset (with 100 subjects). In this grid search, we found most parameters yielding relatively good fit (m = 0.94; SD = 0.02).
    • In this grid search, for each parameter setting, we randomly sampled 10 pairs from the embedding pool, and ran 400 rounds on each stimuli pair.
    • We fit the average samples on each trial by each stimuli sequence type with the average LT on each trial by each stimuli sequence type from the human behavioral data.
    • We also scaled the data using this averaged summary of the data.
  • We selected the best fitting parameter setting (mu_-1_v_1_a_1_b_35_ep_0.001_eig_1e-05) and ran a by-condition experiment:
    • In the by condition experiment, the model randomly selected 20 stimuli pair in each condition (e.g. 20 pairs of animacy violation, 20 pairs of identity vioaltion etc). For each pair, the model rans 400 rounds of simulation on it (i.e. to average across different sampled grid).
    • We fit the average samples on each trial by each stimuli sequence type AND violation type with the average LT on each trial by each stimuli sequence type AND violation type from the human behavioral data.
library(tidyverse)
library(here)
library(ggthemes)

test_d <- read_csv(here("data/test_d.csv"))
train_d <- read_csv(here("data/train_d.csv"))

sim_d <- read_csv(here("data/summarized_results_detailed.csv")) %>% 
  rename(stimuli_sequence = stim_squence) %>% 
  mutate(trial_number = stimulus_id + 1) %>% 
  group_by(stimuli_sequence, trial_number) %>% 
  summarise(n_sample = mean(n_sample))

Data prepping

tidy up train / test datasets

tidy_train_d <- train_d %>% 
  mutate(
    stimuli_sequence = case_when(
      total_trial_number == 2 & block_type == "deviant_block" ~ "BD", 
      total_trial_number == 4 & block_type == "deviant_block" ~ "BBBD", 
      total_trial_number == 6 & block_type == "deviant_block" ~ "BBBBBD", 
      total_trial_number == 2 & block_type != "deviant_block" ~ "BB", 
      total_trial_number == 4 & block_type != "deviant_block" ~ "BBBB", 
      total_trial_number == 6 & block_type != "deviant_block" ~ "BBBBBB", 
    )
  ) %>% 
  select(subject, total_rt, trial_number, trial_type, stimuli_sequence) 


tidy_test_d <- test_d %>% 
  mutate(
    stimuli_sequence = case_when(
      total_trial_number == 2 & block_type == "deviant_block" ~ "BD", 
      total_trial_number == 4 & block_type == "deviant_block" ~ "BBBD", 
      total_trial_number == 6 & block_type == "deviant_block" ~ "BBBBBD", 
      total_trial_number == 2 & block_type != "deviant_block" ~ "BB", 
      total_trial_number == 4 & block_type != "deviant_block" ~ "BBBB", 
      total_trial_number == 6 & block_type != "deviant_block" ~ "BBBBBB", 
    )
  ) %>% 
  select(subject, total_rt, trial_number, trial_type, stimuli_sequence) 


tidy_train_summary_d <- tidy_train_d %>% 
  group_by(trial_number, trial_type, stimuli_sequence) %>% 
  summarise(total_rt = mean(total_rt))

tidy_test_summary_d <- tidy_test_d %>% 
  group_by(trial_number, trial_type, stimuli_sequence) %>% 
  summarise(total_rt = mean(total_rt))

Scaling by train and test separately

scaled_train_sim_res_info <- sim_d %>% 
  left_join(tidy_train_summary_d, by = c("trial_number", "stimuli_sequence")) %>% 
  ungroup() %>% 
  summarise(
    mean_sim = mean((n_sample), na.rm = TRUE), 
    sd_sim = sd((n_sample), na.rm = TRUE), 
    mean_rt = mean((total_rt), na.rm = TRUE), 
    sd_rt = sd((total_rt), na.rm = TRUE)
  ) 
  
scaled_train_res <- sim_d %>% 
  mutate(
    multiply_const = scaled_train_sim_res_info$sd_rt /scaled_train_sim_res_info$sd_sim
  ) %>% 
  mutate(
    scaled_stim_sample_n = scaled_train_sim_res_info$mean_rt + (n_sample - scaled_train_sim_res_info$mean_sim) * multiply_const
  ) %>% 
  select(stimuli_sequence, trial_number,scaled_stim_sample_n) %>% 
  left_join(tidy_train_summary_d, by = c("trial_number", "stimuli_sequence")) 


scaled_test_sim_res_info <- sim_d %>% 
  left_join(tidy_test_summary_d, by = c("trial_number", "stimuli_sequence")) %>% 
  ungroup() %>% 
  summarise(
    mean_sim = mean((n_sample), na.rm = TRUE), 
    sd_sim = sd((n_sample), na.rm = TRUE), 
    mean_rt = mean((total_rt), na.rm = TRUE), 
    sd_rt = sd((total_rt), na.rm = TRUE)
  ) 
  
scaled_test_res <- sim_d %>% 
  mutate(
    multiply_const = scaled_test_sim_res_info$sd_rt /scaled_test_sim_res_info$sd_sim
  ) %>% 
  mutate(
    scaled_stim_sample_n = scaled_test_sim_res_info$mean_rt + (n_sample - scaled_test_sim_res_info$mean_sim) * multiply_const
  ) %>% 
  select(stimuli_sequence, trial_number,scaled_stim_sample_n) %>% 
  left_join(tidy_test_summary_d, by = c("trial_number", "stimuli_sequence")) 

Cross-checking with the previous experiment (same parameter that the experiment has tails )

data cleaning

model_d <- read_csv(here("data/model_res_old.csv")) %>% 
   mutate(row_n = row_number()) %>% 
  filter(row_n %% 201 != 0) %>% 
  select(-row_n)

all_model_sim_res <- model_d %>% 
  group_by(mu_prior, v_prior, alpha_prior, beta_prior, epsilon, world_EIG, stim_sequence) %>% 
   summarise(across(c("stim1", "stim2", "stim3", "stim4", "stim5", "stim6"), ~ mean(.x, na.rm = TRUE))) %>% 
   pivot_longer(
    cols = c("stim1", "stim2", "stim3", "stim4", "stim5", "stim6"), 
    values_to = "sample_n", 
    names_to = "trial_number"
  ) %>% 
  mutate(
    param_info = paste("mu", mu_prior, "v", v_prior, "a", alpha_prior, "b", beta_prior, "ep", epsilon, "eig", world_EIG, sep = "_"), 
    trial_number = case_when(
      trial_number == "stim1" ~ 1, 
      trial_number == "stim2" ~ 2, 
      trial_number == "stim3" ~ 3, 
      trial_number == "stim4" ~ 4, 
      trial_number == "stim5" ~ 5, 
      trial_number == "stim6" ~ 6
    )
  ) %>% 
  ungroup() %>% 
  select(stim_sequence, trial_number, sample_n, param_info)

get_exact_sequence <- function(all_model_sim_res, seq_type){
  
  if(seq_type == "BB" | seq_type == "BBBB" | seq_type == "BBBBBB"){
    
    df <- all_model_sim_res %>% 
      filter(stim_sequence == "BBBBBB") %>% 
      filter(trial_number <= nchar(seq_type)) %>% 
      mutate(stim_sequence = seq_type)
    
  }else if(seq_type == "BD"){
    
    df <- all_model_sim_res %>% 
      filter(stim_sequence == "BDBBBB") %>% 
      filter(trial_number <= nchar(seq_type)) %>% 
      mutate(stim_sequence = seq_type)
    
  }else if(seq_type == "BBBD"){
    
    df <- all_model_sim_res %>% 
      filter(stim_sequence == "BBBDBB") %>% 
      filter(trial_number <= nchar(seq_type)) %>% 
      mutate(stim_sequence = seq_type)
    
  }else if(seq_type == "BBBBBD"){
    
    df <- all_model_sim_res %>% 
      filter(stim_sequence == "BBBBBD") 
  }
  
  return (df) 
  
}

complete_sim_res <- all_model_sim_res %>% 
  group_by(param_info) %>% 
  count() %>% 
  filter(n == 24)

chopped_sim_res <- lapply(
  c("BB", "BD", "BBBB", "BBBD", "BBBBBB", "BBBBBD"), 
  function(x){
    get_exact_sequence(all_model_sim_res %>% filter(param_info %in% complete_sim_res$param_info), x)
  }
) %>% 
  bind_rows()

scale the old simulation data

scaled_sim_res <- chopped_sim_res %>% 
  rename(stimuli_sequence = stim_sequence) %>% 
  left_join(tidy_train_summary_d %>% ungroup(), by = c("trial_number", "stimuli_sequence")) %>% 
  group_by(param_info) %>% 
  summarise(
    mean_sim = mean((sample_n), na.rm = TRUE), 
    sd_sim = sd((sample_n), na.rm = TRUE), 
    mean_rt = mean((total_rt), na.rm = TRUE), 
    sd_rt = sd((total_rt), na.rm = TRUE)
  ) %>% 
  left_join(chopped_sim_res %>% rename(stimuli_sequence = stim_sequence), by = "param_info") %>% 
  left_join(tidy_train_summary_d, by = c("trial_number", "stimuli_sequence")) %>% 
  mutate(
    multiply_const = sd_rt /sd_sim
  ) %>% 
  mutate(
    scaled_stim_sample_n = mean_rt + (sample_n - mean_sim) * multiply_const
  ) %>% 
  select(param_info, stimuli_sequence, sample_n, trial_number, scaled_stim_sample_n,total_rt) %>% 
  group_by(param_info) %>% 
  nest()

scaled_sim_res$r <- unlist(map(scaled_sim_res$data, function(x){
    r <- cor(x$total_rt, x$scaled_stim_sample_n, method = "pearson")
  }))

comparing with the test data

test_scaled_sim_res <- chopped_sim_res %>% 
  rename(stimuli_sequence = stim_sequence) %>% 
  left_join(tidy_test_summary_d, by = c("trial_number", "stimuli_sequence")) %>% 
  group_by(param_info) %>% 
  summarise(
    mean_sim = mean((sample_n), na.rm = TRUE), 
    sd_sim = sd((sample_n), na.rm = TRUE), 
    mean_rt = mean((total_rt), na.rm = TRUE), 
    sd_rt = sd((total_rt), na.rm = TRUE)
  ) %>% 
  left_join(chopped_sim_res %>% rename(stimuli_sequence = stim_sequence), by = "param_info") %>% 
  left_join(tidy_test_summary_d, by = c("trial_number", "stimuli_sequence")) %>% 
  mutate(
    multiply_const = sd_rt /sd_sim
  ) %>% 
  mutate(
    scaled_stim_sample_n = mean_rt + (sample_n - mean_sim) * multiply_const
  ) %>% 
  select(param_info, stimuli_sequence, trial_number, scaled_stim_sample_n,total_rt) %>% 
  group_by(param_info) %>% 
  nest()

test_scaled_sim_res$r <- unlist(map(test_scaled_sim_res$data, function(x){
    r <- cor(x$total_rt, x$scaled_stim_sample_n, method = "pearson")
  }))

old sim vs new sim

#sim_d 
scaled_sim_res %>% 
  arrange(-r) %>% 
  head(1) %>% 
  unnest(data) %>% 
  ungroup() %>% 
  select(stimuli_sequence, sample_n, trial_number) %>% 
  mutate(sim_type = "param_search_exp") %>% 
  bind_rows(sim_d %>% rename(sample_n = n_sample) %>% select(stimuli_sequence, sample_n, trial_number) %>% mutate(sim_type = "condition_exp")) %>% 
  ggplot(aes(
    x = trial_number, 
    y = sample_n, 
    color = sim_type, 
    group = sim_type
  )) + 
   stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  facet_wrap( ~ stimuli_sequence) +
  theme_few()

General Test

correlations

train_r <- cor(scaled_train_res$total_rt, scaled_train_res$scaled_stim_sample_n, method = "pearson")
test_r <- cor(scaled_test_res$total_rt, scaled_test_res$scaled_stim_sample_n, method = "pearson")

train_r
## [1] 0.9703497
test_r
## [1] 0.9861476

Visualizing general trend

bind_rows(scaled_test_res %>% mutate(data_type = "test"), 
          scaled_train_res %>% mutate(data_type = "train")) %>% 
  pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  ggplot(aes(x = trial_number, y = value, color = value_type, group = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  
  facet_grid(data_type ~ stimuli_sequence) +
  theme_few()

By condition test

ON TEST DATA

clean up data

condition_sim_d <- read_csv(here("data/summarized_results_detailed.csv")) %>% 
  rename(stimuli_sequence = stim_squence) %>% 
  mutate(trial_number = stimulus_id + 1) %>% 
  group_by(stimuli_sequence, trial_number, violation_type) %>% 
  summarise(n_sample = mean(n_sample)) %>% 
  mutate(violation_type = case_when(
    stimuli_sequence %in% c("BB", "BBBB", "BBBBBB") ~ "all_background", 
    TRUE ~ violation_type
  ))

retidy the human data (initially no condition info)

tidy_condition_test_summary_d <- read_csv(here("data/full_human_d.csv")) %>% 
  filter(subject %in% (tidy_test_d %>% distinct(subject) %>% pull())) %>% 
  mutate(
    stimuli_sequence = case_when(
      total_trial_number == 2 & block_type == "deviant_block" ~ "BD", 
      total_trial_number == 4 & block_type == "deviant_block" ~ "BBBD", 
      total_trial_number == 6 & block_type == "deviant_block" ~ "BBBBBD", 
      total_trial_number == 2 & block_type != "deviant_block" ~ "BB", 
      total_trial_number == 4 & block_type != "deviant_block" ~ "BBBB", 
      total_trial_number == 6 & block_type != "deviant_block" ~ "BBBBBB", 
    )
  ) %>% 
  select(subject, total_rt, violation_type, trial_number, trial_type, stimuli_sequence) %>% 
  group_by(violation_type, trial_number, stimuli_sequence) %>% 
  summarise(total_rt = mean(total_rt)) %>% 
  mutate(violation_type = if_else(violation_type == "null", "all_background", violation_type))

rescale by condition

scaled_condition_test_sim_res_info <- condition_sim_d %>% 
  left_join(tidy_condition_test_summary_d, by = c("trial_number", "stimuli_sequence", "violation_type")) %>% 
  ungroup() %>% 
  summarise(
    mean_sim = mean((n_sample), na.rm = TRUE), 
    sd_sim = sd((n_sample), na.rm = TRUE), 
    mean_rt = mean((total_rt), na.rm = TRUE), 
    sd_rt = sd((total_rt), na.rm = TRUE)
  ) 



scaled_condition_test_res <- condition_sim_d %>% 
  mutate(
    multiply_const = scaled_condition_test_sim_res_info$sd_rt /scaled_condition_test_sim_res_info$sd_sim
  ) %>% 
  mutate(
    scaled_stim_sample_n = scaled_condition_test_sim_res_info$mean_rt + (n_sample - scaled_condition_test_sim_res_info$mean_sim) * multiply_const
  ) %>% 
  select(stimuli_sequence, trial_number, violation_type, scaled_stim_sample_n) %>% 
  left_join(tidy_condition_test_summary_d, by = c("trial_number", "stimuli_sequence", "violation_type")) 

visulazation

scaled_condition_test_res %>% 
  pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  filter(violation_type == "all_background") %>% 
  ggplot(aes(x = trial_number, y = value, color = value_type, group = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  
  facet_grid(violation_type ~ stimuli_sequence) +
  theme_few()

scaled_condition_test_res %>% 
  pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  filter(violation_type != "all_background") %>% 
  ggplot(aes(x = trial_number, y = value, color = value_type, group = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  
  facet_grid(violation_type ~ stimuli_sequence) +
  theme_few()

correlation

cor(scaled_condition_test_res$total_rt, scaled_condition_test_res$scaled_stim_sample_n, method = "pearson")
## [1] 0.9820528

aggregated visualization

scaled_condition_test_res %>% 
  mutate(trial_type = case_when(
    stimuli_sequence == "BBBBBD" & trial_number == 6 ~ "deviant", 
      stimuli_sequence == "BBBD" & trial_number == 4 ~ "deviant", 
       stimuli_sequence == "BD" & trial_number == 2 ~ "deviant", 
      TRUE ~ "background"
  )) %>% 
  filter(trial_type == "deviant") %>% 
   pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  ggplot(aes(x = violation_type, y = value, color = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .2)) +
  theme_few()

ON TRAIN DATA

tidy_condition_train_summary_d <- read_csv(here("data/full_human_d.csv")) %>% 
  filter(subject %in% (tidy_train_d %>% distinct(subject) %>% pull())) %>% 
  mutate(
    stimuli_sequence = case_when(
      total_trial_number == 2 & block_type == "deviant_block" ~ "BD", 
      total_trial_number == 4 & block_type == "deviant_block" ~ "BBBD", 
      total_trial_number == 6 & block_type == "deviant_block" ~ "BBBBBD", 
      total_trial_number == 2 & block_type != "deviant_block" ~ "BB", 
      total_trial_number == 4 & block_type != "deviant_block" ~ "BBBB", 
      total_trial_number == 6 & block_type != "deviant_block" ~ "BBBBBB", 
    )
  ) %>% 
  select(subject, total_rt, violation_type, trial_number, trial_type, stimuli_sequence) %>% 
  group_by(violation_type, trial_number, stimuli_sequence) %>% 
  summarise(total_rt = mean(total_rt)) %>% 
  mutate(violation_type = if_else(violation_type == "null", "all_background", violation_type))

scaled_condition_train_sim_res_info <- condition_sim_d %>% 
  left_join(tidy_condition_train_summary_d, by = c("trial_number", "stimuli_sequence", "violation_type")) %>% 
  ungroup() %>% 
  summarise(
    mean_sim = mean((n_sample), na.rm = TRUE), 
    sd_sim = sd((n_sample), na.rm = TRUE), 
    mean_rt = mean((total_rt), na.rm = TRUE), 
    sd_rt = sd((total_rt), na.rm = TRUE)
  ) 



scaled_condition_train_res <- condition_sim_d %>% 
  mutate(
    multiply_const = scaled_condition_train_sim_res_info$sd_rt /scaled_condition_train_sim_res_info$sd_sim
  ) %>% 
  mutate(
    scaled_stim_sample_n = scaled_condition_train_sim_res_info$mean_rt + (n_sample - scaled_condition_train_sim_res_info$mean_sim) * multiply_const
  ) %>% 
  select(stimuli_sequence, trial_number, violation_type, scaled_stim_sample_n) %>% 
  left_join(tidy_condition_train_summary_d, by = c("trial_number", "stimuli_sequence", "violation_type")) 

visualization

scaled_condition_train_res %>% 
  pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  filter(violation_type == "all_background") %>% 
  ggplot(aes(x = trial_number, y = value, color = value_type, group = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  
  facet_grid(violation_type ~ stimuli_sequence) +
  theme_few()

scaled_condition_train_res %>% 
  pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  filter(violation_type != "all_background") %>% 
  ggplot(aes(x = trial_number, y = value, color = value_type, group = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  
  facet_grid(violation_type ~ stimuli_sequence) +
  theme_few()

correlation

cor(scaled_condition_train_res$total_rt, scaled_condition_train_res$scaled_stim_sample_n, method = "pearson")
## [1] 0.8797483

aggregated visualization

scaled_condition_train_res %>% 
  mutate(trial_type = case_when(
    stimuli_sequence == "BBBBBD" & trial_number == 6 ~ "deviant", 
      stimuli_sequence == "BBBD" & trial_number == 4 ~ "deviant", 
       stimuli_sequence == "BD" & trial_number == 2 ~ "deviant", 
      TRUE ~ "background"
  )) %>% 
  filter(trial_type == "deviant") %>% 
   pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  ggplot(aes(x = violation_type, y = value, color = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .2)) +
  theme_few()

Scaling as possible solution

Embedding distance vs dihab

We noticed that the positive relationship between embedding distance and the raw number of samples on the deviant trial only exist in the lower end of the x-axis, so we hypothesize that maybe it is because of our grid being too narrow. We decided to scale the embedding between 0-1

read_csv(here("data/summarized_results_detailed_new.csv")) %>% 
  filter(epsilon == 0.001) %>% 
  filter(stim_squence %in% c("BD", "BBBD", "BBBBBD")) %>% 
  filter(stimulus_id == case_when(
    stim_squence == "BD" ~ 1, 
    stim_squence == "BBBD" ~ 3, 
    stim_squence == "BBBBBD" ~ 5, 
  )) %>% 
  mutate(abs_magnitude = abs(b_val - d_val)) %>% 
  distinct(violation_type, abs_magnitude, stim_squence, n_sample, epsilon) %>% 
  group_by(violation_type, abs_magnitude, stim_squence, epsilon) %>% 
  summarise(mean_n_sample = n_sample) %>% 
  ggplot(aes(x = abs_magnitude, y = mean_n_sample)) +
  geom_point(alpha = .05) + 
  facet_grid(~stim_squence) + 
  geom_smooth()+
  theme_few()

Scaled version

(note that this is a very rough attempt – we are scaling it after PCA, but probably should have done it before that)

distance

se_d <- read_csv(here("data/summarized_results_scaled.csv"))

se_d %>% 
  filter(stim_squence %in% c("BD", "BBBD", "BBBBBD")) %>% 
  mutate(abs_magnitude = abs(b_val - d_val)) %>% 
  distinct(violation_type, abs_magnitude, stim_squence) %>% 
  ggplot(aes(x = violation_type, y = abs_magnitude)) +
  stat_summary(fun.data = "mean_cl_boot") + 
  geom_jitter(position = position_jitter(width = .2), alpha = .3) + 
  theme_few() + 
  facet_wrap(~stim_squence)

### clean up data

se_condition_sim_d <- se_d %>% 
  rename(stimuli_sequence = stim_squence) %>% 
  mutate(trial_number = stimulus_id + 1) %>% 
  group_by(stimuli_sequence, trial_number, violation_type) %>% 
  summarise(n_sample = mean(n_sample)) %>% 
  mutate(violation_type = case_when(
    stimuli_sequence %in% c("BB", "BBBB", "BBBBBB") ~ "all_background", 
    TRUE ~ violation_type
  ))

FIT TEST DATA

scaled_se_condition_test_sim_res_info <- se_condition_sim_d %>% 
  left_join(tidy_condition_test_summary_d, by = c("trial_number", "stimuli_sequence", "violation_type")) %>% 
  ungroup() %>% 
  summarise(
    mean_sim = mean((n_sample), na.rm = TRUE), 
    sd_sim = sd((n_sample), na.rm = TRUE), 
    mean_rt = mean((total_rt), na.rm = TRUE), 
    sd_rt = sd((total_rt), na.rm = TRUE)
  ) 


scaled_se_condition_test_res <- se_condition_sim_d %>% 
  mutate(
    multiply_const = scaled_se_condition_test_sim_res_info$sd_rt /scaled_se_condition_test_sim_res_info$sd_sim
  ) %>% 
  mutate(
    scaled_stim_sample_n = scaled_se_condition_test_sim_res_info$mean_rt + (n_sample - scaled_se_condition_test_sim_res_info$mean_sim) * multiply_const
  ) %>% 
  select(stimuli_sequence, trial_number, violation_type, scaled_stim_sample_n) %>% 
  left_join(tidy_condition_test_summary_d, by = c("trial_number", "stimuli_sequence", "violation_type")) 

visulazation

scaled_se_condition_test_res %>% 
  pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  filter(violation_type == "all_background") %>% 
  ggplot(aes(x = trial_number, y = value, color = value_type, group = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  
  facet_grid(violation_type ~ stimuli_sequence) +
  theme_few()

scaled_se_condition_test_res %>% 
  pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  filter(violation_type != "all_background") %>% 
  ggplot(aes(x = trial_number, y = value, color = value_type, group = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  
  facet_grid(violation_type ~ stimuli_sequence) +
  theme_few()

correlation

cor(scaled_se_condition_test_res$total_rt, scaled_se_condition_test_res$scaled_stim_sample_n, method = "pearson")
## [1] 0.9777929

aggregated visualization

scaled_se_condition_test_res %>% 
  mutate(trial_type = case_when(
    stimuli_sequence == "BBBBBD" & trial_number == 6 ~ "deviant", 
      stimuli_sequence == "BBBD" & trial_number == 4 ~ "deviant", 
       stimuli_sequence == "BD" & trial_number == 2 ~ "deviant", 
      TRUE ~ "background"
  )) %>% 
  filter(trial_type == "deviant") %>% 
   pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  ggplot(aes(x = violation_type, y = value, color = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .2)) +
  theme_few()

FIT TRAIN DATA

scaled_se_condition_train_sim_res_info <- se_condition_sim_d %>% 
  left_join(tidy_condition_train_summary_d, by = c("trial_number", "stimuli_sequence", "violation_type")) %>% 
  ungroup() %>% 
  summarise(
    mean_sim = mean((n_sample), na.rm = TRUE), 
    sd_sim = sd((n_sample), na.rm = TRUE), 
    mean_rt = mean((total_rt), na.rm = TRUE), 
    sd_rt = sd((total_rt), na.rm = TRUE)
  ) 


scaled_se_condition_train_res <- se_condition_sim_d %>% 
  mutate(
    multiply_const = scaled_se_condition_train_sim_res_info$sd_rt /scaled_se_condition_train_sim_res_info$sd_sim
  ) %>% 
  mutate(
    scaled_stim_sample_n = scaled_se_condition_train_sim_res_info$mean_rt + (n_sample - scaled_se_condition_train_sim_res_info$mean_sim) * multiply_const
  ) %>% 
  select(stimuli_sequence, trial_number, violation_type, scaled_stim_sample_n) %>% 
  left_join(tidy_condition_train_summary_d, by = c("trial_number", "stimuli_sequence", "violation_type")) 

visulazation

scaled_se_condition_train_res %>% 
  pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  filter(violation_type == "all_background") %>% 
  ggplot(aes(x = trial_number, y = value, color = value_type, group = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  
  facet_grid(violation_type ~ stimuli_sequence) +
  theme_few()

scaled_se_condition_train_res %>% 
  pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  filter(violation_type != "all_background") %>% 
  ggplot(aes(x = trial_number, y = value, color = value_type, group = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .4)) +
  stat_summary(geom = "line", position = position_dodge(width = .4)) +
  #scale_y_log10()+
  
  facet_grid(violation_type ~ stimuli_sequence) +
  theme_few()

correlation

cor(scaled_se_condition_train_res$total_rt, scaled_se_condition_train_res$scaled_stim_sample_n, method = "pearson")
## [1] 0.8728853

aggregated visualization

scaled_se_condition_train_res %>% 
  mutate(trial_type = case_when(
    stimuli_sequence == "BBBBBD" & trial_number == 6 ~ "deviant", 
      stimuli_sequence == "BBBD" & trial_number == 4 ~ "deviant", 
       stimuli_sequence == "BD" & trial_number == 2 ~ "deviant", 
      TRUE ~ "background"
  )) %>% 
  filter(trial_type == "deviant") %>% 
   pivot_longer(cols = c("scaled_stim_sample_n", "total_rt"), 
               names_to = "value_type", 
               values_to = "value") %>% 
  ggplot(aes(x = violation_type, y = value, color = value_type)) + 
  stat_summary(fun.data = "mean_cl_boot", position = position_dodge(width = .2)) +
  theme_few()