```
library(magrittr)
library(assertthat)
library(ggplot2)
library(tidyr)
Constant \(k\) is taken to be one.
position <- function(U1) {
cut(U1, breaks = c(0,0.25, 0.5, 0.75, 1.0), labels = FALSE, include.lowest = TRUE)
}
prob_plus <- function(k, neighbors) {
p_plus <- prod(exp(k * +1 * neighbors))
p_minus <- prod(exp(k * -1 * neighbors))
## Normalize
p_plus / (p_plus + p_minus)
}
update_v <- function(v, k, U1, U2) {
assert_that(U1 >= 0)
assert_that(U1 <= 1)
assert_that(U2 >= 0)
assert_that(U2 <= 1)
assert_that(length(v) == 4)
assert_that(length(k) == 1)
assert_that(length(U1) == 1)
assert_that(length(U2) == 1)
## Choose target position
pos <- position(U1)
if (pos == 1) {
neighbors <- c(v[4],v[2])
} else if (pos == 2) {
neighbors <- c(v[1],v[3])
} else if (pos == 3) {
neighbors <- c(v[2],v[4])
} else if (pos == 4) {
neighbors <- c(v[3],v[1])
}
## Update target position based on neighbors
v[pos] <- ifelse(U2 < prob_plus(k = k, neighbors = neighbors),
+1, -1)
v
}
v_max <- c(+1,+1,+1,+1)
v_min <- c(-1,-1,-1,-1)
update_v(v_max, k = 1, 0.5, 0.5)
## [1] 1 1 1 1
n <- 100
U1_seq <- runif(n = n)
U2_seq <- runif(n = n)
run_chain <- function(init, k, U1_seq, U2_seq) {
assert_that(length(init) == 4)
assert_that(length(U1_seq) == length(U2_seq))
## Out object
out <- matrix(NA,
nrow = length(U1_seq),
ncol = length(init))
out[1,] <- init
for (i in seq_along(U1_seq)[-1]) {
out[i,] <- update_v(out[i-1,], k, U1_seq[i], U2_seq[i])
}
out
}
With a smaller \(k\), convergence seems faster. There are really 16 states, but the state is summarized as sum of 4 nodes for cleaner visualization.
data1 <- data.frame(i = seq_along(U1_seq),
max_chain_k0.01 = rowSums(run_chain(v_max, 0.01, U1_seq, U2_seq)),
min_chain_k0.01 = rowSums(run_chain(v_min, 0.01, U1_seq, U2_seq)),
max_chain_k0.5 = rowSums(run_chain(v_max, 0.5, U1_seq, U2_seq)),
min_chain_k0.5 = rowSums(run_chain(v_min, 0.5, U1_seq, U2_seq)),
max_chain_k1 = rowSums(run_chain(v_max, 1, U1_seq, U2_seq)),
min_chain_k1 = rowSums(run_chain(v_min, 1, U1_seq, U2_seq)),
max_chain_k2 = rowSums(run_chain(v_max, 2, U1_seq, U2_seq)),
min_chain_k2 = rowSums(run_chain(v_min, 2, U1_seq, U2_seq)))
data1_long <- gather(data1, key = chain, value = value, -i)
data1_long$k <- as.numeric(gsub(".*_chain_k", "", data1_long$chain))
ggplot(data = data1_long,
mapping = aes(x = i, y = value, group = chain, color = chain)) +
geom_line(alpha = 1/3, size = 2) +
facet_grid(k ~ .) +
theme_bw() + theme(legend.key = element_blank())
\(k = 0.01\) for fast enough converence for visualization. No matter when you start, convergence into the same distribution occurs.
## Try different starting index
get_sum <- function(init, k, start_i, ran_data) {
assert_that(start_i < nrow(ran_data))
n <- nrow(ran_data)
assert_that(start_i >= 1)
c(rep(NA, start_i - 1),
rowSums(run_chain(init,
k,
ran_data[seq(start_i, n),"U1_seq"],
ran_data[seq(start_i, n),"U2_seq"])))
}
n <- 100
ran_data <- data.frame(i = seq_len(n),
U1_seq = runif(n = n),
U2_seq = runif(n = n))
data2 <- data.frame(i = seq_len(nrow(ran_data)),
max_chain_i1 = get_sum(v_max, k = 0.01, start_i = 1, ran_data = ran_data),
min_chain_i1 = get_sum(v_min, k = 0.01, start_i = 1, ran_data = ran_data),
max_chain_i20 = get_sum(v_max, k = 0.01, start_i = 20, ran_data = ran_data),
min_chain_i20 = get_sum(v_min, k = 0.01, start_i = 20, ran_data = ran_data),
max_chain_i30 = get_sum(v_max, k = 0.01, start_i = 30, ran_data = ran_data),
min_chain_i30 = get_sum(v_min, k = 0.01, start_i = 30, ran_data = ran_data),
max_chain_i40 = get_sum(v_max, k = 0.01, start_i = 40, ran_data = ran_data),
min_chain_i40 = get_sum(v_min, k = 0.01, start_i = 40, ran_data = ran_data))
data2_long <- gather(data2, key = chain, value = value, -i)
data2_long$start_i <- as.numeric(gsub(".*_chain_i", "", data2_long$chain))
ggplot(data = data2_long,
mapping = aes(x = i, y = value, group = chain, color = chain)) +
geom_line(alpha = 1/3, size = 2) +
theme_bw() + theme(legend.key = element_blank())
## Warning: Removed 174 rows containing missing values (geom_path).
ggplot(data = data2_long,
mapping = aes(x = i, y = value, group = chain, color = chain)) +
geom_line(alpha = 1/3, size = 2) +
facet_grid(start_i ~ .) +
theme_bw() + theme(legend.key = element_blank())
## Warning: Removed 174 rows containing missing values (geom_path).