library(tidyverse)
library(testthat)
library(tictoc)
source("~/Documents/JMP/code/rust3_probtransitionmatrix.R")

1 Transform data into trajectories

# 1 Transform data into trajectories ---------------------------------------------

rust <- rust %>%
  group_split(bus) %>%
  map(~ list(s = .x$mileage, a = as.integer(.x$replace)))

head(rust)
## [[1]]
## [[1]]$s
##  [1]  1  2  2  3  4  4  5  6  7  7  8  9  9 10 11 11 12 12 13 14 14 15 16 16 17
## [26] 17 18 19 19 20 20 21 21 22 23 24 25 25 26 27 28 28 29 29 30 30 30 31 31
## 
## [[1]]$a
##  [1]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
## [26]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 NA
## 
## 
## [[2]]
## [[2]]$s
##  [1]  1  2  2  3  4  5  5  6  7  7  8  9  9 10 11 11 12 12 13 13 14 14 15 16 16
## [26] 17 18 19 19 20 21 21 21 22 23 24 24 25 26 26 27 28 29 29 31 31 32 32 33
## 
## [[2]]$a
##  [1]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
## [26]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 NA
## 
## 
## [[3]]
## [[3]]$s
##  [1]  1  1  2  3  3  4  5  6  7  7  7  8  9 10 10 11 11 12 13 13 14 14 15 15 16
## [26] 17 17 18 19 19 20 21 22 22 23 23 24 24 25 26 26 26 27 28 28 28 29 29 30
## 
## [[3]]$a
##  [1]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
## [26]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 NA
## 
## 
## [[4]]
## [[4]]$s
##  [1]  1  1  2  2  3  4  4  5  5  6  6  6  7  8  8  9 10 10 11 11 12 12 13 13 14
## [26] 14 15 15 16 17 17 18 18 20 20 21 22 22 23 23 24 25 26 27 27 27 28 28 29
## 
## [[4]]$a
##  [1]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
## [26]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 NA
## 
## 
## [[5]]
## [[5]]$s
##   [1]  1  2  3  4  5  5  6  7  8  9 10 11 12 13 14 14 15 15 16 16 17 17 18 19 19
##  [26] 20 20 21 21 22 23 23 24 24 25 25 26 26 27 28 28 29 30 31  1  1  2  3  3  3
##  [51]  4  5  5  6  7  7  8  8  8  9  9 10 10 11 11 12 12 13 13 14 14 15 15 16 16
##  [76] 17 17 18 18 18 18 19 19 20 20 20 21 21 21 22 22 23 23 24 24 24 24 25 25 26
## [101] 26 27 27 27 28 28 28 29 29 29 30 30 30 31 31 32 32
## 
## [[5]]$a
##   [1]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##  [26]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0
##  [51]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##  [76]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
## [101]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 NA
## 
## 
## [[6]]
## [[6]]$s
##   [1]  1  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 16 17 18 19 20 21 21 22
##  [26] 23 24 25 25 26 27 27 28 29 30 30 31 32 33 33 34 35 35 36 37 37 37 38 38 38
##  [51] 38 39 39 40 41 41 42 42 42 42 43 43 43 44 44 45 45 46 47 47 47 48 48 49 49
##  [76] 50 50 51 51 52 52 53 53 53 54 54 54 55 55 55 56 56 57 57 57 58 58 59 59 59
## [101] 60 60 60 61 61 61 62 62 63 63 64 64 64 65 65 65 66
## 
## [[6]]$a
##   [1]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##  [26]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##  [51]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##  [76]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
## [101]  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 NA
beta <- .99

2 Define a step function step()

# 2 Define a step function `step()` ----------------------------------------------

# I need to be able to simulate trajectories given the approximated probability
# transition matrix and any policy. Create a function `step` that simulates a 
# bus stepping up a month given an action a (keep or replace the engine).

step <- function(x, a) {
  # Takes mileage x
  # Returns x' sampled from P(x'|x)
  state <- rep(0, 90)
  state[x] <- 1
  p <- matrix(state, ncol = 90) %*% (
    if (a == 0) { transitiond0 } else { transitiond1 }
  )
  sample(1:90, size = 1, prob = as.vector(p))
}

# Keeping the engine when x = 5: x' will be 5, 6, or 7 with probabilities 
# .36, .63, and .01.
map_dbl(1:300, function(...) step(5, 0))
##   [1] 6 6 6 6 5 6 6 5 6 6 5 6 5 5 5 6 5 6 6 5 5 5 7 5 6 6 6 6 5 7 5 5 6 5 5 5 5
##  [38] 5 6 7 5 5 6 5 6 6 6 6 6 6 6 6 6 5 5 5 6 6 6 6 5 6 6 5 5 5 5 5 6 6 6 5 6 6
##  [75] 6 6 5 6 6 6 5 6 5 6 6 5 5 5 5 6 6 6 5 6 6 6 5 6 6 6 5 6 6 5 6 6 5 6 5 5 6
## [112] 5 5 5 6 6 6 5 6 6 6 6 5 6 5 6 6 5 6 5 5 6 6 5 6 6 6 6 6 6 6 6 6 6 5 6 6 6
## [149] 6 5 6 6 6 5 5 6 6 6 5 6 6 6 5 6 6 6 6 6 6 5 6 6 6 5 6 5 6 5 6 6 6 6 6 6 5
## [186] 6 6 6 6 5 5 6 5 5 6 6 6 6 6 6 5 6 6 6 5 7 6 5 6 5 6 5 6 6 5 6 6 5 6 6 6 5
## [223] 6 5 6 6 5 6 5 6 5 6 6 6 5 6 6 6 5 5 6 5 5 5 6 6 6 5 5 6 6 5 5 5 6 5 5 6 5
## [260] 5 5 6 5 5 6 6 5 5 5 6 6 6 6 6 6 5 6 6 6 5 5 6 6 5 6 6 6 5 6 6 6 5 5 6 6 5
## [297] 5 6 6 6
# Keep the engine when x = 5: x' will be 1, 2, or 3 with probabilities 
# .36, .63, and .01.
map_dbl(1:300, function(...) step(5, 1))
##   [1] 2 1 2 1 2 2 2 1 2 2 2 1 1 1 2 1 1 2 2 2 2 2 1 1 1 2 1 2 2 1 2 2 2 2 2 2 1
##  [38] 2 2 2 2 2 1 2 2 2 1 1 3 2 2 2 2 2 2 2 2 2 2 1 2 1 1 1 1 2 2 1 2 2 2 1 1 2
##  [75] 2 2 2 2 2 1 2 2 2 1 1 2 1 1 1 2 2 2 3 1 1 2 2 2 2 2 2 2 2 2 2 2 2 1 2 1 2
## [112] 2 2 2 2 2 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 1 2 2 2 1 1 2 2 1
## [149] 2 2 2 2 2 2 1 1 2 2 2 1 2 2 1 2 1 2 1 2 2 1 2 1 2 1 2 1 2 1 2 1 2 2 1 2 1
## [186] 1 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 1 2 2 1 2 2 2 2 1 1 2 2 2 2 1 2 1 3 1 2
## [223] 2 1 2 2 2 2 1 3 1 1 2 1 2 1 2 2 2 2 1 1 1 1 1 1 2 2 1 1 1 1 2 2 2 2 2 1 1
## [260] 1 2 2 2 2 2 1 1 2 1 2 2 2 2 2 2 2 2 1 1 2 1 3 2 1 2 2 2 2 2 2 1 1 2 1 2 2
## [297] 2 1 2 2

3 Bellman helper functions

# 3 Bellman helper functions -------------------------------------------------------

#   3.1: alphas --> R(s) `reward()` ----------------------------------------------

reward <- function(alpha) {
  map_dbl(1:90, ~ sum(alpha * dnorm(.x, 1:90, sd = 1)))
}

#      Tests: if alpha = c(1, 0, 0, ..., 0), R(s) should be -------------------------
#  (dnorm(1, mean = 1, sd = 1), smaller, smaller-er, )

reward(c(1, rep(0, 90 - 1)))
##  [1]  3.989423e-01  2.419707e-01  5.399097e-02  4.431848e-03  1.338302e-04
##  [6]  1.486720e-06  6.075883e-09  9.134720e-12  5.052271e-15  1.027977e-18
## [11]  7.694599e-23  2.118819e-27  2.146384e-32  7.998828e-38  1.096607e-43
## [16]  5.530710e-50  1.026163e-56  7.004182e-64  1.758750e-71  1.624636e-79
## [21]  5.520948e-88  6.902029e-97 3.174282e-106 5.370560e-116 3.342714e-126
## [26] 7.653930e-137 6.447260e-148 1.997889e-159 2.277577e-171 9.551695e-184
## [31] 1.473646e-196 8.363952e-210 1.746366e-223 1.341420e-237 3.790526e-252
## [36] 3.940396e-267 1.506905e-282 2.120007e-298 1.097221e-314  0.000000e+00
## [41]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [46]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [51]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [56]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [61]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [66]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [71]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [76]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [81]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [86]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
# Test 2: symmetry of alpha = c(1, 0, 0, ..., 0, 1)

reward(c(1, rep(0, 88), 1))
##  [1]  3.989423e-01  2.419707e-01  5.399097e-02  4.431848e-03  1.338302e-04
##  [6]  1.486720e-06  6.075883e-09  9.134720e-12  5.052271e-15  1.027977e-18
## [11]  7.694599e-23  2.118819e-27  2.146384e-32  7.998828e-38  1.096607e-43
## [16]  5.530710e-50  1.026163e-56  7.004182e-64  1.758750e-71  1.624636e-79
## [21]  5.520948e-88  6.902029e-97 3.174282e-106 5.370560e-116 3.342714e-126
## [26] 7.653930e-137 6.447260e-148 1.997889e-159 2.277577e-171 9.551695e-184
## [31] 1.473646e-196 8.363952e-210 1.746366e-223 1.341420e-237 3.790526e-252
## [36] 3.940396e-267 1.506905e-282 2.120007e-298 1.097221e-314  0.000000e+00
## [41]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [46]  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [51]  0.000000e+00 1.097221e-314 2.120007e-298 1.506905e-282 3.940396e-267
## [56] 3.790526e-252 1.341420e-237 1.746366e-223 8.363952e-210 1.473646e-196
## [61] 9.551695e-184 2.277577e-171 1.997889e-159 6.447260e-148 7.653930e-137
## [66] 3.342714e-126 5.370560e-116 3.174282e-106  6.902029e-97  5.520948e-88
## [71]  1.624636e-79  1.758750e-71  7.004182e-64  1.026163e-56  5.530710e-50
## [76]  1.096607e-43  7.998828e-38  2.146384e-32  2.118819e-27  7.694599e-23
## [81]  1.027977e-18  5.052271e-15  9.134720e-12  6.075883e-09  1.486720e-06
## [86]  1.338302e-04  4.431848e-03  5.399097e-02  2.419707e-01  3.989423e-01
#   3.2: Value iteration helper function `V_update()` --------------------------
# Take a guess for V(s, a) and update it according to the Bellman Equation:

V_update <- function(V, R, beta, P0, P1) {
  
  cbind(
    # v(s, a) is 90x2
    # v(s, 0): value of *keeping* the engine in state s and then behaving 
    # optimally thereafter
    matrix(
      data = R + beta * P0 %*% pmax(V[, 1], V[, 2]),
      # V(s, 0) = R(s) + beta * P(s' | s, 0) %*% V*(s)
      nrow = 90,
      ncol = 1
    ),
    
    # v(s, 1): value of *replacing* the engine in state s and then behaving 
    # optimally thereafter
    matrix(
      data = R + beta * P1 %*% pmax(V[, 1], V[, 2]),
      # V(s, 1) = R(s) + beta * P(s' | s, 1) %*% V*(s)
      nrow = 90,
      ncol = 1
    )
  )
}

#      Test: ---------------------------------------------

V_update(
  V = matrix(data = c(1, rep(0, 180 - 1)), nrow = 90, ncol = 2),
  R = reward(c(1, 1, 1, rep(0, 180 - 3))),
  beta = .99,
  P0 = transitiond0,
  P1 = transitiond1
)
##                [,1]      [,2]
##  [1,]  1.047522e+00 1.0475217
##  [2,]  8.828837e-01 1.2355014
##  [3,]  6.949040e-01 1.0475217
##  [4,]  3.003935e-01 0.6530112
##  [5,]  5.855665e-02 0.4111743
##  [6,]  4.567165e-03 0.3571849
##  [7,]  1.353230e-04 0.3527530
##  [8,]  1.492805e-06 0.3526192
##  [9,]  6.085023e-09 0.3526177
## [10,]  9.139774e-12 0.3526177
## [11,]  5.053299e-15 0.3526177
## [12,]  1.028054e-18 0.3526177
## [13,]  7.694811e-23 0.3526177
## [14,]  2.118841e-27 0.3526177
## [15,]  2.146392e-32 0.3526177
## [16,]  7.998839e-38 0.3526177
## [17,]  1.096607e-43 0.3526177
## [18,]  5.530711e-50 0.3526177
## [19,]  1.026163e-56 0.3526177
## [20,]  7.004182e-64 0.3526177
## [21,]  1.758750e-71 0.3526177
## [22,]  1.624636e-79 0.3526177
## [23,]  5.520948e-88 0.3526177
## [24,]  6.902029e-97 0.3526177
## [25,] 3.174282e-106 0.3526177
## [26,] 5.370560e-116 0.3526177
## [27,] 3.342714e-126 0.3526177
## [28,] 7.653930e-137 0.3526177
## [29,] 6.447260e-148 0.3526177
## [30,] 1.997889e-159 0.3526177
## [31,] 2.277577e-171 0.3526177
## [32,] 9.551695e-184 0.3526177
## [33,] 1.473646e-196 0.3526177
## [34,] 8.363952e-210 0.3526177
## [35,] 1.746366e-223 0.3526177
## [36,] 1.341420e-237 0.3526177
## [37,] 3.790526e-252 0.3526177
## [38,] 3.940396e-267 0.3526177
## [39,] 1.506905e-282 0.3526177
## [40,] 2.120007e-298 0.3526177
## [41,] 1.097221e-314 0.3526177
## [42,]  0.000000e+00 0.3526177
## [43,]  0.000000e+00 0.3526177
## [44,]  0.000000e+00 0.3526177
## [45,]  0.000000e+00 0.3526177
## [46,]  0.000000e+00 0.3526177
## [47,]  0.000000e+00 0.3526177
## [48,]  0.000000e+00 0.3526177
## [49,]  0.000000e+00 0.3526177
## [50,]  0.000000e+00 0.3526177
## [51,]  0.000000e+00 0.3526177
## [52,]  0.000000e+00 0.3526177
## [53,]  0.000000e+00 0.3526177
## [54,]  0.000000e+00 0.3526177
## [55,]  0.000000e+00 0.3526177
## [56,]  0.000000e+00 0.3526177
## [57,]  0.000000e+00 0.3526177
## [58,]  0.000000e+00 0.3526177
## [59,]  0.000000e+00 0.3526177
## [60,]  0.000000e+00 0.3526177
## [61,]  0.000000e+00 0.3526177
## [62,]  0.000000e+00 0.3526177
## [63,]  0.000000e+00 0.3526177
## [64,]  0.000000e+00 0.3526177
## [65,]  0.000000e+00 0.3526177
## [66,]  0.000000e+00 0.3526177
## [67,]  0.000000e+00 0.3526177
## [68,]  0.000000e+00 0.3526177
## [69,]  0.000000e+00 0.3526177
## [70,]  0.000000e+00 0.3526177
## [71,]  0.000000e+00 0.3526177
## [72,]  0.000000e+00 0.3526177
## [73,]  0.000000e+00 0.3526177
## [74,]  0.000000e+00 0.3526177
## [75,]  0.000000e+00 0.3526177
## [76,]  0.000000e+00 0.3526177
## [77,]  0.000000e+00 0.3526177
## [78,]  0.000000e+00 0.3526177
## [79,]  0.000000e+00 0.3526177
## [80,]  0.000000e+00 0.3526177
## [81,]  0.000000e+00 0.3526177
## [82,]  0.000000e+00 0.3526177
## [83,]  0.000000e+00 0.3526177
## [84,]  0.000000e+00 0.3526177
## [85,]  0.000000e+00 0.3526177
## [86,]  0.000000e+00 0.3526177
## [87,]  0.000000e+00 0.3526177
## [88,]  0.000000e+00 0.3526177
## [89,]  0.000000e+00 0.3526177
## [90,]  0.000000e+00 0.3526177
#   3.3: R(s) --> V(s, a) `value_iteration()` ----------------------------------
# Take a reward function and calculate the corresponding value function using 
# value iteration.

value_iteration <- function(alpha, beta, P0, P1, sound = "off") {
  
  R <- reward(alpha)
  V <- matrix(data = c(1, rep(0, 180 - 1)), nrow = 90, ncol = 2) #init
  
  for(i in 1:10000) {
    
    # Update V
    V_next <- V_update(V, R, beta, P0, P1)
    
    # Check for convergence of V:
    if (max(abs(V_next - V)) < .01) {
      return(V_next)
    }
    
    if (sound == "on") { print(i) }
    V <- V_next
  }
}

#      Test -------------------------------------------------------------------------

V <- value_iteration(
  alpha = c(1, rep(0, 89)),
  beta = .99,
  P0 = transitiond0,
  P1 = transitiond1,
  sound = "off"
)

# When is the value to keep >= value to replace?
# Only the first state: then, replace it because alpha is (1, 0, 0, 0, ..., 0).

V[, 1] >= V[, 2]
##  [1]  TRUE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [25] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [37] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [49] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [61] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [73] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [85] FALSE FALSE FALSE FALSE FALSE FALSE

4 Generate R{rand} and use it to get V{rand}

# 4 Generate R{rand} and use it to get V{rand} ---------------------------------

set.seed(1234)
alpha_rand <- sample(0:1, size = 90, replace = T)
R_rand <- reward(alpha_rand)
V_rand <- value_iteration(alpha_rand, 
                          beta = .99, 
                          P0 = transitiond0, 
                          P1 = transitiond1)

5 Use V{rand} to get Boltzmann action probabilities

# 5 Use V{rand} to get Boltzmann action probabilities ---------
# Boltzmann: take action a over a' in state s with probability:
#   exp(beta * V(s, a)) / [exp(beta * V(s, a)) + exp(beta * V(s, a'))]

boltzmann_prob_keep <- function(V, beta) {
  exp(beta * V[, 1]) / (exp(beta * V[, 1]) + exp(beta * V[, 2]))
}

#  Test ---------

# Do we usually keep the engine when V_keep > V_replace?
all(c(boltzmann_prob_keep(V_rand, .99) > .5) == (V_rand[, 1] > V_rand[, 2]))
## [1] TRUE

6 Generate one trajectory using V_rand and boltzmann probabilities

# 6 Generate one trajectory using V_rand and boltzmann probabilities ---------

generate_traj <- function(V, beta) {
  
  action_probs <- boltzmann_prob_keep(V, beta)
  traj <- list(s = rep(NA, 104), a = rep(NA, 104)) # init
  
  # initial state
  traj$s[1] <- ceiling(.5*rexp(n = 1, rate = 1))
  
  for (i in 1:104) {
    # take action according to boltzmann action_probs
    traj$a[i] <- sample(0:1, size = 1, prob = c(action_probs[traj$s[1]], 1 - action_probs[traj$s[1]]))
    # step up s according to the action took and transition probabilities
    traj$s[i + 1] <- step(traj$s[i], traj$a[i])
  }
  
  return(traj)
}

#   Test --------------------

generate_traj(V_rand, beta)
## $s
##   [1] 1 2 2 3 2 3 4 2 2 2 1 1 2 2 2 2 4 5 2 3 1 2 1 1 2 3 4 2 2 2 3 2 3 4 5 6 7
##  [38] 2 3 1 1 1 2 3 4 2 2 1 1 2 2 3 3 3 3 3 4 5 6 2 3 2 3 3 4 5 2 2 1 2 1 2 3 1
##  [75] 2 3 2 2 1 1 2 3 2 3 2 2 2 2 1 2 3 4 5 5 6 2 1 1 2 1 1 1 2 2 1
## 
## $a
##   [1] 0 1 0 1 0 0 1 1 1 1 1 1 0 1 0 0 0 1 0 1 0 1 1 0 0 0 1 0 0 0 1 0 0 0 0 0 1
##  [38] 0 1 1 0 1 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 1 1 1 1 1 0 0 1 1
##  [75] 0 1 0 1 0 0 0 1 0 1 0 1 1 1 1 0 0 0 0 0 1 1 0 1 1 0 1 1 0 1

7 Generate 104 trajectories

# 7 Generate 104 trajectories ------------------

generate_104_trajs <- function(V, beta) {
  map(1:104, function(...) generate_traj(V, beta))
}

#   Test -------

traj_rand <- generate_104_trajs(V_rand, beta)
head(traj_rand)
## [[1]]
## [[1]]$s
##   [1] 1 1 2 3 3 3 3 2 2 2 1 1 2 2 2 2 2 3 2 2 3 4 4 2 1 2 1 1 1 1 1 2 2 2 3 3 4
##  [38] 5 2 3 1 2 2 2 3 3 3 3 1 2 2 3 3 4 5 6 6 7 8 8 2 2 2 3 2 2 2 1 1 1 2 2 3 4
##  [75] 5 6 1 1 1 2 3 1 2 2 3 4 2 1 1 2 3 4 5 2 2 3 3 4 5 6 7 7 8 8 8
## 
## [[1]]$a
##   [1] 1 0 0 1 0 0 1 1 1 1 1 1 1 1 0 0 0 1 1 0 0 0 1 1 0 1 1 0 0 1 1 1 0 0 0 0 0
##  [38] 1 0 1 1 1 1 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 1 0 1 1 0 1 1 1 0 0 0 0 0
##  [75] 0 1 0 0 1 0 1 1 1 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0
## 
## 
## [[2]]
## [[2]]$s
##   [1] 1 2 2 3 4 1 1 2 3 4 1 2 3 4 4 4 5 5 1 1 1 2 1 1 2 2 3 3 4 4 5 2 3 4 2 2 2
##  [38] 2 2 2 2 2 3 4 4 5 2 2 3 1 2 2 3 1 2 2 1 1 2 3 4 4 5 5 6 7 2 2 3 2 3 4 4 2
##  [75] 2 2 3 3 2 2 2 3 4 1 2 3 4 2 1 1 2 1 2 3 2 3 4 4 2 2 3 3 3 2 2
## 
## [[2]]$a
##   [1] 0 0 0 0 1 1 1 0 0 1 1 0 0 0 0 0 0 1 1 0 0 1 1 0 1 0 0 0 0 0 1 0 0 1 1 0 1
##  [38] 1 1 0 0 0 0 0 0 1 1 0 1 1 1 0 1 1 0 1 1 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 1 0
##  [75] 1 0 0 1 0 1 0 0 1 0 0 0 1 1 1 1 1 1 0 1 0 0 0 1 1 0 0 0 1 1
## 
## 
## [[3]]
## [[3]]$s
##   [1] 2 2 2 2 2 3 3 2 3 1 1 1 2 3 2 2 2 1 1 1 1 2 2 3 4 2 2 1 2 2 3 2 2 3 3 2 2
##  [38] 2 2 2 2 2 2 3 4 5 2 2 1 2 2 3 4 5 5 6 7 7 1 1 2 2 3 2 2 1 1 1 2 1 2 1 2 3
##  [75] 3 3 4 1 2 2 3 4 2 3 1 2 3 1 2 3 1 1 2 2 2 3 2 3 4 5 5 1 2 3 2
## 
## [[3]]$a
##   [1] 1 1 1 0 0 0 1 0 1 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 1 1 1 1 1 0 1 1 0 0 1 1 1
##  [38] 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 1 0 1 1 0 1 0 1 1 1 1 1 0 1 0 0 0
##  [75] 0 0 1 1 1 0 0 1 0 1 0 0 1 1 0 1 1 1 1 1 0 1 0 0 0 0 1 1 0 1
## 
## 
## [[4]]
## [[4]]$s
##   [1] 1 2 2 3 4 5 6 6 6 7 8 2 1 2 2 1 2 3 1 2 1 2 3 4 5 5 1 1 2 2 2 3 2 2 1 1 1
##  [38] 1 2 1 2 1 2 3 3 1 2 3 2 3 1 2 2 3 3 1 1 1 2 2 3 2 3 4 2 2 3 4 4 2 2 1 2 2
##  [75] 1 2 1 1 1 2 2 2 3 2 2 1 1 2 2 2 2 3 1 2 3 3 1 2 3 4 2 1 1 2 2
## 
## [[4]]$a
##   [1] 0 1 0 0 0 0 0 0 0 0 1 1 0 0 1 1 0 1 0 1 1 1 0 0 0 1 0 0 0 1 0 1 0 1 1 1 1
##  [38] 1 1 1 1 0 0 0 1 0 0 1 0 1 0 1 0 0 1 1 1 1 0 0 1 0 0 1 1 0 0 0 1 0 1 1 1 1
##  [75] 1 1 1 0 1 1 1 0 1 1 1 0 1 1 1 1 0 1 1 0 0 1 1 0 0 1 1 1 0 1
## 
## 
## [[5]]
## [[5]]$s
##   [1] 1 2 1 2 1 2 2 1 2 3 2 2 2 1 2 2 2 3 2 2 2 2 2 3 3 4 5 6 1 2 3 2 2 3 4 1 1
##  [38] 1 2 3 3 1 2 1 2 3 4 1 2 2 2 2 2 3 4 2 3 3 4 1 2 3 1 2 2 2 2 2 2 3 3 1 1 2
##  [75] 2 3 2 2 2 1 1 1 1 1 2 1 2 2 3 4 1 1 2 1 1 2 2 2 2 2 2 2 2 2 2
## 
## [[5]]$a
##   [1] 0 1 1 1 0 1 1 1 0 1 1 1 1 0 0 1 0 1 1 0 1 1 0 0 0 0 0 1 1 0 1 1 0 0 1 0 0
##  [38] 1 0 0 1 0 1 0 0 0 1 1 0 0 1 1 0 0 1 0 0 0 1 1 0 1 0 0 1 0 1 1 0 0 1 0 0 1
##  [75] 0 1 1 1 1 0 1 0 1 1 1 0 1 0 0 1 1 0 1 0 1 1 1 1 1 1 0 1 0 0
## 
## 
## [[6]]
## [[6]]$s
##   [1] 1 2 2 2 3 3 4 5 2 3 1 2 2 3 3 4 5 6 7 7 7 1 2 2 2 2 1 1 1 1 2 3 2 2 2 3 3
##  [38] 2 1 1 2 3 2 3 4 2 2 2 2 2 2 1 2 2 3 4 2 2 1 1 1 2 3 2 2 3 2 1 1 2 4 5 2 3
##  [75] 1 2 3 1 2 3 3 4 5 5 6 7 2 2 1 2 2 2 2 1 2 3 1 1 3 2 3 1 1 2 3
## 
## [[6]]$a
##   [1] 0 1 1 0 0 0 0 1 0 1 1 1 0 0 0 0 0 0 0 0 1 1 1 0 1 1 1 0 0 0 0 1 1 1 0 0 1
##  [38] 1 0 0 0 1 0 0 1 1 0 1 1 0 1 1 1 0 0 1 1 1 0 0 1 0 1 0 0 1 1 0 0 0 0 1 0 1
##  [75] 1 0 1 0 0 0 0 0 0 0 0 1 1 1 1 1 0 1 1 1 0 1 1 1 1 0 1 1 1 0

8 Calculate visiting weights for each state in a trajectory

# 8 Calculate visiting weights for each state in a trajectory ------------------

# In the IRL algorithm, I'll need to calculate weights for each state given
# a certain trajectory that count up the number of times that state is visited,
# discounted by beta^t. I follow Ng and Russell (2000) by using 90 evenly
# spaced basis functions, so being in bucket 8 also contributes somewhat
# to the weight of bucket 9 and 7.

weight_state_i <- function(trajectory, gaus_center) {
  density <- dnorm(trajectory, mean = gaus_center, sd = 1)
  data_size <- length(trajectory)
  discount_seq <- beta^(0:(data_size - 1))
  sum(density*discount_seq)
}

expect_true(
  weight_state_i(c(1, 1, 2), 1) > weight_state_i(c(1, 1, 2), 2),
  weight_state_i(c(1, 1, 2), 2) > weight_state_i(c(1, 1, 2), 3)
)

weight_state_i(traj_rand[[1]]$s, 1)
## [1] 12.22604
weight_state_i(rust[[1]]$s, 1)
## [1] 0.9366354
weights <- function(trajectories_list) {
  # Basically weights each state by how often it's visited in the 104 
  # trajectories (except that trajectories that are visited later in the 
  # trajectory are discounted by beta^t).
  
  tr_states <- trajectories_list %>%
    flatten() %>%
    keep(names(.) == 's')
  
  map_dbl(1:90, function(i) mean(map_dbl(tr_states, weight_state_i, i)))
}

#   Test ------

tibble(
  w = weights(rust)
) %>%
  ggplot(aes(x = 1:90, y = w)) +
  geom_point() +
  geom_line()

9 Use an LP solver to max difference in visiting weights between data and Traj{rand}

# 9 Use an LP solver to max difference in visiting weights between data and Traj{rand} ---------------

# Standard LP form: maximize c'x
#                   s/t Ax <= b
#                       x >= 0
# |alpha_i| < 1: use gamma_i where gamma_i = alpha_i + 1
# Then constraint: gamma_i \in [0, 2].
# 
# x = (gamma_1, gamma_2, ..., gamma_90)
# c: weights(rust) - weights(traj_rand)
#   and when we iterate, c_new = c_old + weights(rust) - weights(traj_rand)
# constraints: gamma_i^+ <= 2 (>= 0 is taken for granted with lpSolve::lp)
# A = diag(90)
# b: rep(2, 90)

alpha_calculator <- function(objective) {
  lpSolve::lp(
    direction = "max",
    objective.in = objective,
    const.mat = diag(90),
    const.dir = rep("<=", 90),
    const.rhs = matrix(rep(2, 90), nrow = 90)
  )$solution - rep(1, 90)
}

#   Test -------------------------------------------------

alpha_calculator(
  matrix(weights(rust) - weights(traj_rand), nrow = 90)
)
##  [1] -1 -1 -1 -1 -1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
## [26]  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
## [51]  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
## [76]  1  1  1  1  1  1  1  1  1 -1 -1 -1 -1 -1 -1

10 Iterating the process: IRL function

# 10 Iterating the process: IRL function ----------------------------------------

# Use the new LP alpha estimate to find R(s), then use value iteration to find
# V(s, a), then use Boltzmann to generate new trajectories, and add the difference
# in weights between the new trajectories and the Rust data to the LP solver
# objective function. Look for convergence in alpha's for the IRL Reward function
# estimate.

IRL <- function(dataset, beta, P0, P1, min_its) {
  
  # Init: random value function
  set.seed(1234)
  V_rand <- value_iteration(
    alpha = sample(0:1, size = 90, replace = T), beta = .99, 
    P0 = transitiond0, P1 = transitiond1
  )
  
  # Generate 104 trajectories under V_rand
  traj_rand <- generate_104_trajs(V_rand, beta)
  
  # Objective function: max difference in weights between dataset and simulated trajs
  objective <- weights(dataset) - weights(traj_rand)
  
  # Use LP solver to find new estimate for alpha's
  alpha <- alpha_calculator(matrix(objective, nrow = 90))
  
  for (i in 1:100) {
    # 100 iterations possible
    
    V <- value_iteration(
      alpha = alpha, beta = .99, P0 = transitiond0, P1 = transitiond1
    )
    
    traj <- generate_104_trajs(V, beta)
    
    objective <- objective + weights(dataset) - weights(traj)
    
    alpha1 <- alpha_calculator(matrix(objective, nrow = 90))
    
    if (i > min_its & max(abs(alpha1 - alpha)) < .1) {
      return(alpha1)
    }
    
    alpha <- alpha1
    print(i)
  }
}

11 Running IRL on Rust

# 11 Running IRL on Rust ------------------------------------------------------

IRL_R_estimate <- IRL(rust, .99, transitiond0, transitiond1, 10) %>%
  reward()
## [1] 1
## [1] 2
## [1] 3
## [1] 4
## [1] 5
## [1] 6
## [1] 7
## [1] 8
## [1] 9
## [1] 10
tibble(
  s = 1:90,
  x = IRL_R_estimate
) %>%
  ggplot(aes(x = s, y = x)) +
  geom_line()

V_IRL_estimate <- value_iteration(IRL_R_estimate, .99, transitiond0, transitiond1)

tibble(
  x = 1:90,
  y = V_IRL_estimate[,1]
) %>%
  ggplot(aes(x = x, y = y)) +
  geom_line() +
  geom_point()

pk <- boltzmann_prob_keep(V_IRL_estimate, .99)

tibble(
  x = 1:90,
  y = pk
) %>%
  ggplot(aes(x = x, y = y)) +
  geom_line() +
  geom_point()