In the effort to improve the computational efficiency of the BAST model, this vignette summarizes various implementations and their speed.
First, a speed comparison on the remote sensed chlorophyll data in the Aral sea; the setup of this application is described in section 4.3 of the following paper:
The code for the paper’s implementation can be found in this GitHub repository. The majority of the code for this example is taken from this repository.
First, we’ll read in the processed data (information on this can be found in the AralSea.R file at the GitHub repository above)
aralDataUrl = "https://github.com/ztluostat/BAST/raw/main/aral_data.RData"
load(url(aralDataUrl))
n = length(Y)Here is a plot of what the data looks like:
ggplot() +
geom_boundary(bnd) +
geom_point(aes(x = lon, y = lat, col = Y), data = as.data.frame(coords)) +
scale_color_gradientn(colours = rainbow(5), name = 'Y') +
labs(x = 'Scaled Lon.', y = 'Scaled Lat.') +
ggtitle('Observed Data')The BASTION package does not include a function for constructing a constrained Delauney triangulation graph, so the graph will be constructed using the constrainedDentri function from the paper repository and used for all implementations.
aral_graph = constrainedDentri(n, gen2dMesh(coords, bnd)) # Get full graph
E(aral_graph)$eid = c(1:ecount(aral_graph)) # Assign edge ids
V(aral_graph)$vid = c(1:vcount(aral_graph)) # Assign vertex ids
# plot spatial graph
plotGraph(coords, aral_graph) +
geom_boundary(bnd) +
labs(x = 'Scaled Lon.', y = 'Scaled Lat.') +
ggtitle('Spatial Graph')The following is the implementation from the original paper. In the interest of timing, the iteration numbers have been cut down and such changes are commented below.
mstgraph = mst(aral_graph) # initial spanning tree
graph0 = delete_edge_attr(aral_graph, 'weight')
mstgraph0 = delete_edge_attr(mstgraph, 'weight')
M = 30 # number of weak learners
k_max = 5 # maximum number of clusters per weak learner
mu = list() # initial values of mu
mstgraph_lst = list() # initial spanning trees
cluster = matrix(1, nrow = n, ncol = M) # initial cluster memberships
for(m in 1:M) {
mu[[m]] = c(0)
mstgraph_lst[[m]] = mstgraph0
}
init_val = list()
init_val[['trees']] = mstgraph_lst
init_val[['mu']] = mu
init_val[['cluster']] = cluster
init_val[['sigmasq_y']] = 1
# standardize Y
std_res = standardize(Y)
Y_std = std_res$x
std_par = std_res$std_par
# find lambda_s
nu = 3; q = 0.9
quant = qchisq(1-q, nu)
lambda_s = quant * var(Y_std) / nu
hyperpar = c()
hyperpar['sigmasq_mu'] = (0.5/(2*sqrt(M)))^2
hyperpar['lambda_s'] = lambda_s
hyperpar['nu'] = nu
hyperpar['lambda_k'] = 4
hyperpar['M'] = M
hyperpar['k_max'] = k_max
# MCMC parameters
# number of posterior samples = (MCMC - BURNIN) / THIN
MCMC = 10000 # MCMC iterations !! Down from 30000 originally !!
BURNIN = 5000 # burnin period length !! Down from 15000 originally !!
THIN = 5 # thinning intervalsHere is an implementation written using the BASTION package (based on the one from the paper)
BASTIONfit = function(Y, graph, init_vals, hyperpars, MCMC, BURNIN, THIN, seed = NULL) {
if(!is.null(seed)) {
set.seed(seed)
}
n = vcount(graph)
# hyper-parameter
M = hyperpars['M'] # number of trees
sigmasq_mu = hyperpars['sigmasq_mu']
lambda_s = hyperpars['lambda_s']
nu = hyperpars['nu']
lambda_k = hyperpars['lambda_k']
k_max = hyperpars['k_max']
hyper = c(sigmasq_mu, lambda_s, nu, lambda_k)
# initial values
mstgraph_lst = init_vals[['trees']]
mu = init_vals[['mu']]
cluster = init_vals[['cluster']] # n*M matrix
sigmasq_y = init_vals[['sigmasq_y']]
k = as.numeric(apply(cluster, 2, max)) # number of clusters
csize = list() # cluster size
g = matrix(0, nrow = n, ncol = M) # n*M matrix of fitted mu's
for(m in 1:M) {
cluster_m = cluster[, m]
g[, m] = mu[[m]][cluster_m]
csize[[m]] = Rfast::Table(cluster_m)
mstgraph_m = mstgraph_lst[[m]]
}
# whether an edge in graph is within a cluster or between two clusters
# n*M matrix
################# MCMC ####################
## MCMC results
mu_out = list()
sigmasq_y_out = numeric((MCMC-BURNIN)/THIN)
cluster_out = array(0, dim = c((MCMC-BURNIN)/THIN, n, M))
tree_out = list()
log_post_out = numeric((MCMC-BURNIN)/THIN)
## MCMC iteration
for(iter in 1:MCMC) {
for(m in 1:M) {
k_m = k[m]
mstgraph_m = mstgraph_lst[[m]]
cluster_m = cluster[, m]
csize_m = csize[[m]]
e_m = Y - Rfast::rowsums(g[, -m])
if(k_m == 1) {
rb = 0.9;
rd = 0;
rc = 0;
rhy = 0.1
} else if(k_m == min(k_max, n)) {
rb = 0;
rd = 0.6;
rc = 0.3;
rhy = 0.1
} else {
rb = 0.3;
rd = 0.3;
rc = 0.3;
rhy = 0.1
}
move = sample(4, 1, prob = c(rb, rd, rc, rhy))
# current_comp = components(mstgraph_m)
if(move == 1) { ## Birth move
# split an existing cluster
clust_split = sample.int(k_m, 1, prob = (csize_m - 1))
proposed_output = graphBirth(mstgraph_m, cluster_m, clust_split)
vid_new = proposed_output$new_clust_ids
vid_old = proposed_output$old_clust_ids
# compute log-prior ratio
log_A = log(lambda_k) - log(k_m + 1)
# compute log-proposal ratio
if(k_m == min(k_max, n)-1) {
rd_new = 0.6
} else {
rd_new = 0.3
}
log_P = log(rd_new) - log(rb)
# compute log-likelihood ratio
##
sigma_ratio = sigmasq_y / sigmasq_mu
csize_old = length(vid_old)
csize_new = length(vid_new)
sum_e_old = sum(e_m[vid_old])
sum_e_new = sum(e_m[vid_new])
logdetdiff = -0.5 * (log(csize_old + sigma_ratio) +
log(csize_new + sigma_ratio) -
log(csize_old + csize_new + sigma_ratio) -
log(sigma_ratio))
quaddiff = 0.5 * ( sum_e_old^2 / (csize_old + sigma_ratio) +
sum_e_new^2 / (csize_new + sigma_ratio) -
(sum_e_old + sum_e_new)^2 / (csize_old + csize_new + sigma_ratio)
) / sigmasq_y
##
log_L = logdetdiff + quaddiff
#acceptance probability
acc_prob = min(0, log_A + log_P + log_L)
acc_prob = exp(acc_prob)
if(runif(1) < acc_prob){
# accept
mstgraph_lst[[m]] = proposed_output$graph
csize[[m]] = components(proposed_output$graph)$csize
cluster[, m] = proposed_output$membership
k[m] = k[m] + 1
}
}
if(move == 2) { ## Death move
# merge two existing clusters (c1, c2) -> c2
proposed_output = graphDeath(mstgraph_m, cluster_m, graph)
vid_old = proposed_output$old_clust_ids
vid_new = proposed_output$new_clust_ids
# compute log-prior ratio
log_A = -log(lambda_k) + log(k_m)
# # compute log-proposal ratio
if(k_m == 2) {
rb_new = 0.9
}else {
rb_new = 0.3
}
log_P = -(log(rd) - log(rb_new))
# compute log-likelihood ratio
sigma_ratio = sigmasq_y / sigmasq_mu
csize_old = length(vid_old)
csize_new = length(vid_new)
sum_e_old = sum(e_m[vid_old])
sum_e_new = sum(e_m[vid_new])
logdetdiff = -0.5 * (log(csize_new+sigma_ratio) -
log(csize_old+sigma_ratio) -
log(csize_new-csize_old+sigma_ratio) +
log(sigma_ratio))
quaddiff = 0.5 * ( sum_e_new^2/(csize_new+sigma_ratio) -
sum_e_old^2/(csize_old+sigma_ratio) -
(sum_e_new-sum_e_old)^2/
(csize_new-csize_old+sigma_ratio)
) / sigmasq_y
log_L = logdetdiff + quaddiff
# acceptance probability
acc_prob = min(0, log_A + log_P + log_L)
acc_prob = exp(acc_prob)
if(runif(1) < acc_prob){
# accept
mstgraph_lst[[m]] = proposed_output$graph
csize[[m]] = components(proposed_output$graph)$csize
cluster[, m] = proposed_output$membership
k[m] = k[m] - 1
}
}
if(move == 3) { ## change move
# # first perform death move: (c1, c2) -> c2
# merge_res = mergeCluster(mstgraph_m, eid_btw_mst_m, subgraphs_m, csize_m,
# cluster_m, inc_mat, change = T)
# # then perform birth move
# split_res = splitCluster(mstgraph_m, k_m-1, merge_res$subgraphs, merge_res$csize)
proposed_output = graphChange(mstgraph_m, cluster_m, graph)
vid_old_d = proposed_output$old_dclust_ids
vid_new_d = proposed_output$new_dclust_ids
vid_old_b = proposed_output$old_bclust_ids
vid_new_b = proposed_output$new_bclust_ids
# compute log-likelihood ratio
# First do it for the death move
sigma_ratio = sigmasq_y / sigmasq_mu
csize_old_d = length(vid_old_d)
csize_new_d = length(vid_new_d)
sum_e_old_d = sum(e_m[vid_old_d])
sum_e_new_d = sum(e_m[vid_new_d])
logdetdiff_d = -0.5 * (log(csize_new_d+sigma_ratio) -
log(csize_old_d+sigma_ratio) -
log(csize_new_d-csize_old_d+sigma_ratio) +
log(sigma_ratio))
quaddiff_d = 0.5 * ( sum_e_new_d^2/(csize_new_d+sigma_ratio) -
sum_e_old_d^2/(csize_old_d+sigma_ratio) -
(sum_e_new_d-sum_e_old_d)^2/(csize_new_d-csize_old_d+sigma_ratio)
) / sigmasq_y
log_L_death = logdetdiff_d + quaddiff_d
# Now do it for the birth move
csize_old_b = length(vid_old_b)
csize_new_b = length(vid_new_b)
sum_e_old_b = sum(e_m[vid_old_b])
sum_e_new_b = sum(e_m[vid_new_b])
logdetdiff_b = -0.5 * (log(csize_old_b + sigma_ratio) +
log(csize_new_b + sigma_ratio) -
log(csize_old_b + csize_new_b + sigma_ratio) -
log(sigma_ratio))
quaddiff_b = 0.5 * ( sum_e_old_b^2 / (csize_old_b + sigma_ratio) +
sum_e_new_b^2 / (csize_new_b + sigma_ratio) -
(sum_e_old_b + sum_e_new_b)^2 / (csize_old_b + csize_new_b + sigma_ratio)
) / sigmasq_y
log_L_birth = logdetdiff_b + quaddiff_b
# Add them together
log_L = log_L_birth + log_L_death
# acceptance probability
acc_prob = min(0, log_L)
acc_prob = exp(acc_prob)
if(runif(1) < acc_prob){
# accept
mstgraph_lst[[m]] = proposed_output$graph
csize[[m]] = components(proposed_output$graph)$csize
cluster[, m] = proposed_output$membership
}
}
if(move == 4) {
# update MST
proposed_output = graphHyper(graph, cluster_m)
mstgraph_lst[[m]] = proposed_output$graph
cluster[, m] = proposed_output$membership
}
# update mu_m
k_m = k[m]
cluster_m = cluster[, m]
csize_m = csize[[m]]
Qinv_diag = 1 / (csize_m/sigmasq_y + 1/sigmasq_mu)
b = Qinv_diag * Rfast::group.sum(e_m, cluster_m) / sigmasq_y
mu[[m]] = rnorm(k_m, b, sqrt(Qinv_diag))
g[, m] = mu[[m]][cluster_m]
}
# update sigmasq_y
Y_hat = g[, M] + Y - e_m
rate = 0.5*(nu*lambda_s + sum((Y - Y_hat)^2))
sigmasq_y = 1/rgamma(1, shape = (n+nu)/2, rate = rate)
## save result
if(iter > BURNIN & (iter - BURNIN) %% THIN == 0) {
mu_out[[(iter-BURNIN)/THIN]] = mu
sigmasq_y_out[(iter-BURNIN)/THIN] = sigmasq_y
#tree_out[[(iter-BURNIN)/THIN]] = mstgraph_lst
cluster_out[(iter-BURNIN)/THIN, , ] = cluster
log_post_out[(iter-BURNIN)/THIN] = evalLogPost(mu, g, sigmasq_y, k, Y, hyper)
}
# cat(iter, ', ', sep = "")
}
mode(cluster_out) = 'integer' # to save memory
return(list('mu_out' = mu_out,
'sigmasq_y_out' = sigmasq_y_out,
'cluster_out' = cluster_out,
'log_post_out' = log_post_out))
}Let’s take a look at the speed difference. Note that the code below is by default not run at the time of this vignette’s execution as it is extremely slow. It instead loads the results precomputed on a system with an AMD Ryzen 7 3700X processor and 32GB of RAM.
aral_speed_full = microbenchmark(
bastfit = fitBAST(Y_std, graph0, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345),
bastionfit = BASTIONfit(Y_std, graph0, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345),
bastionfit_C = BASTIONfit_C(Y_std, aral_graph, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345),
times = 5
)
save(aral_speed_full, file = "aral_speed_results_full.Rda")datapath = paste0(getwd(), "/aral_speed_results_full.Rda")
attach(what = datapath, name = "precomputed")
speed_table2 = xtable(summary(aral_speed_full), caption = "Speed comparison (seconds) between the different native R implementations and C++.")
detach(precomputed)
print(speed_table2, type = "html", comment = FALSE)| expr | min | lq | mean | median | uq | max | neval | |
|---|---|---|---|---|---|---|---|---|
| 1 | bastfit | 463.18 | 535.07 | 570.30 | 561.74 | 628.88 | 662.62 | 5.00 |
| 2 | bastionfit | 511.21 | 609.50 | 639.30 | 669.46 | 678.88 | 727.43 | 5.00 |
| 3 | bastionfit_C | 94.34 | 94.51 | 97.33 | 96.60 | 97.19 | 104.02 | 5.00 |
Unsurprisingly, the speed with an R implementation leaves much to be desired with even the fastest possible fitting time coming in at almost 7 minutes on a relatively small data set. Further, the original code from the paper is more performant at the cost of being more specific to this particular setup. We see a huge improvement from moving to C++ due to the the power of the Boost graph library and generally lower function call overhead compared to R.
As a sanity check, let’s see the predictive field generated by the different models to see that they are mostly equivalent. Note that despite setting the same seed, the results will still be slightly different due to having a different sequence of RNG function calls. However, each function individually will return the same results given the same seed, including the one implemented in C++ which uses R’s random number generation and not Boost’s.
The code block below is not executed at the time of knitting in the interest of vignette compilation speed.
bastionfit_C = BASTIONfit_C(Y_std, aral_graph, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345)
save(bastionfit_C, file = "aral_fit_C.Rda")
bastfit = fitBAST(Y_std, graph0, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345)
bastionfit = BASTIONfit(Y_std, graph0, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345)
save(bastfit, bastionfit, file = "aral_fits_no_c.Rda")Next, we’ll form the prediction field (a grid in this case due to the nature of our data).
load("aral_fit_C.Rda")
load("aral_fits_no_c.Rda")
mesh = gen2dMesh(coords, bnd)
prediction_scheme = function(fitted_model) {
Y_grid_all = predictBAST(fitted_model,
coords,
coords_grid,
method = 'soft-mesh',
mesh = mesh,
weighting = 'uniform',
return_type = 'all',
seed = 12345)
Y_grid_all = apply(Y_grid_all, 2, unstandardize, std_par = std_par)
# use posterior mean as predictor
Y_grid = rowMeans(Y_grid_all)
# back to original scale
Y_grid = Y_grid + mean_Y
return(Y_grid)
}
# Predict for each of the 3 different models
Y_grid_BAST = prediction_scheme(bastfit)
Y_grid_BASTION = prediction_scheme(bastionfit)
Y_grid_BASTION_C = prediction_scheme(bastionfit_C)Finally, we’ll plot the predictive field for each model as well as a reminder of what the true data looked like:
plot_bast = plotField(coords_grid, Y_grid_BAST, title = 'Predictive Field - BAST') +
xlab('Scaled Long.') + ylab('Scaled Lat.')
plot_bastion = plotField(coords_grid, Y_grid_BASTION, title = 'Predictive Field - BASTION (R)') +
xlab('Scaled Long.') + ylab('Scaled Lat.')
plot_bastion_c = plotField(coords_grid, Y_grid_BASTION_C, title = 'Predictive Field - BASTION (C++)') +
xlab('Scaled Long.') + ylab('Scaled Lat.')
plot_true = ggplot() +
geom_boundary(bnd) +
geom_point(aes(x = lon, y = lat, col = Y), data = as.data.frame(coords)) +
scale_color_gradientn(colours = rainbow(5), name = 'Y') +
labs(x = 'Scaled Lon.', y = 'Scaled Lat.') +
ggtitle('Observed Data')
grid.arrange(plot_true, plot_bast, plot_bastion, plot_bastion_c)As we can see, the predictive fields for the different implementations are virtually identical.