Speed Comparisons

Ray, Isaac

In the effort to improve the computational efficiency of the BAST model, this vignette summarizes various implementations and their speed.

First, a speed comparison on the remote sensed chlorophyll data in the Aral sea; the setup of this application is described in section 4.3 of the following paper:

Luo, Z. T., Sang, H., & Mallick, B. (2021) BAST: Bayesian Additive Regression Spanning Trees for Complex Constrained Domain. Advances in Neural Information Processing Systems 34 (NeurIPS 2021)

The code for the paper’s implementation can be found in this GitHub repository. The majority of the code for this example is taken from this repository.

First, we’ll read in the processed data (information on this can be found in the AralSea.R file at the GitHub repository above)

aralDataUrl = "https://github.com/ztluostat/BAST/raw/main/aral_data.RData"
load(url(aralDataUrl))
n = length(Y)

Here is a plot of what the data looks like:

ggplot() + 
  geom_boundary(bnd) +
  geom_point(aes(x = lon, y = lat, col = Y), data = as.data.frame(coords)) +
  scale_color_gradientn(colours = rainbow(5), name = 'Y') +
  labs(x = 'Scaled Lon.', y = 'Scaled Lat.') + 
  ggtitle('Observed Data')

The BASTION package does not include a function for constructing a constrained Delauney triangulation graph, so the graph will be constructed using the constrainedDentri function from the paper repository and used for all implementations.

aral_graph = constrainedDentri(n, gen2dMesh(coords, bnd)) # Get full graph
E(aral_graph)$eid = c(1:ecount(aral_graph)) # Assign edge ids
V(aral_graph)$vid = c(1:vcount(aral_graph)) # Assign vertex ids
# plot spatial graph
plotGraph(coords, aral_graph) + 
  geom_boundary(bnd) +
  labs(x = 'Scaled Lon.', y = 'Scaled Lat.') + 
  ggtitle('Spatial Graph')

The following is the implementation from the original paper. In the interest of timing, the iteration numbers have been cut down and such changes are commented below.

mstgraph = mst(aral_graph)  # initial spanning tree
graph0 = delete_edge_attr(aral_graph, 'weight')
mstgraph0 = delete_edge_attr(mstgraph, 'weight')

M = 30      # number of weak learners
k_max = 5   # maximum number of clusters per weak learner
mu = list() # initial values of mu
mstgraph_lst = list()  # initial spanning trees
cluster = matrix(1, nrow = n, ncol = M)  # initial cluster memberships
for(m in 1:M) {
  mu[[m]] = c(0)
  mstgraph_lst[[m]] = mstgraph0
}

init_val = list()
init_val[['trees']] = mstgraph_lst
init_val[['mu']] = mu
init_val[['cluster']] = cluster
init_val[['sigmasq_y']] = 1

# standardize Y
std_res = standardize(Y)
Y_std = std_res$x
std_par = std_res$std_par

# find lambda_s
nu = 3; q = 0.9
quant = qchisq(1-q, nu)
lambda_s = quant * var(Y_std) / nu

hyperpar = c()
hyperpar['sigmasq_mu'] = (0.5/(2*sqrt(M)))^2
hyperpar['lambda_s'] = lambda_s
hyperpar['nu'] = nu
hyperpar['lambda_k'] = 4
hyperpar['M'] = M
hyperpar['k_max'] = k_max

# MCMC parameters
# number of posterior samples = (MCMC - BURNIN) / THIN
MCMC = 10000    # MCMC iterations !! Down from 30000 originally !!
BURNIN = 5000  # burnin period length !! Down from 15000 originally !!
THIN = 5        # thinning intervals

Here is an implementation written using the BASTION package (based on the one from the paper)

BASTIONfit = function(Y, graph, init_vals, hyperpars, MCMC, BURNIN, THIN, seed = NULL) {
  if(!is.null(seed)) {
    set.seed(seed)
  }
    n = vcount(graph)
  # hyper-parameter
  M = hyperpars['M']  # number of trees
  sigmasq_mu = hyperpars['sigmasq_mu']
  lambda_s = hyperpars['lambda_s']
  nu = hyperpars['nu']
  lambda_k = hyperpars['lambda_k']
  k_max = hyperpars['k_max']
  hyper = c(sigmasq_mu, lambda_s, nu, lambda_k)
  
  # initial values
  mstgraph_lst = init_vals[['trees']]
  mu = init_vals[['mu']]
  cluster = init_vals[['cluster']]  # n*M matrix
  sigmasq_y = init_vals[['sigmasq_y']]
  k = as.numeric(apply(cluster, 2, max))  # number of clusters
  csize = list() # cluster size
  g = matrix(0, nrow = n, ncol = M)  # n*M matrix of fitted mu's
  for(m in 1:M) {
    cluster_m = cluster[, m]
    g[, m] = mu[[m]][cluster_m]
    csize[[m]] = Rfast::Table(cluster_m)
    mstgraph_m = mstgraph_lst[[m]]
  }
  
  # whether an edge in graph is within a cluster or between two clusters
  # n*M matrix
  
  ################# MCMC ####################
  
  ## MCMC results
  mu_out = list()
  sigmasq_y_out = numeric((MCMC-BURNIN)/THIN)
  cluster_out = array(0, dim = c((MCMC-BURNIN)/THIN, n, M))
  tree_out = list()
  log_post_out = numeric((MCMC-BURNIN)/THIN)
  
  ## MCMC iteration
  for(iter in 1:MCMC) {
    for(m in 1:M) {
      k_m = k[m]
      mstgraph_m = mstgraph_lst[[m]]
      cluster_m = cluster[, m]
      csize_m = csize[[m]]
      
      e_m = Y - Rfast::rowsums(g[, -m])
      
      if(k_m == 1) {
        rb = 0.9; 
        rd = 0; 
        rc = 0; 
        rhy = 0.1
      } else if(k_m == min(k_max, n)) {
        rb = 0;
        rd = 0.6;
        rc = 0.3;
        rhy = 0.1
      } else {
        rb = 0.3;
        rd = 0.3;
        rc = 0.3;
        rhy = 0.1
      }
      move = sample(4, 1, prob = c(rb, rd, rc, rhy))
      # current_comp = components(mstgraph_m)
      
      if(move == 1) { ## Birth move
        # split an existing cluster
        clust_split = sample.int(k_m, 1, prob = (csize_m - 1))
        proposed_output = graphBirth(mstgraph_m, cluster_m, clust_split)
        vid_new = proposed_output$new_clust_ids
        vid_old = proposed_output$old_clust_ids
        
        # compute log-prior ratio
        log_A = log(lambda_k) - log(k_m + 1)
        # compute log-proposal ratio
        if(k_m == min(k_max, n)-1) {
          rd_new = 0.6
        } else {
          rd_new = 0.3
        }
        log_P = log(rd_new) - log(rb)
        # compute log-likelihood ratio
        ##
        sigma_ratio = sigmasq_y / sigmasq_mu
        csize_old = length(vid_old)
        csize_new = length(vid_new)
        sum_e_old = sum(e_m[vid_old])
        sum_e_new = sum(e_m[vid_new])
        logdetdiff = -0.5 * (log(csize_old + sigma_ratio) + 
                             log(csize_new + sigma_ratio) - 
                             log(csize_old + csize_new + sigma_ratio) - 
                             log(sigma_ratio))
        quaddiff = 0.5 * ( sum_e_old^2 / (csize_old + sigma_ratio) + 
                           sum_e_new^2 / (csize_new + sigma_ratio) -
                           (sum_e_old + sum_e_new)^2 / (csize_old + csize_new + sigma_ratio)
                          ) / sigmasq_y
        ##
        log_L = logdetdiff + quaddiff
        
        #acceptance probability
        acc_prob = min(0, log_A + log_P + log_L)
        acc_prob = exp(acc_prob)
        if(runif(1) < acc_prob){
          # accept
          mstgraph_lst[[m]] = proposed_output$graph
          csize[[m]] = components(proposed_output$graph)$csize
          cluster[, m] = proposed_output$membership
          k[m] = k[m] + 1
        }
      }
      
      if(move == 2) { ## Death move
        # merge two existing clusters (c1, c2) -> c2
        proposed_output = graphDeath(mstgraph_m, cluster_m, graph)
        vid_old = proposed_output$old_clust_ids
        vid_new = proposed_output$new_clust_ids
        
        # compute log-prior ratio
        log_A = -log(lambda_k) + log(k_m)
        # # compute log-proposal ratio
        if(k_m == 2) {
          rb_new = 0.9
        }else {
          rb_new = 0.3
        }
        log_P = -(log(rd) - log(rb_new))
        # compute log-likelihood ratio
        sigma_ratio = sigmasq_y / sigmasq_mu
        csize_old = length(vid_old)
        csize_new = length(vid_new)
        sum_e_old = sum(e_m[vid_old])
        sum_e_new = sum(e_m[vid_new])
        logdetdiff = -0.5 * (log(csize_new+sigma_ratio) - 
                             log(csize_old+sigma_ratio) - 
                             log(csize_new-csize_old+sigma_ratio) + 
                             log(sigma_ratio))
        quaddiff = 0.5 * (  sum_e_new^2/(csize_new+sigma_ratio) - 
                            sum_e_old^2/(csize_old+sigma_ratio) - 
                            (sum_e_new-sum_e_old)^2/
                              (csize_new-csize_old+sigma_ratio)
                          ) / sigmasq_y
        log_L = logdetdiff + quaddiff
        
        # acceptance probability
        acc_prob = min(0, log_A + log_P + log_L)
        acc_prob = exp(acc_prob)
        if(runif(1) < acc_prob){
          # accept
          mstgraph_lst[[m]] = proposed_output$graph
          csize[[m]] = components(proposed_output$graph)$csize
          cluster[, m] = proposed_output$membership
          k[m] = k[m] - 1
        }
      }
      
      if(move == 3) { ## change move
        # # first perform death move: (c1, c2) -> c2
        # merge_res = mergeCluster(mstgraph_m, eid_btw_mst_m, subgraphs_m, csize_m, 
        #                          cluster_m, inc_mat, change = T)
        # # then perform birth move
        # split_res = splitCluster(mstgraph_m, k_m-1, merge_res$subgraphs, merge_res$csize)
        
        proposed_output = graphChange(mstgraph_m, cluster_m, graph)
        
        vid_old_d = proposed_output$old_dclust_ids
        vid_new_d = proposed_output$new_dclust_ids
        
        vid_old_b = proposed_output$old_bclust_ids
        vid_new_b = proposed_output$new_bclust_ids
        
        
        # compute log-likelihood ratio
        # First do it for the death move
        sigma_ratio = sigmasq_y / sigmasq_mu
        csize_old_d = length(vid_old_d)
        csize_new_d = length(vid_new_d)
        sum_e_old_d = sum(e_m[vid_old_d])
        sum_e_new_d = sum(e_m[vid_new_d])
        logdetdiff_d = -0.5 * (log(csize_new_d+sigma_ratio) - 
                             log(csize_old_d+sigma_ratio) - 
                             log(csize_new_d-csize_old_d+sigma_ratio) + 
                             log(sigma_ratio))
        quaddiff_d = 0.5 * (  sum_e_new_d^2/(csize_new_d+sigma_ratio) - 
                            sum_e_old_d^2/(csize_old_d+sigma_ratio) - 
                            (sum_e_new_d-sum_e_old_d)^2/(csize_new_d-csize_old_d+sigma_ratio)
                          ) / sigmasq_y
        log_L_death = logdetdiff_d + quaddiff_d
        # Now do it for the birth move
        csize_old_b = length(vid_old_b)
        csize_new_b = length(vid_new_b)
        sum_e_old_b = sum(e_m[vid_old_b])
        sum_e_new_b = sum(e_m[vid_new_b])
        logdetdiff_b = -0.5 * (log(csize_old_b + sigma_ratio) + 
                             log(csize_new_b + sigma_ratio) - 
                             log(csize_old_b + csize_new_b + sigma_ratio) - 
                             log(sigma_ratio))
        quaddiff_b = 0.5 * ( sum_e_old_b^2 / (csize_old_b + sigma_ratio) + 
                           sum_e_new_b^2 / (csize_new_b + sigma_ratio) -
                           (sum_e_old_b + sum_e_new_b)^2 / (csize_old_b + csize_new_b + sigma_ratio)
                          ) / sigmasq_y
        log_L_birth = logdetdiff_b + quaddiff_b
        # Add them together
        log_L = log_L_birth + log_L_death
        
        # acceptance probability
        acc_prob = min(0, log_L)
        acc_prob = exp(acc_prob)
        if(runif(1) < acc_prob){
          # accept
          mstgraph_lst[[m]] = proposed_output$graph
          csize[[m]] = components(proposed_output$graph)$csize
          cluster[, m] = proposed_output$membership
        }
      }
      
      if(move == 4) {
        # update MST
        proposed_output = graphHyper(graph, cluster_m)
        mstgraph_lst[[m]] = proposed_output$graph
        cluster[, m] = proposed_output$membership
      }
      
      # update mu_m
      k_m = k[m]
      cluster_m = cluster[, m]
      csize_m = csize[[m]]
      Qinv_diag = 1 / (csize_m/sigmasq_y + 1/sigmasq_mu)
      b = Qinv_diag * Rfast::group.sum(e_m, cluster_m) / sigmasq_y
      mu[[m]] = rnorm(k_m, b, sqrt(Qinv_diag))
      
      g[, m] = mu[[m]][cluster_m]
    }
    
    # update sigmasq_y
    Y_hat = g[, M] + Y - e_m
    rate = 0.5*(nu*lambda_s + sum((Y - Y_hat)^2))
    sigmasq_y = 1/rgamma(1, shape = (n+nu)/2, rate = rate)
    
    
    ## save result
    if(iter > BURNIN & (iter - BURNIN) %% THIN == 0) {
      mu_out[[(iter-BURNIN)/THIN]] = mu
      sigmasq_y_out[(iter-BURNIN)/THIN] = sigmasq_y
      #tree_out[[(iter-BURNIN)/THIN]] = mstgraph_lst
      cluster_out[(iter-BURNIN)/THIN, , ] = cluster
      
      log_post_out[(iter-BURNIN)/THIN] = evalLogPost(mu, g, sigmasq_y, k, Y, hyper)
    }
    
    # cat(iter, ', ', sep = "")
  }
  
  mode(cluster_out) = 'integer'  # to save memory
  return(list('mu_out' = mu_out,
              'sigmasq_y_out' = sigmasq_y_out,
              'cluster_out' = cluster_out, 
              'log_post_out' = log_post_out))
}

Let’s take a look at the speed difference. Note that the code below is by default not run at the time of this vignette’s execution as it is extremely slow. It instead loads the results precomputed on a system with an AMD Ryzen 7 3700X processor and 32GB of RAM.

aral_speed_full = microbenchmark(
  bastfit = fitBAST(Y_std, graph0, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345),
  bastionfit = BASTIONfit(Y_std, graph0, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345),
  bastionfit_C = BASTIONfit_C(Y_std, aral_graph, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345),
  times = 5
)

save(aral_speed_full, file = "aral_speed_results_full.Rda")
datapath = paste0(getwd(), "/aral_speed_results_full.Rda")
attach(what = datapath, name = "precomputed")
speed_table2 = xtable(summary(aral_speed_full), caption = "Speed comparison (seconds) between the different native R implementations and C++.")
detach(precomputed)
print(speed_table2, type = "html", comment = FALSE)
Speed comparison (seconds) between the different native R implementations and C++.
expr min lq mean median uq max neval
1 bastfit 463.18 535.07 570.30 561.74 628.88 662.62 5.00
2 bastionfit 511.21 609.50 639.30 669.46 678.88 727.43 5.00
3 bastionfit_C 94.34 94.51 97.33 96.60 97.19 104.02 5.00

Unsurprisingly, the speed with an R implementation leaves much to be desired with even the fastest possible fitting time coming in at almost 7 minutes on a relatively small data set. Further, the original code from the paper is more performant at the cost of being more specific to this particular setup. We see a huge improvement from moving to C++ due to the the power of the Boost graph library and generally lower function call overhead compared to R.

As a sanity check, let’s see the predictive field generated by the different models to see that they are mostly equivalent. Note that despite setting the same seed, the results will still be slightly different due to having a different sequence of RNG function calls. However, each function individually will return the same results given the same seed, including the one implemented in C++ which uses R’s random number generation and not Boost’s.

The code block below is not executed at the time of knitting in the interest of vignette compilation speed.

bastionfit_C = BASTIONfit_C(Y_std, aral_graph, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345)
save(bastionfit_C, file = "aral_fit_C.Rda")
bastfit = fitBAST(Y_std, graph0, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345)
bastionfit = BASTIONfit(Y_std, graph0, init_val, hyperpar, MCMC, BURNIN, THIN, seed = 12345)
save(bastfit, bastionfit, file = "aral_fits_no_c.Rda")

Next, we’ll form the prediction field (a grid in this case due to the nature of our data).

load("aral_fit_C.Rda")
load("aral_fits_no_c.Rda")

mesh = gen2dMesh(coords, bnd)

prediction_scheme = function(fitted_model) {
  Y_grid_all = predictBAST(fitted_model, 
                           coords, 
                           coords_grid, 
                           method = 'soft-mesh', 
                           mesh = mesh, 
                           weighting = 'uniform', 
                           return_type = 'all', 
                           seed = 12345)
  Y_grid_all = apply(Y_grid_all, 2, unstandardize, std_par = std_par)

  # use posterior mean as predictor
  Y_grid = rowMeans(Y_grid_all)
  
  # back to original scale
  Y_grid = Y_grid + mean_Y
  return(Y_grid)
}

# Predict for each of the 3 different models
Y_grid_BAST = prediction_scheme(bastfit)
Y_grid_BASTION = prediction_scheme(bastionfit)
Y_grid_BASTION_C = prediction_scheme(bastionfit_C)

Finally, we’ll plot the predictive field for each model as well as a reminder of what the true data looked like:

plot_bast = plotField(coords_grid, Y_grid_BAST, title = 'Predictive Field - BAST') +
  xlab('Scaled Long.') + ylab('Scaled Lat.')

plot_bastion = plotField(coords_grid, Y_grid_BASTION, title = 'Predictive Field - BASTION (R)') +
  xlab('Scaled Long.') + ylab('Scaled Lat.')

plot_bastion_c = plotField(coords_grid, Y_grid_BASTION_C, title = 'Predictive Field - BASTION (C++)') +
  xlab('Scaled Long.') + ylab('Scaled Lat.')

plot_true = ggplot() + 
  geom_boundary(bnd) +
  geom_point(aes(x = lon, y = lat, col = Y), data = as.data.frame(coords)) +
  scale_color_gradientn(colours = rainbow(5), name = 'Y') +
  labs(x = 'Scaled Lon.', y = 'Scaled Lat.') + 
  ggtitle('Observed Data')

grid.arrange(plot_true, plot_bast, plot_bastion, plot_bastion_c)

As we can see, the predictive fields for the different implementations are virtually identical.