library(ggplot2)
library(data.table)

Load and visualize weekly transition probabilities

weekly_state_state_preds = fread('~/Downloads/CGM Data/weekly_pred_probs.csv')

states = c("green","yellow","red")

outline <- data.frame( 
  from_state = states, 
  outline_color = c('green', 'orange', 'red') 
) 
square <- data.frame( 
  x = c(-Inf, Inf, Inf, -Inf), 
  y = c(-Inf, -Inf, Inf, Inf)
)

ggplot(weekly_state_state_preds, 
       aes(x=days_since_last_contact, 
           y=pred_prob, ymin=pred_prob-2*pred_se, ymax=pred_prob+2*pred_se, color=to_state, group=to_state)) +
  geom_line() + geom_point() + geom_errorbar(width=0.1) +
  facet_wrap(from_state~to_state, scales="free_y", shrink=TRUE) + 
  scale_color_manual(breaks = states, values = c('green','orange','red', 'orange')) +
  theme_minimal() +
  geom_polygon(inherit.aes = FALSE, aes(x = x, y = y, color = outline_color, fill = NA), data = merge(outline, square)) +
  scale_fill_identity() + geom_vline(xintercept = 0, linetype="dashed") +
  theme(legend.position = "none") + 
  geom_line(inherit.aes = FALSE, aes(x=days_since_last_contact, y=baseline_prob, group=to_state),
            linetype="dashed")

Initialize patients and constants

PATIENTS_STARTING_IN_EACH_STATE = list(
  "g" = 40,
  "y" = 20,
  "r" = 40
)
NUM_PATIENTS = sum(unlist(PATIENTS_STARTING_IN_EACH_STATE))

WEEKS_TO_SIMULATE = 52

TRANS_PROB_INDEX_TO_USE_AS_BASELINE = 6

Store transition probabilities in a three-dimensional array. The third dimension is the number of weeks since contact - 1, and the 6th index is the baseline. These can be used as sampling probabilities without normalization.

trans_probs = array(weekly_state_state_preds[order(days_since_last_contact, to_state, from_state), pred_prob], 
                    dim=c(3,3,6))

# example, get probabilities for baseline patient in green state
trans_probs[1,,6]
## [1] 0.6860465 0.1744186 0.1395349

Simulate one policy with interventions every four weeks and another with no interventions

Store patient states in matrix, with rows for each patient and columns for each time step. Also store time of last contact in matrix with same setup.

patient_states = matrix(data = NA, nrow=NUM_PATIENTS, ncol = WEEKS_TO_SIMULATE)
patient_contacts = matrix(data = 0, nrow=NUM_PATIENTS, ncol = WEEKS_TO_SIMULATE)

# initialize states
patient_states[,1] = c(
  rep(1, PATIENTS_STARTING_IN_EACH_STATE[["g"]]),
  rep(2, PATIENTS_STARTING_IN_EACH_STATE[["y"]]),
  rep(3, PATIENTS_STARTING_IN_EACH_STATE[["r"]])
)

for(t in seq(1, WEEKS_TO_SIMULATE-1)) {
  for(p in seq(1,NUM_PATIENTS)) {
    # Get transition probabilities based on current patient state and days since last contact
    current_state = patient_states[p,t]
    time_since_last_contact = t - (WEEKS_TO_SIMULATE - match(1, rev(patient_contacts[p,])))
    
    # Contact if time_since_last_contact > 4
    if(is.na(time_since_last_contact) | time_since_last_contact > 4) {
      time_since_last_contact = 0
      patient_contacts[p,t] = 1
    }
    
    if(is.na(time_since_last_contact) | time_since_last_contact >= TRANS_PROB_INDEX_TO_USE_AS_BASELINE) {
      current_trans_probs = trans_probs[current_state, , TRANS_PROB_INDEX_TO_USE_AS_BASELINE]
    } else {
      current_trans_probs = trans_probs[current_state, , time_since_last_contact + 1]
    }
    
    # Sample next step from transition probabilities
    patient_states[p,t+1] = sample.int(3, 1, prob=current_trans_probs) 
  }
}

state_dt = data.table(
  state = as.numeric(patient_states),
  time = rep(seq(1,WEEKS_TO_SIMULATE), each=NUM_PATIENTS),
  patient = rep(seq(1,NUM_PATIENTS), WEEKS_TO_SIMULATE)
)
state_dt[, state_name := c('green','yellow','red')[state]]

contact_dt = data.table(
  contact = as.numeric(patient_contacts),
  time = rep(seq(1,WEEKS_TO_SIMULATE), each=NUM_PATIENTS),
  patient = rep(seq(1,NUM_PATIENTS), WEEKS_TO_SIMULATE)
)

state_contact_dt = merge(state_dt, contact_dt, by=c('patient','time')) 

melt(state_contact_dt[, .(
  p_time_in_green = mean(ifelse(state_name=="green",1,0)),
  p_time_in_yellow = mean(ifelse(state_name=="yellow",1,0)),
  p_time_in_red = mean(ifelse(state_name=="red",1,0)),
  num_contacts = sum(contact),
  num_patients = uniqueN(patient),
  num_time_steps = uniqueN(time),
  num_obs = .N
)])
##            variable    value
## 1:  p_time_in_green    0.405
## 2: p_time_in_yellow    0.370
## 3:    p_time_in_red    0.225
## 4:     num_contacts 1300.000
## 5:     num_patients  100.000
## 6:   num_time_steps   52.000
## 7:          num_obs 5200.000
patient_states = matrix(data = NA, nrow=NUM_PATIENTS, ncol = WEEKS_TO_SIMULATE)
patient_contacts = matrix(data = 0, nrow=NUM_PATIENTS, ncol = WEEKS_TO_SIMULATE)

# initialize states
patient_states[,1] = c(
  rep(1, PATIENTS_STARTING_IN_EACH_STATE[["g"]]),
  rep(2, PATIENTS_STARTING_IN_EACH_STATE[["y"]]),
  rep(3, PATIENTS_STARTING_IN_EACH_STATE[["r"]])
)

for(t in seq(1, WEEKS_TO_SIMULATE-1)) {
  for(p in seq(1,NUM_PATIENTS)) {
    # Get transition probabilities based on current patient state and days since last contact
    current_state = patient_states[p,t]
    time_since_last_contact = t - (WEEKS_TO_SIMULATE - match(1, rev(patient_contacts[p,])))
    
    # No contacts
  
    if(is.na(time_since_last_contact) | time_since_last_contact >= TRANS_PROB_INDEX_TO_USE_AS_BASELINE) {
      current_trans_probs = trans_probs[current_state, , TRANS_PROB_INDEX_TO_USE_AS_BASELINE]
    } else {
      current_trans_probs = trans_probs[current_state, , time_since_last_contact + 1]
    }
    
    # Sample next step from transition probabilities
    patient_states[p,t+1] = sample.int(3, 1, prob=current_trans_probs) 
  }
}

state_dt = data.table(
  state = as.numeric(patient_states),
  time = rep(seq(1,WEEKS_TO_SIMULATE), each=NUM_PATIENTS),
  patient = rep(seq(1,NUM_PATIENTS), WEEKS_TO_SIMULATE)
)
state_dt[, state_name := c('green','yellow','red')[state]]

contact_dt = data.table(
  contact = as.numeric(patient_contacts),
  time = rep(seq(1,WEEKS_TO_SIMULATE), each=NUM_PATIENTS),
  patient = rep(seq(1,NUM_PATIENTS), WEEKS_TO_SIMULATE)
)

state_contact_dt = merge(state_dt, contact_dt, by=c('patient','time')) 

melt(state_contact_dt[, .(
  p_time_in_green = mean(ifelse(state_name=="green",1,0)),
  p_time_in_yellow = mean(ifelse(state_name=="yellow",1,0)),
  p_time_in_red = mean(ifelse(state_name=="red",1,0)),
  num_contacts = sum(contact),
  num_patients = uniqueN(patient),
  num_time_steps = uniqueN(time),
  num_obs = .N
)])
##            variable        value
## 1:  p_time_in_green    0.3096154
## 2: p_time_in_yellow    0.4017308
## 3:    p_time_in_red    0.2886538
## 4:     num_contacts    0.0000000
## 5:     num_patients  100.0000000
## 6:   num_time_steps   52.0000000
## 7:          num_obs 5200.0000000