Click here for other works of the author on RPubs
The multi-armed bandit problem is a hypothetical problem of a person choosing which slot machines to play in a casino. Each machine yields a reward generated through some unknown process when played. On each turn, the person must choose a machine to play based on information gathered previously. The objective of the problem is to maximize the sum of rewards earned through a sequence of decision.
The problem is interesting because it involves a fundamental trade-off between the gain from exploiting choices that had worked well previously vs exploring choices that might potentially be optimal, but appeared to be inferior because of randomness. This “exploration” / “exploitation” trade-off can be observed in many real world problems such as choosing which products to buy, advertisement selection, medical treatment selection or financial portfolio design.
Below I will demonstrate two algorithms, upper confidence limits (UCB) and Thompson sampling (or probability matching strategy, a more general term), that tries to solve the multi-armed bandit problem. I will compare their performance and strategy in three scenarios: standard rewards, standard but more volatile rewards, somewhat chaotic rewards.
library(ggplot2)
library(reshape2)
This data represents an standard, ideal situation: normally distributed rewards, well seperated from each other.
mean_reward = c(5, 7.5, 10, 12.5, 15, 17.5, 20, 22.5, 25, 26)
reward_dist = c(function(n) rnorm(n = n, mean = mean_reward[1], sd = 2.5),
function(n) rnorm(n = n, mean = mean_reward[2], sd = 2.5),
function(n) rnorm(n = n, mean = mean_reward[3], sd = 2.5),
function(n) rnorm(n = n, mean = mean_reward[4], sd = 2.5),
function(n) rnorm(n = n, mean = mean_reward[5], sd = 2.5),
function(n) rnorm(n = n, mean = mean_reward[6], sd = 2.5),
function(n) rnorm(n = n, mean = mean_reward[7], sd = 2.5),
function(n) rnorm(n = n, mean = mean_reward[8], sd = 2.5),
function(n) rnorm(n = n, mean = mean_reward[9], sd = 2.5),
function(n) rnorm(n = n, mean = mean_reward[10], sd = 2.5))
#prepare simulation data
dataset = matrix(nrow = 10000, ncol = 10)
for(i in 1:10){
dataset[, i] = reward_dist[[i]](n = 10000)
}
colnames(dataset) <- 1:10
dataset_p = melt(dataset)[, 2:3]
colnames(dataset_p) <- c("Bandit", "Reward")
dataset_p$Bandit = as.factor(dataset_p$Bandit)
#plot the distributions of rewards from bandits
ggplot(dataset_p, aes(x = Reward, col = Bandit, fill = Bandit)) +
geom_density(alpha = 0.3) +
labs(title = "Reward from different bandits")
UCB <- function(N = 1000, reward_data){
d = ncol(reward_data)
bandit_selected = integer(0)
numbers_of_selections = integer(d)
sums_of_rewards = integer(d)
total_reward = 0
for (n in 1:N) {
max_upper_bound = 0
for (i in 1:d) {
if (numbers_of_selections[i] > 0){
average_reward = sums_of_rewards[i] / numbers_of_selections[i]
delta_i = sqrt(2 * log(1 + n * log(n)^2) / numbers_of_selections[i])
upper_bound = average_reward + delta_i
} else {
upper_bound = 1e400
}
if (upper_bound > max_upper_bound){
max_upper_bound = upper_bound
bandit = i
}
}
bandit_selected = append(bandit_selected, bandit)
numbers_of_selections[bandit] = numbers_of_selections[bandit] + 1
reward = reward_data[n, bandit]
sums_of_rewards[bandit] = sums_of_rewards[bandit] + reward
total_reward = total_reward + reward
}
return(list(total_reward = total_reward, bandit_selected = bandit_selected, numbers_of_selections = numbers_of_selections, sums_of_rewards = sums_of_rewards))
}
UCB(N = 1000, reward_data = dataset)
## $total_reward
## 1
## 25825.95
##
## $bandit_selected
## [1] 1 2 3 4 5 6 7 8 9 10 9 9 9 8 9 10 9 8 8 10 9 9 9
## [24] 9 9 9 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [47] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [70] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [93] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [116] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [139] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [162] 10 10 10 10 10 10 10 10 10 10 10 10 8 10 10 10 10 10 10 10 10 10 10
## [185] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [208] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [231] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [254] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [277] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [300] 10 10 10 10 10 10 9 9 9 9 7 10 10 10 10 10 10 10 10 10 10 10 10
## [323] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [346] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [369] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 9 9 10 10 10
## [392] 10 9 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [415] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [438] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [461] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [484] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [507] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [530] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [553] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [576] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [599] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [622] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [645] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [668] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [691] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [714] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [737] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [760] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [783] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [806] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [829] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [852] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [875] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [898] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [921] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [944] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [967] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [990] 10 10 10 10 10 10 10 10 10 10 10
##
## $numbers_of_selections
## [1] 1 1 1 1 1 1 2 5 19 968
##
## $sums_of_rewards
## [1] 7.215002 9.194578 9.895383 15.301631 16.559387
## [6] 12.758873 41.153498 117.463396 474.474029 25121.937130
rnormgamma <- function(n, mu, lambda, alpha, beta){
if(length(n) > 1)
n <- length(n)
tau <- rgamma(n, alpha, beta)
x <- rnorm(n, mu, 1 / (lambda * tau))
data.frame(tau = tau, x = x)
}
T.samp <- function(N = 500, reward_data, mu0 = 0, v = 1, alpha = 2, beta = 6){
d = ncol(reward_data)
bandit_selected = integer(0)
numbers_of_selections = integer(d)
sums_of_rewards = integer(d)
total_reward = 0
reward_history = vector("list", d)
for (n in 1:N){
max_random = -1e400
for (i in 1:d){
if(numbers_of_selections[i] >= 1){
rand = rnormgamma(1,
(v * mu0 + numbers_of_selections[i] * mean(reward_history[[i]])) / (v + numbers_of_selections[i]),
v + numbers_of_selections[i],
alpha + numbers_of_selections[i] / 2,
beta + (sum(reward_history[[i]] - mean(reward_history[[i]])) ^ 2) / 2 + ((numbers_of_selections[i] * v) / (v + numbers_of_selections[i])) * (mean(reward_history[[i]]) - mu0) ^ 2 / 2)$x
}else {
rand = rnormgamma(1, mu0, v, alpha, beta)$x
}
if(rand > max_random){
max_random = rand
bandit = i
}
}
bandit_selected = append(bandit_selected, bandit)
numbers_of_selections[bandit] = numbers_of_selections[bandit] + 1
reward = reward_data[n, bandit]
sums_of_rewards[bandit] = sums_of_rewards[bandit] + reward
total_reward = total_reward + reward
reward_history[[bandit]] = append(reward_history[[bandit]], reward)
}
return(list(total_reward = total_reward, bandit_selected = bandit_selected, numbers_of_selections = numbers_of_selections, sums_of_rewards = sums_of_rewards))
}
T.samp(N = 1000, reward_data = dataset, mu0 = 40)
## $total_reward
## 10
## 24401.63
##
## $bandit_selected
## [1] 10 6 7 4 9 1 6 1 9 4 9 9 7 1 1 1 2 10 10 2 7 6 6
## [24] 5 2 5 2 3 7 4 9 3 3 3 8 2 1 4 7 6 8 10 9 5 2 8
## [47] 6 5 8 3 5 7 3 10 9 4 10 8 8 8 10 7 10 2 9 7 1 8 10
## [70] 9 6 6 5 3 9 8 2 10 2 5 10 4 10 7 10 9 5 10 10 4 2 3
## [93] 9 10 10 10 8 10 10 9 1 10 10 5 10 8 4 10 10 10 9 10 10 3 10
## [116] 1 7 2 10 10 9 10 10 10 10 10 6 1 4 10 9 10 10 10 7 10 10 10
## [139] 10 10 9 3 5 10 10 3 9 10 10 9 10 5 10 8 10 10 10 9 10 9 10
## [162] 10 10 10 10 10 9 10 4 10 5 10 7 10 9 10 10 10 10 9 10 8 10 10
## [185] 10 10 10 8 9 10 9 9 2 10 10 1 10 10 10 8 4 10 10 10 10 10 10
## [208] 10 8 10 10 10 10 6 10 10 10 10 10 10 10 10 4 9 10 10 10 10 10 10
## [231] 10 9 10 10 10 10 10 10 8 6 10 10 10 10 10 10 10 7 10 9 10 10 10
## [254] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [277] 10 10 10 10 10 10 10 10 10 10 6 10 10 10 10 10 10 9 10 10 10 10 10
## [300] 10 10 10 10 10 10 10 10 10 10 10 10 8 10 10 10 10 7 10 10 10 10 10
## [323] 10 10 10 10 10 1 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [346] 10 10 10 10 10 10 10 10 10 10 10 4 10 9 10 10 9 10 10 10 10 1 9
## [369] 9 10 10 10 10 10 10 10 10 10 10 10 9 2 10 9 10 10 10 8 10 10 10
## [392] 10 5 8 10 10 10 10 10 10 10 10 10 7 10 10 9 10 9 10 5 10 8 9
## [415] 10 9 10 10 10 9 10 10 10 10 10 10 10 10 10 10 6 10 10 10 10 10 10
## [438] 10 10 9 10 10 10 10 5 10 10 9 10 10 9 10 10 10 10 10 10 10 10 10
## [461] 3 10 10 10 10 10 10 10 9 10 10 5 10 10 9 10 10 10 10 10 10 10 10
## [484] 10 9 10 10 10 4 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 9 9
## [507] 10 10 10 10 9 10 10 10 10 9 9 10 10 10 10 10 10 10 10 10 10 10 10
## [530] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [553] 10 10 7 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [576] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [599] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [622] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [645] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 6 10 10 10 10 10 10 10
## [668] 10 10 10 10 10 10 10 2 10 10 10 10 10 10 10 10 3 10 10 10 10 10 10
## [691] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [714] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [737] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [760] 10 10 10 10 10 10 10 10 10 4 10 10 10 10 10 10 10 10 10 8 10 10 10
## [783] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [806] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [829] 10 10 10 10 10 10 10 10 10 3 10 10 10 10 10 10 10 10 10 10 10 10 10
## [852] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [875] 10 10 10 10 10 1 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [898] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [921] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [944] 8 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [967] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 3 10 10 10 10 10 10
## [990] 10 10 10 10 10 10 10 10 10 10 10
##
## $numbers_of_selections
## [1] 14 14 15 15 16 14 16 23 53 820
##
## $sums_of_rewards
## [1] 65.92387 92.60770 132.74036 191.73349 234.35009
## [6] 239.16131 321.85440 538.71035 1347.56022 21236.98349
We can see that the UCB algorithm quickly found out that the \(10_th\) bandit yields most reward. On the other hand, Thompson sampling tried the worse bandits a lot more times before finding the best one.
This data represents an ideal but more unstable situation: normally distributed rewards with much larger variance, thus not well seperated from each other.
mean_reward = c(5, 7.5, 10, 12.5, 15, 17.5, 20, 22.5, 25, 26)
reward_dist = c(function(n) rnorm(n = n, mean = mean_reward[1], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[2], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[3], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[4], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[5], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[6], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[7], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[8], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[9], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[10], sd = 20))
#prepare simulation data
dataset = matrix(nrow = 10000, ncol = 10)
for(i in 1:10){
dataset[, i] = reward_dist[[i]](n = 10000)
}
colnames(dataset) <- 1:10
dataset_p = melt(dataset)[, 2:3]
colnames(dataset_p) <- c("Bandit", "Reward")
dataset_p$Bandit = as.factor(dataset_p$Bandit)
#plot the distributions of rewards from bandits
ggplot(dataset_p, aes(x = Reward, col = Bandit, fill = Bandit)) +
geom_density(alpha = 0.3) +
labs(title = "Reward from different bandits")
UCB(N = 1000, reward_data = dataset)
## $total_reward
## 1
## 25381.76
##
## $bandit_selected
## [1] 1 2 3 4 5 6 7 8 9 10 6 6 6 10 6 6 6 6 6 6 6 4 4
## [24] 4 4 4 4 10 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [47] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [70] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [93] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [116] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [139] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [162] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [185] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [208] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [231] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [254] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [277] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [300] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [323] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [346] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [369] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [392] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [415] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [438] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [461] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [484] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [507] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [530] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [553] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [576] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [599] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [622] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [645] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [668] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [691] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [714] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [737] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [760] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [783] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [806] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [829] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [852] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [875] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [898] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [921] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [944] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [967] 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
## [990] 9 9 9 9 9 9 9 9 9 9 9
##
## $numbers_of_selections
## [1] 1 1 1 7 1 11 1 1 973 3
##
## $sums_of_rewards
## [1] 11.724481 -4.163511 11.378243 142.563325 -1.399896
## [6] 236.377343 12.359089 -33.607918 24946.752619 59.775404
T.samp(N = 1000, reward_data = dataset, mu0 = 40)
## $total_reward
## 7
## 25415.94
##
## $bandit_selected
## [1] 7 10 7 5 3 4 5 9 8 3 3 9 6 5 8 3 1 4 4 5 1 6 6
## [24] 6 1 6 4 2 2 9 4 2 1 2 9 5 9 8 1 8 3 4 10 6 2 9
## [47] 10 7 3 8 1 8 10 7 10 3 2 10 10 10 3 10 10 10 1 7 10 10 10
## [70] 10 5 7 4 10 4 5 10 9 9 7 9 10 1 10 6 2 7 10 10 10 10 10
## [93] 10 10 6 10 10 10 10 10 9 1 10 10 10 10 6 8 8 10 10 10 10 10 6
## [116] 4 7 3 10 5 10 10 10 9 3 10 9 8 2 9 10 8 10 10 8 10 10 10
## [139] 10 10 10 10 10 10 6 10 4 10 10 10 10 10 10 5 10 10 10 6 10 10 10
## [162] 10 10 10 7 10 10 10 2 10 10 10 10 10 10 10 10 10 10 10 7 10 10 9
## [185] 10 10 10 10 10 1 10 3 10 4 10 10 10 10 10 4 10 10 10 10 10 10 10
## [208] 10 10 10 1 10 2 10 10 10 10 10 10 10 10 10 9 10 10 10 10 10 10 10
## [231] 10 10 10 10 10 5 10 10 10 2 10 10 10 7 10 10 10 10 10 10 10 10 10
## [254] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 5 10 10 10 10 10 8 10
## [277] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 7 10 10 10 10 10 10 10 2
## [300] 10 10 6 10 10 10 10 10 10 10 9 10 10 10 10 10 10 10 10 10 10 10 10
## [323] 10 10 10 10 10 10 10 10 10 10 10 7 10 10 10 10 10 10 10 10 8 10 10
## [346] 10 10 10 10 5 10 10 10 10 1 10 10 10 10 10 10 10 10 10 10 10 10 10
## [369] 10 10 10 10 7 10 2 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [392] 10 10 8 10 10 10 10 10 3 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [415] 10 10 10 10 10 10 10 10 10 10 10 8 10 10 10 10 10 10 10 10 10 10 10
## [438] 10 10 10 10 10 4 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [461] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [484] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [507] 10 10 10 10 10 10 10 10 10 10 10 10 2 10 10 10 10 10 10 10 10 10 10
## [530] 10 10 10 10 10 5 10 10 10 10 10 10 10 10 10 10 1 10 10 10 10 10 10
## [553] 10 10 10 10 10 10 10 10 10 10 8 10 10 10 10 10 10 10 10 10 10 10 10
## [576] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [599] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [622] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [645] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [668] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [691] 10 10 10 10 10 10 10 10 10 3 10 10 10 10 10 10 10 10 10 10 10 10 10
## [714] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [737] 10 10 10 10 6 10 10 10 10 10 10 10 10 10 5 10 10 10 10 10 10 10 10
## [760] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [783] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [806] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [829] 10 10 5 10 10 10 10 10 10 10 10 10 10 10 10 10 4 10 10 10 10 10 10
## [852] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [875] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [898] 10 10 10 10 10 10 10 6 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [921] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [944] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [967] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [990] 10 10 10 10 10 10 9 10 10 10 10
##
## $numbers_of_selections
## [1] 13 14 13 14 15 15 15 16 17 868
##
## $sums_of_rewards
## [1] 207.17618 89.21395 15.40864 152.29464 230.05097
## [6] 233.19867 238.58975 337.62746 404.47312 23507.90494
When the fluctuation of rewards are greater, the UCB algorithm is more susceptible to being “stuck” at a suboptimal choice and never finds the optimal bandit. Thompson sampling is generally more robust and would be able to find the optimal bandit in all kinds of situations.
This data represents an more chaotic (possibly more realistic) situation: rewards with different distribution and different variance.
mean_reward = c(5, 7.5, 10, 12.5, 15, 17.5, 20, 22.5, 25, 26)
reward_dist = c(function(n) rnorm(n = n, mean = mean_reward[1], sd = 20),
function(n) rgamma(n = n, shape = mean_reward[2] / 2, rate = 0.5),
function(n) rpois(n = n, lambda = mean_reward[3]),
function(n) runif(n = n, min = mean_reward[4] - 20, max = mean_reward[4] + 20),
function(n) rlnorm(n = n, meanlog = log(mean_reward[5]) - 0.25, sdlog = 0.5),
function(n) rnorm(n = n, mean = mean_reward[6], sd = 20),
function(n) rexp(n = n, rate = 1 / mean_reward[7]),
function(n) rbinom(n = n, size = mean_reward[8] / 0.5, prob = 0.5),
function(n) rnorm(n = n, mean = mean_reward[9], sd = 20),
function(n) rnorm(n = n, mean = mean_reward[10], sd = 20))
#prepare simulation data
dataset = matrix(nrow = 10000, ncol = 10)
for(i in 1:10){
dataset[, i] = reward_dist[[i]](n = 10000)
}
colnames(dataset) <- 1:10
dataset_p = melt(dataset)[, 2:3]
colnames(dataset_p) <- c("Bandit", "Reward")
dataset_p$Bandit = as.factor(dataset_p$Bandit)
#plot the distributions of rewards from bandits
ggplot(dataset_p, aes(x = Reward, col = Bandit, fill = Bandit)) +
geom_density(alpha = 0.3) +
labs(title = "Reward from different bandits")
UCB(N = 1000, reward_data = dataset)
## $total_reward
## 1
## 22182.78
##
## $bandit_selected
## [1] 1 2 3 4 5 6 7 8 9 10 9 1 1 7 7 9 9 9 9 9 1 8 8
## [24] 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7
## [47] 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7
## [70] 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7
## [93] 7 7 7 7 7 7 7 7 7 7 7 7 8 7 8 8 8 8 8 8 8 8 8
## [116] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [139] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [162] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [185] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [208] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [231] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [254] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [277] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [300] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [323] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [346] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [369] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [392] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [415] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [438] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [461] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [484] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [507] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [530] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [553] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [576] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [599] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [622] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [645] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [668] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [691] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [714] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [737] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [760] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [783] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [806] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [829] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [852] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [875] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [898] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [921] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [944] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [967] 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
## [990] 8 8 8 8 8 8 8 8 8 8 8
##
## $numbers_of_selections
## [1] 4 1 1 1 1 1 85 898 7 1
##
## $sums_of_rewards
## [1] 75.1713518 7.9493833 7.0000000 0.4473063 7.2531766
## [6] -27.4872651 1838.1946655 20127.0000000 133.1729419 14.0827289
T.samp(N = 1000, reward_data = dataset, mu0 = 40)
## $total_reward
## 3
## 23016.19
##
## $bandit_selected
## [1] 3 9 5 10 8 5 9 6 3 2 6 1 2 8 5 3 5 3 4 10 2 8 1
## [24] 5 5 2 4 2 9 6 8 4 1 7 7 5 7 7 3 1 4 6 7 3 6 8
## [47] 7 1 5 4 1 1 9 6 7 9 2 10 10 7 10 3 10 6 8 1 4 7 8
## [70] 3 4 8 3 9 9 1 9 9 10 5 10 6 8 9 1 6 2 8 9 5 7 8
## [93] 10 10 8 4 8 7 8 9 10 8 7 9 9 7 10 2 4 10 2 8 8 3 8
## [116] 10 9 10 7 2 8 8 10 8 10 8 8 8 8 8 8 8 3 8 8 8 4 9
## [139] 8 8 4 1 8 8 8 8 6 2 8 6 8 8 8 10 6 4 8 1 9 7 9
## [162] 8 8 10 8 3 8 8 8 8 8 5 4 5 8 8 8 8 8 8 8 8 8 8
## [185] 8 8 8 8 8 8 8 8 8 10 8 10 8 8 8 8 8 8 7 2 7 7 7
## [208] 9 8 10 8 8 1 10 8 10 8 8 10 8 8 8 8 8 8 2 8 8 8 8
## [231] 8 8 8 8 9 8 10 8 8 10 8 8 8 8 8 8 8 8 8 8 8 8 8
## [254] 8 8 8 7 8 8 8 8 8 8 8 8 8 8 8 8 2 7 8 8 8 8 8
## [277] 4 8 8 8 8 8 8 1 8 8 8 8 8 8 3 8 8 8 8 8 8 8 5
## [300] 6 8 8 8 6 8 8 8 8 8 1 8 8 8 8 8 8 8 8 8 8 8 8
## [323] 8 8 8 8 9 5 8 8 8 8 8 8 9 8 8 8 8 8 8 8 8 8 8
## [346] 8 8 8 8 3 8 8 8 8 8 8 8 8 8 3 10 10 10 10 8 8 8 8
## [369] 8 8 8 8 8 10 8 7 10 10 8 10 10 8 8 8 10 8 7 2 7 7 7
## [392] 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 4 7 7 7 7
## [415] 7 7 7 7 7 7 8 7 8 10 10 5 8 10 10 10 10 10 10 10 10 10 10
## [438] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [461] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [484] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [507] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 3 10 10 10 10
## [530] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [553] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [576] 10 10 6 10 10 10 10 4 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [599] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 3 10 10
## [622] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [645] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [668] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [691] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [714] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [737] 10 10 10 10 6 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [760] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [783] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [806] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [829] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [852] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [875] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [898] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [921] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [944] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
## [967] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 3
## [990] 10 10 10 10 10 10 10 10 10 10 2
##
## $numbers_of_selections
## [1] 15 16 18 16 15 16 55 222 22 605
##
## $sums_of_rewards
## [1] 36.13793 124.81952 178.00000 235.31664 193.73812
## [6] 209.71535 1202.61210 4996.00000 445.10455 15394.74151
The performance of the two algorithm are similar to what we’ve observed in the previous condition.
A major reason why the Thompson sampling algorithm tries all bandits several times before choosing the one it considers best is because I chose a prior distribution with a relatively high mean. With a prior having a larger mean, the algorithm favors “exploration” over “exploitation” at the beginning and only when it is very confident that it has found the best choice that it valued “exploitation” over “exploration”. If we decrease the mean of prior, “exploitation” would have a higher value and the algorithm would stop exploring faster. by chaning the prior distribution used, one can adjust the relative importance of “exploration” over “exploitation” to suit specific problems on hand. This a another testament to how flexible the Thompson sampling algorithm is.
From the demonstrations above, we have seen how Thompson sampling can be more robust than methods based on optimal, standard situations such as UCB. However, using the approach we’ve implemented, one must use conjugate priors of the likelihood in order to update posterior distributions in a straight forward way. What if approapriate conjugate priors do not exist or one wish to construct hierarchecal models which include other variables that might affect rewards gained? In such cases, posterior distributions of the reward cannot be estimated by the usual way. Luckily, researches have developed more flexible methods for estimating the posterior. In this appendix, I will show one of the most powerful and commonly used technique, Markoc chain Monte Carlo (MCMC).
Markoc chain Monte Carlo algorithms attempt to “sample” from posterior distribution by constructing a Markov chain. With sufficient samples and applying other control measures (ex. thinning & burn in), these samples drawn from the Markov chain will approximate the posterior. This type of algorithms are able to sample from joint distributions in all kinds of models with great flexibility and accuracy.
As good as it sounds, MCMC algorithms have a major drawback, which is its slow speed relative to other methods. This being said, the flexibility and robustness MCMC algorithms offer are still rarely matched by other algorithms. My personal advice would be to use other methods when the problem at hand is relatively standard and use MCMC when one wish to build sophisticated models that include other explanatory variables.
We use the same model as we’ve introduced before: a normal-gamma prior with a normal likelihood.
MCMC can be conducted in many different ways. I will give example codes using OpenBUGS, a software designed to perform Gibbs sampling (one type of MCMC). The package R2OpenBUGS allows us to call OpenBUGS from R.
#load package R2OpenBUGS
library(R2OpenBUGS)
#self-defined function that returns samples from the posterior
MCMC <- function(reward, n.iter = 1000, n.chains = 1, n.burnin = 100, n.thin = 10, ...){
n = length(reward)
est_mu = mean(reward)
if(n == 1){
return(rnorm(n.iter - n.burnin, est_mu, 100))
}
my.data <- list("reward", "n", "est_mu")
model = function(){
for(i in 1:n){
reward[i] ~ dnorm(mu, tau)
}
mu ~ dnorm(est_mu, 0.001)
tau ~ dgamma(0.001, 0.001)
}
my.model.file <- "model_bandit.odc"
write.model(model, con = my.model.file)
params <- c("mu")
inits <- function(){
list(mu = 5, tau = 1)
}
out <- bugs(data = my.data, inits = inits, parameters.to.save = params, model.file = my.model.file, codaPkg = T, n.iter = n.iter, n.chains = n.chains, n.burnin = n.burnin, n.thin = n.thin, save.history = F, ...)
bugs_out <- read.bugs(out, quiet = T)
return(bugs_out[[1]][, 2])
}
Same function as T.samp function above but estimates posterior distribution of mean reward using MCMC.
T.samp.mcmc <- function(N = 500, reward_data, mcmc.iter = 1000, mcmc.burnin = 100, show.iter = FALSE, ...){
d = ncol(reward_data)
bandit_selected = integer(0)
numbers_of_selections = integer(d)
sums_of_rewards = integer(d)
total_reward = 0
posterior_dist = matrix(9999, nrow = mcmc.iter - mcmc.burnin, ncol = d, n.mcmc = )
reward_history = vector("list", d)
for (n in 1:N){
max_random = 0
for (i in 1:d){
rand = sample(posterior_dist[, i], 1)
if(rand > max_random){
max_random = rand
bandit = i
}
}
bandit_selected = append(bandit_selected, bandit)
numbers_of_selections[bandit] = numbers_of_selections[bandit] + 1
reward = reward_data[n, bandit]
sums_of_rewards[bandit] = sums_of_rewards[bandit] + reward
total_reward = total_reward + reward
reward_history[[bandit]] = append(reward_history[[bandit]], reward)
posterior_dist[, bandit] = as.vector(MCMC(reward = reward_history[[bandit]], n.iter = mcmc.iter, n.burnin = mcmc.burnin, ...))
if(show.iter == TRUE)
cat()
}
return(list(total_reward = total_reward, bandit_selected = bandit_selected, numbers_of_selections = numbers_of_selections, sums_of_rewards = sums_of_rewards))
}
T.samp.mcmc(N = 200, reward_data = dataset)
P.S. Codes using MCMC are not runned due to its slow speed, you are welcome to try it yourself.