MCMC

Ray, Isaac

The Bayesian Additive Regression Spanning Trees model (BAST) is a novel ensemble model for non-parametric regression intended to be used on data (especially spatial) that lies on a complex or constrained domain, or a space with irregular shape embedded in Euclidean space. At the core of the BAST model is a novel weak learner; a random spanning tree manifold partition model. The model is based upon 4 possible moves/graph operations: birth, death, change, and hyper. The BASTION package implements those 4 core graph operations as well as some utility functions to make it easier to utilize them.

The purpose of this vignette is to:

  1. Provide a motivating example problem using synthetic data
  2. Demonstrate how the core functions of the BASTION package work individually on the data
  3. Give a simple example of how the core functions can be combined to predict responses at testing data

The Data & Motivation

We’ll start by posing a situation. Suppose we are observing spatial points around a black hole, but have no data from within the black hole (our surveyors got too close). Below is a plot of all the locations at which we gathered data (stored in data_hole):

ggplot(data_hole) + geom_point(aes(x, y))

At each data point, our space surveyors observed the value of the function we are interested in (stored in response). Below is a plot of the response data we observed at each point:

ggplot(data_hole) + 
  geom_point(aes(x, y, color = response)) +
  scale_color_viridis()

Unfortunately, all of our space surveyors have been sucked into the black hole. But we still need a way to predict the value of our function of interest at new data points. Before discussing how we will achieve this, let’s discuss the core functions of BASTION.

The constructGraph function

Before we can do any graph based procedures on our data, we need to have a graph! The BASTION package provides the function constructGraph to do this for us. constructGraph takes in two dimensional numeric data (in our case, data_hole) and a parameter k and performs the K-Nearest Neighbors algorithm to generate a weighted igraph graph. Below is a plot of what our graph looks like:

hole_graph = constructGraph(data_hole, 5)
ggraph(hole_graph, layout = data_hole) + 
  geom_edge_link0(edge_alpha = 0.1) + 
  geom_node_point(aes(color = response)) +
  scale_color_viridis()

An important consideration is that the K-Nearest Neighbors algorithm is not guaranteed to generate a connected graph, especially in situations where there is not much data and the chosen k parameter is small. Other methods of constructing a graph, such as Constrained Delaunay triangulation, may be a better choice in such instances. The other functions of the BASTION package do not require that the original graph is made using the constructGraph function, a weighted igraph graph generated by other means will still work so long as it is connected and points share an edge with other points that are ‘nearby’ (we will be using Euclidean distance for simplicity).

Initializing Clusters

Now that we have a graph, let’s separate it into clusters. We are working on the principle that points which are ‘close’ together in space likely also have response values which are ‘close’ together. However, as we will see later, this assumption need not be true once we start performing different graph operations. With this in mind, the constructClusters function will form a partition of our graph into nclust different clusters (10 is chosen arbitrarily) without needing to consider our response data.

The way that constructClusters creates the partition is by first using Prim’s algorithm to find a minimum spanning tree on the input graph hole_graph using its edge weights (which was generated by constructGraph). By property of spanning trees, by removing nclust - 1 edges from this spanning tree, we are left with nclust disconnected trees which together form a spanning forest. We then assign all the vertices in the same disconnected tree to the same cluster. The final spanning forest and cluster membership vector are both available in the list output by constructClusters.

Below is a demonstration of constructClusters on our hole_graph and a plot of which vertices were assigned to each cluster.

clustered_output = constructClusters(hole_graph, nclust = 10)
ggraph(clustered_output$spanning_forest, layout = data_hole) +
  geom_edge_link0(edge_alpha = 0.1) +
  geom_node_point(aes(colour = as.factor(clustered_output$membership))) +
  scale_colour_brewer(palette = "Paired") + 
  labs(colour = "Cluster")

Graph Operations

Now that we have a graph with clusters, we can demonstrate each of the four graph operations which are the core of the BASTION package.

Birth Move

First, we will consider a cluster birth move. The graphBirth function, when provided with a clustered spanning forest graph and corresponding membership (spanning_forest and membership from the constructClusters output), will pick a random edge from the cluster indexed by clust and delete it. By deleting this edge, it splits the tree and creates an additional cluster and we say that a cluster is ‘born’.

Note that clust is an optional parameter; cluster 1 was chosen for demonstration since in this example it happened to be the largest cluster. In practice clust is generally not specified, in which case a cluster is chosen at random from the existing clusters with at least one edge.

birth_output = graphBirth(clustered_output$spanning_forest, clustered_output$membership, clust = 1)
ggraph(birth_output$graph, layout = data_hole) +
  geom_edge_link0(edge_alpha = 0.1) +
  geom_node_point(aes(colour = as.factor(birth_output$membership))) +
  scale_colour_brewer(palette = "Paired") + 
  labs(colour = "Cluster")

As we can see from the plot above, when compared with the original spanning forest graph, the original cluster 1 was split into the above clusters 1 and 3.

Death Move

Next, we will consider the opposite of a cluster birth move: a cluster death move. The graphDeath function, when provided with a clustered spanning forest graph, corresponding membership, and the original full graph (hole_graph in this example), will randomly select an edge from the original graph which connects two clusters and add it back to the graph. In doing so, the two clusters that are connected become a single cluster and we say that a cluster has ‘died’.

death_output = graphDeath(birth_output$graph, birth_output$membership, hole_graph)
ggraph(death_output$graph, layout = data_hole) +
  geom_edge_link0(edge_alpha = 0.1) +
  geom_node_point(aes(colour = as.factor(death_output$membership))) +
  scale_colour_brewer(palette = "Paired") + 
  labs(colour = "Cluster")

As we can see from the plot above, when compared with the plot from the birth move, the cluster which was born was subsequently killed (such is life for a cluster) and we are back to the original spanning forest we started with.

Change Move

A change cluster move is simply the combination of a death move followed by a birth move. The graphChange function, when provided with a clustered spanning forest graph, corresponding membership, and the original full graph, will kill a random cluster then birth a new one. The motivation for this is to encourage faster cluster movement and discourage the graph from settling into a local ‘optimum’ which a single death or birth move alone would be unlikely to remove it from.

change_output = graphChange(death_output$graph, death_output$membership, hole_graph)
ggraph(change_output$graph, layout = data_hole) +
  geom_edge_link0(edge_alpha = 0.1) +
  geom_node_point(aes(colour = as.factor(change_output$membership))) +
  scale_colour_brewer(palette = "Paired") + 
  labs(colour = "Cluster")

As we can see from the plot above, when compared with the plot from the death move, the previous cluster 5 is killed by merging with cluster 2, then cluster 3 birthed a new cluster 10.

Hyper Move

The final core move is a hyper cluster move, performed by the graphHyper function. In a hyper cluster move, every edge from the original graph that connects vertices belonging to the same cluster is randomly assigned a new weight from a \(\textrm{Unif}(0,0.5)\) distribution. Every edge from the original graph that connects vertices belonging to different clusters is randomly assigned a new weight from a \(\textrm{Unif}(0.5,1)\) distribution. Then, a new minimum spanning tree across the entire original graph is created using Prim’s algorithm. Finally, all the edges which connect different clusters are removed.

The end result is a graph that has the same cluster membership for each vertex, but a different set of active edges comprising each tree in the spanning forest. This is beneficial because it changes the available edges that could be deleted in a birth move.

hyper_output = graphHyper(hole_graph, change_output$membership)
ggraph(hyper_output$graph, layout = data_hole) +
  geom_edge_link0(edge_alpha = 0.1) +
  geom_node_point(aes(colour = as.factor(hyper_output$membership))) +
  scale_colour_brewer(palette = "Paired") + 
  labs(colour = "Cluster")

As we can see from the plot above, when compared with the plot from the change move, every vertex belongs to the same cluster (the membership is unchanged), but the set of edges connecting the vertices within each cluster is different.

Simulation Study

We now have all the necessary building blocks to formulate a method for predicting the value of our response at new data points. We will use the response data that our spaghettified space surveyors gathered to train our model, then use that model for prediction.

So, let’s put it all together in a very simple estimation procedure. Suppose we want to try and approximate the response as a piecewise constant function over the domain. Our estimate for the piecewise constant of each cluster will be the mean of the response at the cluster’s member vertices, and the per-cluster error will be the sum of squared difference between each response and our estimate. We’ll assess the total error of our graph by summing the error of each cluster, and penalize the number of clusters by adding a scaled exponential function of the number of clusters to the total error.

Note that the choices we are making here are practically arbitrary; there are lots of valid choices for what loss and penalty functions to use as well as different ways to approximate the response using the clusters. Different choices will lead to different performance characteristics and theoretical guarantees; these choices were made because they are simple to demonstrate. The performance and theory of these different choices is an area of ongoing research.

estimation_procedure = function(membership, response) {
  k = length(unique(membership))
  estimates = rep(0, k)
  errors = rep(0, k)
  point_estimates = rep(0, length(membership))
  for(i in 1:k) {
    estimates[i] = mean(response[membership == i])
    errors[i] = sum((response[membership == i] - estimates[i])^2)
    point_estimates[membership == i] = estimates[i]
  }
  total_penalty = sum(errors) + (exp(k)/length(membership))
  return(list(total_penalty = total_penalty, 
              errors = errors, 
              estimates = estimates, 
              point_estimates = point_estimates))
}

Now, we’ll repeatedly execute the 4 different graph operations on our original hole_graph with different probabilities. We can again arbitrarily choose these probabilities; let’s choose a 2/5 chance for a birth move, 2/5 for a death move, and 1/10 each for a hyper move or change move. In every iteration, we will select a move with the probabilities mentioned, evaluate if the error increases or decreases, and accept the new graph if its error decreased.

Note that since edges play no role in our procedure, a hyper move will always be accepted as it does not affect the total penalty. As above, this way of determining whether to accept a new graph is neither the only way nor necessarily the best; it is chosen for simplicity.

iterations = 2000
current_graph = clustered_output$spanning_forest
current_membership = clustered_output$membership
current_loss = estimation_procedure(current_membership, response)$total_penalty
for(i in 1:iterations) {
  move = sample(1:4, 1, prob = c(0.4, 0.4, 0.1, 0.1))
  if(move == 1) {
    proposed_output = graphBirth(current_graph, current_membership)
    proposed_estimation = estimation_procedure(proposed_output$membership, response)
    if(proposed_estimation$total_penalty <= current_loss) {
      current_graph = proposed_output$graph
      current_membership = proposed_output$membership
      current_loss = proposed_estimation$total_penalty
    }
  } else if(move == 2) {
    proposed_output = graphDeath(current_graph, current_membership, hole_graph)
    proposed_estimation = estimation_procedure(proposed_output$membership, response)
    if(proposed_estimation$total_penalty <= current_loss) {
      current_graph = proposed_output$graph
      current_membership = proposed_output$membership
      current_loss = proposed_estimation$total_penalty
    }
  } else if(move == 3) {
    proposed_output = graphChange(current_graph, current_membership, hole_graph)
    proposed_estimation = estimation_procedure(proposed_output$membership, response)
    if(proposed_estimation$total_penalty <= current_loss) {
      current_graph = proposed_output$graph
      current_membership = proposed_output$membership
      current_loss = proposed_estimation$total_penalty
    }
  } else if(move == 4) {
    proposed_output = graphHyper(hole_graph, current_membership)
    current_graph = proposed_output$graph
    current_membership = proposed_output$membership
  }
}
final_graph = current_graph
final_membership = current_membership
final_estimation_output = estimation_procedure(current_membership, response)
final_penalty = final_estimation_output$total_penalty
final_errors = final_estimation_output$errors

Let’s see how our estimation did. As a reminder, here is our original graph:

ggraph(hole_graph, layout = data_hole) + 
  geom_edge_link0(edge_alpha = 0.1) + 
  geom_node_point(aes(color = response)) +
  scale_color_viridis() +
  labs(color = "Response")

And here is our model, the piecewise constant approximation:

final_point_estimates = final_estimation_output$point_estimates
ggraph(final_graph, layout = data_hole) + 
  geom_edge_link0(edge_alpha = 0.1) + 
  geom_node_point(aes(color = final_point_estimates)) +
  scale_color_viridis() +
  labs(color = "Response Estimate")

Predicting on Test Data

Now that we have a model, let’s predict at all the points our space surveyors missed. Our test data points are stored in test_data_hole. The method by which we will predict is by taking each point to predict at, finding the closest vertex in our model (by Euclidean distance), and assigning it to the same cluster. Below is some code to find the closest vertex for each new point:

closest_point = function(new_point, existing_points) {
  return(unname(which.min(sqrt(colSums((new_point - t(existing_points))^2)))))
}

test_membership = rep(0, nrow(test_data_hole))
for(i in 1:nrow(test_data_hole)) {
  test_membership[i] = final_membership[closest_point(as.numeric(test_data_hole[i,]), data_hole)]
}

test_estimation_output = estimation_procedure(test_membership, test_response)
test_pred_response = test_estimation_output$point_estimates

Here is what the true response function applied to our testing data looks like:

ggplot(test_data_hole) + 
  geom_point(aes(x, y, color = test_response)) +
  scale_color_viridis(limits = c(-300, 500)) +
  labs(color = "Response")

And here is what our model predicted the response function of our testing data to look like:

ggplot(test_data_hole) + 
  geom_point(aes(x, y, color = test_pred_response)) +
  scale_color_viridis(limits = c(-300, 500)) +
  labs(color = "Predicted Response")

Not bad for such a simple procedure! Here is a histogram of the prediction errors from this method:

test_error = test_response - test_pred_response
ggplot() + 
  geom_histogram(aes(x = test_error), bins = 20) +
  labs(x = "Prediction Error")

And a graph of where these errors occur (dotted line overlays denote where discontinuities in the true response function lie):

ggplot(test_data_hole) + 
  geom_point(aes(x, y, color = abs(test_error))) +
  geom_abline(intercept = 0, slope = 1, linetype = "dotted") +
  geom_abline(intercept = 0, slope = -1, linetype = "dotted") +
  scale_color_viridis() +
  labs(color = "Absolute Prediction Error")

As we can see, even with a very simple combination of the four core graph operations we can recover a lot of information about the true response function.