library(tidyverse)
library(testthat)
library(tictoc)
source("~/Documents/JMP/code/rust3_probtransitionmatrix.R")
2 Define a step function step()
# 2 Define a step function `step()` ----------------------------------------------
# I need to be able to simulate trajectories given the approximated probability
# transition matrix and any policy. Create a function `step` that simulates a
# bus stepping up a month given an action a (keep or replace the engine).
step <- function(x, a) {
# Takes mileage x
# Returns x' sampled from P(x'|x)
state <- rep(0, 90)
state[x] <- 1
p <- matrix(state, ncol = 90) %*% (
if (a == 0) { transitiond0 } else { transitiond1 }
)
sample(1:90, size = 1, prob = as.vector(p))
}
# Keeping the engine when x = 5: x' will be 5, 6, or 7 with probabilities
# .36, .63, and .01.
map_dbl(1:300, function(...) step(5, 0))
## [1] 6 6 6 6 5 6 6 5 6 6 5 6 5 5 5 6 5 6 6 5 5 5 7 5 6 6 6 6 5 7 5 5 6 5 5 5 5
## [38] 5 6 7 5 5 6 5 6 6 6 6 6 6 6 6 6 5 5 5 6 6 6 6 5 6 6 5 5 5 5 5 6 6 6 5 6 6
## [75] 6 6 5 6 6 6 5 6 5 6 6 5 5 5 5 6 6 6 5 6 6 6 5 6 6 6 5 6 6 5 6 6 5 6 5 5 6
## [112] 5 5 5 6 6 6 5 6 6 6 6 5 6 5 6 6 5 6 5 5 6 6 5 6 6 6 6 6 6 6 6 6 6 5 6 6 6
## [149] 6 5 6 6 6 5 5 6 6 6 5 6 6 6 5 6 6 6 6 6 6 5 6 6 6 5 6 5 6 5 6 6 6 6 6 6 5
## [186] 6 6 6 6 5 5 6 5 5 6 6 6 6 6 6 5 6 6 6 5 7 6 5 6 5 6 5 6 6 5 6 6 5 6 6 6 5
## [223] 6 5 6 6 5 6 5 6 5 6 6 6 5 6 6 6 5 5 6 5 5 5 6 6 6 5 5 6 6 5 5 5 6 5 5 6 5
## [260] 5 5 6 5 5 6 6 5 5 5 6 6 6 6 6 6 5 6 6 6 5 5 6 6 5 6 6 6 5 6 6 6 5 5 6 6 5
## [297] 5 6 6 6
# Keep the engine when x = 5: x' will be 1, 2, or 3 with probabilities
# .36, .63, and .01.
map_dbl(1:300, function(...) step(5, 1))
## [1] 2 1 2 1 2 2 2 1 2 2 2 1 1 1 2 1 1 2 2 2 2 2 1 1 1 2 1 2 2 1 2 2 2 2 2 2 1
## [38] 2 2 2 2 2 1 2 2 2 1 1 3 2 2 2 2 2 2 2 2 2 2 1 2 1 1 1 1 2 2 1 2 2 2 1 1 2
## [75] 2 2 2 2 2 1 2 2 2 1 1 2 1 1 1 2 2 2 3 1 1 2 2 2 2 2 2 2 2 2 2 2 2 1 2 1 2
## [112] 2 2 2 2 2 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 1 2 2 2 1 1 2 2 1
## [149] 2 2 2 2 2 2 1 1 2 2 2 1 2 2 1 2 1 2 1 2 2 1 2 1 2 1 2 1 2 1 2 1 2 2 1 2 1
## [186] 1 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 1 2 2 1 2 2 2 2 1 1 2 2 2 2 1 2 1 3 1 2
## [223] 2 1 2 2 2 2 1 3 1 1 2 1 2 1 2 2 2 2 1 1 1 1 1 1 2 2 1 1 1 1 2 2 2 2 2 1 1
## [260] 1 2 2 2 2 2 1 1 2 1 2 2 2 2 2 2 2 2 1 1 2 1 3 2 1 2 2 2 2 2 2 1 1 2 1 2 2
## [297] 2 1 2 2
3 Bellman helper functions
# 3 Bellman helper functions -------------------------------------------------------
# 3.1: alphas --> R(s) `reward()` ----------------------------------------------
reward <- function(alpha) {
map_dbl(1:90, ~ sum(alpha * dnorm(.x, 1:90, sd = 1)))
}
# Tests: if alpha = c(1, 0, 0, ..., 0), R(s) should be -------------------------
# (dnorm(1, mean = 1, sd = 1), smaller, smaller-er, )
reward(c(1, rep(0, 90 - 1)))
## [1] 3.989423e-01 2.419707e-01 5.399097e-02 4.431848e-03 1.338302e-04
## [6] 1.486720e-06 6.075883e-09 9.134720e-12 5.052271e-15 1.027977e-18
## [11] 7.694599e-23 2.118819e-27 2.146384e-32 7.998828e-38 1.096607e-43
## [16] 5.530710e-50 1.026163e-56 7.004182e-64 1.758750e-71 1.624636e-79
## [21] 5.520948e-88 6.902029e-97 3.174282e-106 5.370560e-116 3.342714e-126
## [26] 7.653930e-137 6.447260e-148 1.997889e-159 2.277577e-171 9.551695e-184
## [31] 1.473646e-196 8.363952e-210 1.746366e-223 1.341420e-237 3.790526e-252
## [36] 3.940396e-267 1.506905e-282 2.120007e-298 1.097221e-314 0.000000e+00
## [41] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [46] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [51] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [56] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [61] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [66] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [71] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [76] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [81] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [86] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
# Test 2: symmetry of alpha = c(1, 0, 0, ..., 0, 1)
reward(c(1, rep(0, 88), 1))
## [1] 3.989423e-01 2.419707e-01 5.399097e-02 4.431848e-03 1.338302e-04
## [6] 1.486720e-06 6.075883e-09 9.134720e-12 5.052271e-15 1.027977e-18
## [11] 7.694599e-23 2.118819e-27 2.146384e-32 7.998828e-38 1.096607e-43
## [16] 5.530710e-50 1.026163e-56 7.004182e-64 1.758750e-71 1.624636e-79
## [21] 5.520948e-88 6.902029e-97 3.174282e-106 5.370560e-116 3.342714e-126
## [26] 7.653930e-137 6.447260e-148 1.997889e-159 2.277577e-171 9.551695e-184
## [31] 1.473646e-196 8.363952e-210 1.746366e-223 1.341420e-237 3.790526e-252
## [36] 3.940396e-267 1.506905e-282 2.120007e-298 1.097221e-314 0.000000e+00
## [41] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [46] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [51] 0.000000e+00 1.097221e-314 2.120007e-298 1.506905e-282 3.940396e-267
## [56] 3.790526e-252 1.341420e-237 1.746366e-223 8.363952e-210 1.473646e-196
## [61] 9.551695e-184 2.277577e-171 1.997889e-159 6.447260e-148 7.653930e-137
## [66] 3.342714e-126 5.370560e-116 3.174282e-106 6.902029e-97 5.520948e-88
## [71] 1.624636e-79 1.758750e-71 7.004182e-64 1.026163e-56 5.530710e-50
## [76] 1.096607e-43 7.998828e-38 2.146384e-32 2.118819e-27 7.694599e-23
## [81] 1.027977e-18 5.052271e-15 9.134720e-12 6.075883e-09 1.486720e-06
## [86] 1.338302e-04 4.431848e-03 5.399097e-02 2.419707e-01 3.989423e-01
# 3.2: Value iteration helper function `V_update()` --------------------------
# Take a guess for V(s, a) and update it according to the Bellman Equation:
V_update <- function(V, R, beta, P0, P1) {
cbind(
# v(s, a) is 90x2
# v(s, 0): value of *keeping* the engine in state s and then behaving
# optimally thereafter
matrix(
data = R + beta * P0 %*% pmax(V[, 1], V[, 2]),
# V(s, 0) = R(s) + beta * P(s' | s, 0) %*% V*(s)
nrow = 90,
ncol = 1
),
# v(s, 1): value of *replacing* the engine in state s and then behaving
# optimally thereafter
matrix(
data = R + beta * P1 %*% pmax(V[, 1], V[, 2]),
# V(s, 1) = R(s) + beta * P(s' | s, 1) %*% V*(s)
nrow = 90,
ncol = 1
)
)
}
# Test: ---------------------------------------------
V_update(
V = matrix(data = c(1, rep(0, 180 - 1)), nrow = 90, ncol = 2),
R = reward(c(1, 1, 1, rep(0, 180 - 3))),
beta = .99,
P0 = transitiond0,
P1 = transitiond1
)
## [,1] [,2]
## [1,] 1.047522e+00 1.0475217
## [2,] 8.828837e-01 1.2355014
## [3,] 6.949040e-01 1.0475217
## [4,] 3.003935e-01 0.6530112
## [5,] 5.855665e-02 0.4111743
## [6,] 4.567165e-03 0.3571849
## [7,] 1.353230e-04 0.3527530
## [8,] 1.492805e-06 0.3526192
## [9,] 6.085023e-09 0.3526177
## [10,] 9.139774e-12 0.3526177
## [11,] 5.053299e-15 0.3526177
## [12,] 1.028054e-18 0.3526177
## [13,] 7.694811e-23 0.3526177
## [14,] 2.118841e-27 0.3526177
## [15,] 2.146392e-32 0.3526177
## [16,] 7.998839e-38 0.3526177
## [17,] 1.096607e-43 0.3526177
## [18,] 5.530711e-50 0.3526177
## [19,] 1.026163e-56 0.3526177
## [20,] 7.004182e-64 0.3526177
## [21,] 1.758750e-71 0.3526177
## [22,] 1.624636e-79 0.3526177
## [23,] 5.520948e-88 0.3526177
## [24,] 6.902029e-97 0.3526177
## [25,] 3.174282e-106 0.3526177
## [26,] 5.370560e-116 0.3526177
## [27,] 3.342714e-126 0.3526177
## [28,] 7.653930e-137 0.3526177
## [29,] 6.447260e-148 0.3526177
## [30,] 1.997889e-159 0.3526177
## [31,] 2.277577e-171 0.3526177
## [32,] 9.551695e-184 0.3526177
## [33,] 1.473646e-196 0.3526177
## [34,] 8.363952e-210 0.3526177
## [35,] 1.746366e-223 0.3526177
## [36,] 1.341420e-237 0.3526177
## [37,] 3.790526e-252 0.3526177
## [38,] 3.940396e-267 0.3526177
## [39,] 1.506905e-282 0.3526177
## [40,] 2.120007e-298 0.3526177
## [41,] 1.097221e-314 0.3526177
## [42,] 0.000000e+00 0.3526177
## [43,] 0.000000e+00 0.3526177
## [44,] 0.000000e+00 0.3526177
## [45,] 0.000000e+00 0.3526177
## [46,] 0.000000e+00 0.3526177
## [47,] 0.000000e+00 0.3526177
## [48,] 0.000000e+00 0.3526177
## [49,] 0.000000e+00 0.3526177
## [50,] 0.000000e+00 0.3526177
## [51,] 0.000000e+00 0.3526177
## [52,] 0.000000e+00 0.3526177
## [53,] 0.000000e+00 0.3526177
## [54,] 0.000000e+00 0.3526177
## [55,] 0.000000e+00 0.3526177
## [56,] 0.000000e+00 0.3526177
## [57,] 0.000000e+00 0.3526177
## [58,] 0.000000e+00 0.3526177
## [59,] 0.000000e+00 0.3526177
## [60,] 0.000000e+00 0.3526177
## [61,] 0.000000e+00 0.3526177
## [62,] 0.000000e+00 0.3526177
## [63,] 0.000000e+00 0.3526177
## [64,] 0.000000e+00 0.3526177
## [65,] 0.000000e+00 0.3526177
## [66,] 0.000000e+00 0.3526177
## [67,] 0.000000e+00 0.3526177
## [68,] 0.000000e+00 0.3526177
## [69,] 0.000000e+00 0.3526177
## [70,] 0.000000e+00 0.3526177
## [71,] 0.000000e+00 0.3526177
## [72,] 0.000000e+00 0.3526177
## [73,] 0.000000e+00 0.3526177
## [74,] 0.000000e+00 0.3526177
## [75,] 0.000000e+00 0.3526177
## [76,] 0.000000e+00 0.3526177
## [77,] 0.000000e+00 0.3526177
## [78,] 0.000000e+00 0.3526177
## [79,] 0.000000e+00 0.3526177
## [80,] 0.000000e+00 0.3526177
## [81,] 0.000000e+00 0.3526177
## [82,] 0.000000e+00 0.3526177
## [83,] 0.000000e+00 0.3526177
## [84,] 0.000000e+00 0.3526177
## [85,] 0.000000e+00 0.3526177
## [86,] 0.000000e+00 0.3526177
## [87,] 0.000000e+00 0.3526177
## [88,] 0.000000e+00 0.3526177
## [89,] 0.000000e+00 0.3526177
## [90,] 0.000000e+00 0.3526177
# 3.3: R(s) --> V(s, a) `value_iteration()` ----------------------------------
# Take a reward function and calculate the corresponding value function using
# value iteration.
value_iteration <- function(alpha, beta, P0, P1, sound = "off") {
R <- reward(alpha)
V <- matrix(data = c(1, rep(0, 180 - 1)), nrow = 90, ncol = 2) #init
for(i in 1:10000) {
# Update V
V_next <- V_update(V, R, beta, P0, P1)
# Check for convergence of V:
if (max(abs(V_next - V)) < .01) {
return(V_next)
}
if (sound == "on") { print(i) }
V <- V_next
}
}
# Test -------------------------------------------------------------------------
V <- value_iteration(
alpha = c(1, rep(0, 89)),
beta = .99,
P0 = transitiond0,
P1 = transitiond1,
sound = "off"
)
# When is the value to keep >= value to replace?
# Only the first state: then, replace it because alpha is (1, 0, 0, 0, ..., 0).
V[, 1] >= V[, 2]
## [1] TRUE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [25] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [37] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [49] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [61] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [73] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [85] FALSE FALSE FALSE FALSE FALSE FALSE
4 Generate R{rand} and use it to get V{rand}
# 4 Generate R{rand} and use it to get V{rand} ---------------------------------
set.seed(1234)
alpha_rand <- sample(0:1, size = 90, replace = T)
R_rand <- reward(alpha_rand)
V_rand <- value_iteration(alpha_rand,
beta = .99,
P0 = transitiond0,
P1 = transitiond1)
5 Use V{rand} to get Boltzmann action probabilities
# 5 Use V{rand} to get Boltzmann action probabilities ---------
# Boltzmann: take action a over a' in state s with probability:
# exp(beta * V(s, a)) / [exp(beta * V(s, a)) + exp(beta * V(s, a'))]
boltzmann_prob_keep <- function(V, beta) {
exp(beta * V[, 1]) / (exp(beta * V[, 1]) + exp(beta * V[, 2]))
}
# Test ---------
# Do we usually keep the engine when V_keep > V_replace?
all(c(boltzmann_prob_keep(V_rand, .99) > .5) == (V_rand[, 1] > V_rand[, 2]))
## [1] TRUE
6 Generate one trajectory using V_rand and boltzmann
probabilities
# 6 Generate one trajectory using V_rand and boltzmann probabilities ---------
generate_traj <- function(V, beta) {
action_probs <- boltzmann_prob_keep(V, beta)
traj <- list(s = rep(NA, 104), a = rep(NA, 104)) # init
# initial state
traj$s[1] <- ceiling(.5*rexp(n = 1, rate = 1))
for (i in 1:104) {
# take action according to boltzmann action_probs
traj$a[i] <- sample(0:1, size = 1, prob = c(action_probs[traj$s[1]], 1 - action_probs[traj$s[1]]))
# step up s according to the action took and transition probabilities
traj$s[i + 1] <- step(traj$s[i], traj$a[i])
}
return(traj)
}
# Test --------------------
generate_traj(V_rand, beta)
## $s
## [1] 1 2 2 3 2 3 4 2 2 2 1 1 2 2 2 2 4 5 2 3 1 2 1 1 2 3 4 2 2 2 3 2 3 4 5 6 7
## [38] 2 3 1 1 1 2 3 4 2 2 1 1 2 2 3 3 3 3 3 4 5 6 2 3 2 3 3 4 5 2 2 1 2 1 2 3 1
## [75] 2 3 2 2 1 1 2 3 2 3 2 2 2 2 1 2 3 4 5 5 6 2 1 1 2 1 1 1 2 2 1
##
## $a
## [1] 0 1 0 1 0 0 1 1 1 1 1 1 0 1 0 0 0 1 0 1 0 1 1 0 0 0 1 0 0 0 1 0 0 0 0 0 1
## [38] 0 1 1 0 1 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 1 1 1 1 1 0 0 1 1
## [75] 0 1 0 1 0 0 0 1 0 1 0 1 1 1 1 0 0 0 0 0 1 1 0 1 1 0 1 1 0 1
7 Generate 104 trajectories
# 7 Generate 104 trajectories ------------------
generate_104_trajs <- function(V, beta) {
map(1:104, function(...) generate_traj(V, beta))
}
# Test -------
traj_rand <- generate_104_trajs(V_rand, beta)
head(traj_rand)
## [[1]]
## [[1]]$s
## [1] 1 1 2 3 3 3 3 2 2 2 1 1 2 2 2 2 2 3 2 2 3 4 4 2 1 2 1 1 1 1 1 2 2 2 3 3 4
## [38] 5 2 3 1 2 2 2 3 3 3 3 1 2 2 3 3 4 5 6 6 7 8 8 2 2 2 3 2 2 2 1 1 1 2 2 3 4
## [75] 5 6 1 1 1 2 3 1 2 2 3 4 2 1 1 2 3 4 5 2 2 3 3 4 5 6 7 7 8 8 8
##
## [[1]]$a
## [1] 1 0 0 1 0 0 1 1 1 1 1 1 1 1 0 0 0 1 1 0 0 0 1 1 0 1 1 0 0 1 1 1 0 0 0 0 0
## [38] 1 0 1 1 1 1 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 1 0 1 1 0 1 1 1 0 0 0 0 0
## [75] 0 1 0 0 1 0 1 1 1 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0
##
##
## [[2]]
## [[2]]$s
## [1] 1 2 2 3 4 1 1 2 3 4 1 2 3 4 4 4 5 5 1 1 1 2 1 1 2 2 3 3 4 4 5 2 3 4 2 2 2
## [38] 2 2 2 2 2 3 4 4 5 2 2 3 1 2 2 3 1 2 2 1 1 2 3 4 4 5 5 6 7 2 2 3 2 3 4 4 2
## [75] 2 2 3 3 2 2 2 3 4 1 2 3 4 2 1 1 2 1 2 3 2 3 4 4 2 2 3 3 3 2 2
##
## [[2]]$a
## [1] 0 0 0 0 1 1 1 0 0 1 1 0 0 0 0 0 0 1 1 0 0 1 1 0 1 0 0 0 0 0 1 0 0 1 1 0 1
## [38] 1 1 0 0 0 0 0 0 1 1 0 1 1 1 0 1 1 0 1 1 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 1 0
## [75] 1 0 0 1 0 1 0 0 1 0 0 0 1 1 1 1 1 1 0 1 0 0 0 1 1 0 0 0 1 1
##
##
## [[3]]
## [[3]]$s
## [1] 2 2 2 2 2 3 3 2 3 1 1 1 2 3 2 2 2 1 1 1 1 2 2 3 4 2 2 1 2 2 3 2 2 3 3 2 2
## [38] 2 2 2 2 2 2 3 4 5 2 2 1 2 2 3 4 5 5 6 7 7 1 1 2 2 3 2 2 1 1 1 2 1 2 1 2 3
## [75] 3 3 4 1 2 2 3 4 2 3 1 2 3 1 2 3 1 1 2 2 2 3 2 3 4 5 5 1 2 3 2
##
## [[3]]$a
## [1] 1 1 1 0 0 0 1 0 1 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 1 1 1 1 1 0 1 1 0 0 1 1 1
## [38] 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 1 0 1 1 0 1 0 1 1 1 1 1 0 1 0 0 0
## [75] 0 0 1 1 1 0 0 1 0 1 0 0 1 1 0 1 1 1 1 1 0 1 0 0 0 0 1 1 0 1
##
##
## [[4]]
## [[4]]$s
## [1] 1 2 2 3 4 5 6 6 6 7 8 2 1 2 2 1 2 3 1 2 1 2 3 4 5 5 1 1 2 2 2 3 2 2 1 1 1
## [38] 1 2 1 2 1 2 3 3 1 2 3 2 3 1 2 2 3 3 1 1 1 2 2 3 2 3 4 2 2 3 4 4 2 2 1 2 2
## [75] 1 2 1 1 1 2 2 2 3 2 2 1 1 2 2 2 2 3 1 2 3 3 1 2 3 4 2 1 1 2 2
##
## [[4]]$a
## [1] 0 1 0 0 0 0 0 0 0 0 1 1 0 0 1 1 0 1 0 1 1 1 0 0 0 1 0 0 0 1 0 1 0 1 1 1 1
## [38] 1 1 1 1 0 0 0 1 0 0 1 0 1 0 1 0 0 1 1 1 1 0 0 1 0 0 1 1 0 0 0 1 0 1 1 1 1
## [75] 1 1 1 0 1 1 1 0 1 1 1 0 1 1 1 1 0 1 1 0 0 1 1 0 0 1 1 1 0 1
##
##
## [[5]]
## [[5]]$s
## [1] 1 2 1 2 1 2 2 1 2 3 2 2 2 1 2 2 2 3 2 2 2 2 2 3 3 4 5 6 1 2 3 2 2 3 4 1 1
## [38] 1 2 3 3 1 2 1 2 3 4 1 2 2 2 2 2 3 4 2 3 3 4 1 2 3 1 2 2 2 2 2 2 3 3 1 1 2
## [75] 2 3 2 2 2 1 1 1 1 1 2 1 2 2 3 4 1 1 2 1 1 2 2 2 2 2 2 2 2 2 2
##
## [[5]]$a
## [1] 0 1 1 1 0 1 1 1 0 1 1 1 1 0 0 1 0 1 1 0 1 1 0 0 0 0 0 1 1 0 1 1 0 0 1 0 0
## [38] 1 0 0 1 0 1 0 0 0 1 1 0 0 1 1 0 0 1 0 0 0 1 1 0 1 0 0 1 0 1 1 0 0 1 0 0 1
## [75] 0 1 1 1 1 0 1 0 1 1 1 0 1 0 0 1 1 0 1 0 1 1 1 1 1 1 0 1 0 0
##
##
## [[6]]
## [[6]]$s
## [1] 1 2 2 2 3 3 4 5 2 3 1 2 2 3 3 4 5 6 7 7 7 1 2 2 2 2 1 1 1 1 2 3 2 2 2 3 3
## [38] 2 1 1 2 3 2 3 4 2 2 2 2 2 2 1 2 2 3 4 2 2 1 1 1 2 3 2 2 3 2 1 1 2 4 5 2 3
## [75] 1 2 3 1 2 3 3 4 5 5 6 7 2 2 1 2 2 2 2 1 2 3 1 1 3 2 3 1 1 2 3
##
## [[6]]$a
## [1] 0 1 1 0 0 0 0 1 0 1 1 1 0 0 0 0 0 0 0 0 1 1 1 0 1 1 1 0 0 0 0 1 1 1 0 0 1
## [38] 1 0 0 0 1 0 0 1 1 0 1 1 0 1 1 1 0 0 1 1 1 0 0 1 0 1 0 0 1 1 0 0 0 0 1 0 1
## [75] 1 0 1 0 0 0 0 0 0 0 0 1 1 1 1 1 0 1 1 1 0 1 1 1 1 0 1 1 1 0
8 Calculate visiting weights for each state in a trajectory
# 8 Calculate visiting weights for each state in a trajectory ------------------
# In the IRL algorithm, I'll need to calculate weights for each state given
# a certain trajectory that count up the number of times that state is visited,
# discounted by beta^t. I follow Ng and Russell (2000) by using 90 evenly
# spaced basis functions, so being in bucket 8 also contributes somewhat
# to the weight of bucket 9 and 7.
weight_state_i <- function(trajectory, gaus_center) {
density <- dnorm(trajectory, mean = gaus_center, sd = 1)
data_size <- length(trajectory)
discount_seq <- beta^(0:(data_size - 1))
sum(density*discount_seq)
}
expect_true(
weight_state_i(c(1, 1, 2), 1) > weight_state_i(c(1, 1, 2), 2),
weight_state_i(c(1, 1, 2), 2) > weight_state_i(c(1, 1, 2), 3)
)
weight_state_i(traj_rand[[1]]$s, 1)
## [1] 12.22604
weight_state_i(rust[[1]]$s, 1)
## [1] 0.9366354
weights <- function(trajectories_list) {
# Basically weights each state by how often it's visited in the 104
# trajectories (except that trajectories that are visited later in the
# trajectory are discounted by beta^t).
tr_states <- trajectories_list %>%
flatten() %>%
keep(names(.) == 's')
map_dbl(1:90, function(i) mean(map_dbl(tr_states, weight_state_i, i)))
}
# Test ------
tibble(
w = weights(rust)
) %>%
ggplot(aes(x = 1:90, y = w)) +
geom_point() +
geom_line()

9 Use an LP solver to max difference in visiting weights between
data and Traj{rand}
# 9 Use an LP solver to max difference in visiting weights between data and Traj{rand} ---------------
# Standard LP form: maximize c'x
# s/t Ax <= b
# x >= 0
# |alpha_i| < 1: use gamma_i where gamma_i = alpha_i + 1
# Then constraint: gamma_i \in [0, 2].
#
# x = (gamma_1, gamma_2, ..., gamma_90)
# c: weights(rust) - weights(traj_rand)
# and when we iterate, c_new = c_old + weights(rust) - weights(traj_rand)
# constraints: gamma_i^+ <= 2 (>= 0 is taken for granted with lpSolve::lp)
# A = diag(90)
# b: rep(2, 90)
alpha_calculator <- function(objective) {
lpSolve::lp(
direction = "max",
objective.in = objective,
const.mat = diag(90),
const.dir = rep("<=", 90),
const.rhs = matrix(rep(2, 90), nrow = 90)
)$solution - rep(1, 90)
}
# Test -------------------------------------------------
alpha_calculator(
matrix(weights(rust) - weights(traj_rand), nrow = 90)
)
## [1] -1 -1 -1 -1 -1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [26] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [51] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [76] 1 1 1 1 1 1 1 1 1 -1 -1 -1 -1 -1 -1
10 Iterating the process: IRL function
# 10 Iterating the process: IRL function ----------------------------------------
# Use the new LP alpha estimate to find R(s), then use value iteration to find
# V(s, a), then use Boltzmann to generate new trajectories, and add the difference
# in weights between the new trajectories and the Rust data to the LP solver
# objective function. Look for convergence in alpha's for the IRL Reward function
# estimate.
IRL <- function(dataset, beta, P0, P1, min_its) {
# Init: random value function
set.seed(1234)
V_rand <- value_iteration(
alpha = sample(0:1, size = 90, replace = T), beta = .99,
P0 = transitiond0, P1 = transitiond1
)
# Generate 104 trajectories under V_rand
traj_rand <- generate_104_trajs(V_rand, beta)
# Objective function: max difference in weights between dataset and simulated trajs
objective <- weights(dataset) - weights(traj_rand)
# Use LP solver to find new estimate for alpha's
alpha <- alpha_calculator(matrix(objective, nrow = 90))
for (i in 1:100) {
# 100 iterations possible
V <- value_iteration(
alpha = alpha, beta = .99, P0 = transitiond0, P1 = transitiond1
)
traj <- generate_104_trajs(V, beta)
objective <- objective + weights(dataset) - weights(traj)
alpha1 <- alpha_calculator(matrix(objective, nrow = 90))
if (i > min_its & max(abs(alpha1 - alpha)) < .1) {
return(alpha1)
}
alpha <- alpha1
print(i)
}
}
11 Running IRL on Rust
# 11 Running IRL on Rust ------------------------------------------------------
IRL_R_estimate <- IRL(rust, .99, transitiond0, transitiond1, 10) %>%
reward()
## [1] 1
## [1] 2
## [1] 3
## [1] 4
## [1] 5
## [1] 6
## [1] 7
## [1] 8
## [1] 9
## [1] 10
tibble(
s = 1:90,
x = IRL_R_estimate
) %>%
ggplot(aes(x = s, y = x)) +
geom_line()

V_IRL_estimate <- value_iteration(IRL_R_estimate, .99, transitiond0, transitiond1)
tibble(
x = 1:90,
y = V_IRL_estimate[,1]
) %>%
ggplot(aes(x = x, y = y)) +
geom_line() +
geom_point()

pk <- boltzmann_prob_keep(V_IRL_estimate, .99)
tibble(
x = 1:90,
y = pk
) %>%
ggplot(aes(x = x, y = y)) +
geom_line() +
geom_point()
