Now lets consider a grid world, the grid looks like this:
grid <- matrix(0,nrow = 4,ncol = 4)
print(grid)
## [,1] [,2] [,3] [,4]
## [1,] 0 0 0 0
## [2,] 0 0 0 0
## [3,] 0 0 0 0
## [4,] 0 0 0 0
state grid[1,1] and grid[4,4] is terminal state. For each non-terminal state, there are four actions \(A=\{up,down,left,right\}\), those actions that would take the agent off the grid in fact leave the state unchanged. At terminal states, the reward is 0; otherwise, the reward is -1. For the environment, choosing all direction is of equal probability. That’s to say \(Pr\{s',-1|s,a\}=1/4\) where \(s\neq\{(1,1),(4,4)\}\)
first define the terminal state:
terminal <- c("11","44")
V <- grid
gamma <- 0.5
rewardGrid <- grid
for(r in 1:nrow(rewardGrid)){
for(c in 1:ncol(rewardGrid)){
if(!paste0(r,c) %in% terminal){
rewardGrid[r,c] <- -1
}
}
}
print("the reward grid:")
## [1] "the reward grid:"
print(rewardGrid)
## [,1] [,2] [,3] [,4]
## [1,] 0 -1 -1 -1
## [2,] -1 -1 -1 -1
## [3,] -1 -1 -1 -1
## [4,] -1 -1 -1 0
based on the state value matrix, the policy is determined as: \(\pi(a|s)=\max_a{V(s,a)}/Z\), where \(Z=\sum_a{\pi(a|s)}\).
policy <- function(V,r,c){
# let's assume that r,c is not the terminal one.
# look up four direction
v = c()
r.up <- ifelse(r==1,r,r-1)
c.up <- c
v[1] <- V[r.up,c.up]
r.down <- ifelse(r==nrow(grid),r,r+1)
c.down <- c
v[2] <- V[r.down,c.down]
r.left <- r
c.left <- ifelse(c==1,c,c-1)
v[3] <- V[r.left,c.left]
r.right <- r
c.right <- ifelse(c==ncol(grid),c,c+1)
v[4] <- V[r.right,c.right]
v.order <- order(v,decreasing = T)
r <- rep(0,4)
if(v[v.order[1]] > v[v.order[2]]){
r[v.order[1]] <- 1
}else if(v[v.order[2]] > v[v.order[3]]){
r[v.order[1]] <- 1/2
r[v.order[2]] <- 1/2
}else if(v[v.order[3]] > v[v.order[4]]){
r[v.order[1:3]] <- 1/3
}else{
r[1:4] <- 1/4
}
list(r=r,v=v)
}
based on the policy, the value funciton should be updated as: \(V^{(t+1)}(s)=\sum_a{\pi(a|s)\sum_{s'}{p(s',-1|s,a)[-1 + \gamma V(s')]}}\) (Bellman optimality equation).
calV <- function(policy,V){
newV <- matrix(0,nrow = nrow(V),ncol = ncol(V))
for(r in 1:nrow(V)){
for(c in 1:ncol(V)){
txt <- paste0(r,c)
if(!txt %in% terminal){ # only calculate those non-terminal state
res <- policy(V,r,c) # get the policy here
poli <- res$r
v <- res$v
newV[r,c] <- sum(sapply(poli,function(pk){
pk * 1/4 * sum(-1 + gamma * v)
}))
}
}
}
newV
}
begin to update the state-value funciton iterally:
iter.max <- 100
newV <- V
print(newV)
## [,1] [,2] [,3] [,4]
## [1,] 0 0 0 0
## [2,] 0 0 0 0
## [3,] 0 0 0 0
## [4,] 0 0 0 0
epsilon <- 1e-3
stop <- F
iter <- 1
sumOfV <- sum(newV)
while(!stop && iter <= iter.max){
newV <- calV(policy,newV)
print(paste("iter:",iter))
print(newV)
tSumOfV <- sum(newV)
if(abs(sumOfV - tSumOfV) < epsilon){
stop <- T
}else{
sumOfV <- tSumOfV
iter <- iter + 1
}
}
## [1] "iter: 1"
## [,1] [,2] [,3] [,4]
## [1,] 0 -1 -1 -1
## [2,] -1 -1 -1 -1
## [3,] -1 -1 -1 -1
## [4,] -1 -1 -1 0
## [1] "iter: 2"
## [,1] [,2] [,3] [,4]
## [1,] 0.000 -1.375 -1.500 -1.500
## [2,] -1.375 -1.500 -1.500 -1.500
## [3,] -1.500 -1.500 -1.500 -1.375
## [4,] -1.500 -1.500 -1.375 0.000
## [1] "iter: 3"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.546875 -1.734375 -1.750000
## [2,] -1.546875 -1.718750 -1.750000 -1.734375
## [3,] -1.734375 -1.750000 -1.718750 -1.546875
## [4,] -1.750000 -1.734375 -1.546875 0.000000
## [1] "iter: 4"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.625000 -1.847656 -1.871094
## [2,] -1.625000 -1.824219 -1.863281 -1.847656
## [3,] -1.847656 -1.863281 -1.824219 -1.625000
## [4,] -1.871094 -1.847656 -1.625000 0.000000
## [1] "iter: 5"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.662109 -1.900879 -1.929688
## [2,] -1.662109 -1.872070 -1.917969 -1.900879
## [3,] -1.900879 -1.917969 -1.872070 -1.662109
## [4,] -1.929688 -1.900879 -1.662109 0.000000
## [1] "iter: 6"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.679382 -1.926331 -1.957642
## [2,] -1.679382 -1.895020 -1.943237 -1.926331
## [3,] -1.926331 -1.943237 -1.895020 -1.679382
## [4,] -1.957642 -1.926331 -1.679382 0.000000
## [1] "iter: 7"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.687592 -1.938324 -1.970993
## [2,] -1.687592 -1.905655 -1.955338 -1.938324
## [3,] -1.938324 -1.955338 -1.905655 -1.687592
## [4,] -1.970993 -1.938324 -1.687592 0.000000
## [1] "iter: 8"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.691446 -1.944031 -1.977329
## [2,] -1.691446 -1.910732 -1.960995 -1.944031
## [3,] -1.944031 -1.960995 -1.910732 -1.691446
## [4,] -1.977329 -1.944031 -1.691446 0.000000
## [1] "iter: 9"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.693276 -1.946725 -1.980340
## [2,] -1.693276 -1.913110 -1.963691 -1.946725
## [3,] -1.946725 -1.963691 -1.913110 -1.693276
## [4,] -1.980340 -1.946725 -1.693276 0.000000
## [1] "iter: 10"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.694139 -1.948004 -1.981766
## [2,] -1.694139 -1.914242 -1.964959 -1.948004
## [3,] -1.948004 -1.964959 -1.914242 -1.694139
## [4,] -1.981766 -1.948004 -1.694139 0.000000
## [1] "iter: 11"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.694548 -1.948609 -1.982443
## [2,] -1.694548 -1.914774 -1.965561 -1.948609
## [3,] -1.948609 -1.965561 -1.914774 -1.694548
## [4,] -1.982443 -1.948609 -1.694548 0.000000
## [1] "iter: 12"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.694741 -1.948895 -1.982763
## [2,] -1.694741 -1.915027 -1.965846 -1.948895
## [3,] -1.948895 -1.965846 -1.915027 -1.694741
## [4,] -1.982763 -1.948895 -1.694741 0.000000
## [1] "iter: 13"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.694833 -1.949031 -1.982914
## [2,] -1.694833 -1.915147 -1.965981 -1.949031
## [3,] -1.949031 -1.965981 -1.915147 -1.694833
## [4,] -1.982914 -1.949031 -1.694833 0.000000
## [1] "iter: 14"
## [,1] [,2] [,3] [,4]
## [1,] 0.000000 -1.694876 -1.949095 -1.982986
## [2,] -1.694876 -1.915203 -1.966044 -1.949095
## [3,] -1.949095 -1.966044 -1.915203 -1.694876
## [4,] -1.982986 -1.949095 -1.694876 0.000000