Gen Bayes Code Examples: COBAL-EBEB, 2024

Nick Polson and Vadim Sokolov

Plan for Today

  • Generative model using a quantile neural network
  • Hierarchical Bayesian model for classification
  • Will use PyTorch and JAX

Quantile Neural Network for Synthetic Data

Code
import numpy as np
import torch
import matplotlib.pyplot as plt
import scipy.stats 
# Sin
n = 10000
# x = np.linspace(-1,1, n)
np.random.seed(8)
x = np.random.uniform(-1,1,(n))
x = np.sort(x)
eps = np.random.normal(0,np.exp(1-x)/10)
mu = np.sin(np.pi*x)/(np.pi*x)
y  = mu + eps
def truef(tau):
    return torch.sin(torch.pi*x)/(torch.pi*x) + torch.sqrt(torch.exp(1-x)/10)*scipy.stats.norm.ppf(tau)

Convert Data to PyTorch Tensors

Code
x  = torch.as_tensor(x,dtype=torch.float32).view(-1,1)
y  = torch.as_tensor(y,dtype=torch.float32)
plt.scatter(x,y,s=1)    

Define the Model

Code
import torch.nn as nn
class QuantNet(nn.Module):
    def __init__(self, xsz=1):
        super(QuantNet, self).__init__()
        self.nh = 64
        hsz = 32
        hsz1 = 32
        self.fcx = nn.Linear(xsz, hsz)
        self.fctau = nn.Linear(self.nh, hsz)
        self.fcxtau = nn.Linear(hsz , hsz1)
        self.fcxtau1 = nn.Linear(hsz1 , hsz1)
        self.fc = nn.Linear(hsz1 , 2)
    def forward(self, x,tau):
        tau = torch.cos(torch.arange(start=0,end=self.nh)*torch.pi*tau)
        tau = torch.relu(self.fctau(tau)) # function phi from paper 
        x = torch.relu(self.fcx(x)) # function psi
        x = torch.relu(self.fcxtau(x*tau)) # first layer of function g
        x = torch.relu(self.fcxtau1(x)) # second layer of function g
        x = self.fc(x) # third layer of function g
        return x

Model Estimation (a.k.a. Training)

Code
def train(model, x,y,optimizer,epochs):
    lv = np.zeros(epochs)
    for t in range(epochs):
        tau = torch.rand(1).item()
        f = model(x,tau)
        e = y.view(-1,1)-f
        loss = 0.1*torch.mean(torch.square(e[:,0]))
        # loss = 0
        loss += torch.mean(torch.maximum(tau*e[:,1],(tau-1)*e[:,1]))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lv[t] = loss.item()
        if t % 2000 == 0:
            print(f"Epoch {t}: Loss = {loss.item():>7f}")
    print(f"Epoch {t}: Loss = {loss.item():>7f}")
    return lv

Create Model

Code
torch.manual_seed(8) # ovi
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)

model = QuantNet()
model.apply(init_weights)
optimizer = torch.optim.RMSprop(model.parameters())
print(model)
QuantNet(
  (fcx): Linear(in_features=1, out_features=32, bias=True)
  (fctau): Linear(in_features=64, out_features=32, bias=True)
  (fcxtau): Linear(in_features=32, out_features=32, bias=True)
  (fcxtau1): Linear(in_features=32, out_features=32, bias=True)
  (fc): Linear(in_features=32, out_features=2, bias=True)
)

Run the Estimation (Training)

Code
lv = train(model, x,y,optimizer,1000)
plt.plot(lv)
Epoch 0: Loss = 0.382400
Epoch 999: Loss = 0.135881

Make Predictions and Plot

Code
plt.scatter(x,y,s=1, label='Data');
plt.plot(x, model(x,0.05).detach().numpy()[:,1],'g-', linewidth=2, label='5% Percentile');
plt.plot(x, truef(0.05),'g--', linewidth=2, label='True 5% Percentile');
plt.plot(x, model(x,0.95).detach().numpy()[:,1],'r-', linewidth=2, label='95% Percentile');
plt.plot(x, truef(0.95),'r--', linewidth=2, label='True 95% Percentile');
plt.plot(x, model(x,0.5).detach().numpy()[:,1],'b-', linewidth=2, label='50% Percentile');
plt.plot(x, truef(0.5),'b--', linewidth=2, label='True 50% Percentile');
plt.legend(loc='lower right')

Predict at a specific location

Code
xn = 0.5
print(np.exp(1-xn)/10)
ns = 5000
tau = torch.rand(ns)
yhat = np.empty(ns)
for i in range(ns):
    yhati  = model(torch.as_tensor([xn],dtype=torch.float32),tau[i])
    yhat[i] = yhati[1].detach().numpy()
np.std(yhat)
plt.hist(yhat,bins=50);
0.1648721270700128

Fixed Quantile Neural Network

Code
class FixedQuantNet(nn.Module):
    def __init__(self, xsz=1, tau = [0.5]):
        super(FixedQuantNet, self).__init__()
        self.nh = 64
        hsz = 64
        self.fcx = nn.Linear(xsz, hsz)
        self.tau = tau
        self.nq = len(self.tau)
        self.fc = nn.Linear(hsz , self.nq + 1)
    def forward(self, x):
        x = torch.relu(self.fcx(x))
        x = self.fc(x)
        return x
    def train(self,x,y,optimizer,epochs):
        lv = np.zeros(epochs)
        for t in range(epochs):
            f = self(x)
            e = y.view(-1,1)-f
            loss = 0.1*torch.mean(torch.square(e[:,0]))
            # loss = 0
            for i in range(self.nq):
                loss += torch.mean(torch.maximum(self.tau[i]*e[:,i+1],(self.tau[i]-1)*e[:,i+1]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lv[t] = loss.item()
            if t % 2000 == 0:
                print(f"Epoch {t}: Loss = {loss.item():>7f}")
        print(f"Epoch {t}: Loss = {loss.item():>7f}")
        return lv

Train the Model

Code
modelf = FixedQuantNet(tau=[0.05,0.5,0.95])
modelf.apply(init_weights)
optimizer = torch.optim.RMSprop(modelf.parameters())
lv = modelf.train(x,y,optimizer,200)
yhat = modelf(x).detach().numpy()
plt.plot(lv)
Epoch 0: Loss = 1.013725
Epoch 199: Loss = 0.219158

Plots

Code
# model.train(x,y,optimizer,1000)
plt.scatter(x,y,s=1, label='Data');
plt.plot(x, yhat[:,1],'g-', linewidth=2, label='5% Percentile'); plt.plot(x, truef(0.05),'g--', linewidth=2, label='True 5% Percentile');
plt.plot(x, yhat[:,3],'r-', linewidth=2, label='95% Percentile'); plt.plot(x, truef(0.95),'r--', linewidth=2, label='True 95% Percentile');
plt.plot(x, yhat[:,2],'b-', linewidth=2, label='50% Percentile'); plt.plot(x, truef(0.5),'b--', linewidth=2, label='True 50% Percentile');
plt.legend(loc='lower right')

Hierarchical Bayesian Model

We consider the following model: \[\begin{align*} \tau &\sim \mathrm{Gamma}(0.5, 0.5) \\ \lambda_d &\sim \mathrm{Gamma}(0.5, 0.5) \\ \beta_d &\sim \mathcal{N}(0, 20) \\ y_n &\sim \mathrm{Bernoulli}(\sigma((\tau \lambda \odot \beta)^T x_n))), \end{align*}\] - \(\tau\) is a scalar global coefficient scale - \(\lambda\) is a vector of local scales - \(\beta\) is the vector of unscaled coefficients,

Code
import tensorflow_probability.substrates.jax as tfp
import numpy as np
import jax.numpy as jnp
tfd = tfp.distributions
import jax
import matplotlib.pyplot as plt

Apply to Iris Data

Code
import pandas as pd
iris = pd.read_csv("data/iris.csv")
print(iris.shape)
print(iris.head())
y = iris['variety']=="Setosa"
x = iris["sepal.length"].values
plt.scatter(x,y)
x = np.c_[np.ones(150),x]
print(x.shape)
(150, 5)
   sepal.length  sepal.width  petal.length  petal.width variety
0           5.1          3.5           1.4          0.2  Setosa
1           4.9          3.0           1.4          0.2  Setosa
2           4.7          3.2           1.3          0.2  Setosa
3           4.6          3.1           1.5          0.2  Setosa
4           5.0          3.6           1.4          0.2  Setosa
(150, 2)

Log Density Function

For Hamiltonian MC, we only need to evaluate the joint log-density pointwise

Code
def joint_log_prob(x, y, tau, lamb, beta):
    lp = tfd.Gamma(0.5, 0.5).log_prob(tau)
    lp += tfd.Gamma(0.5, 0.5).log_prob(lamb).sum() 
    lp += tfd.Normal(0., 20).log_prob(beta).sum() 
    logits = x @ (tau * lamb * beta)
    lp += tfd.Bernoulli(logits).log_prob(y).sum() 
    return lp   

tau = np.random.rand(1)[0]
lamb = np.random.rand(2)
beta = np.random.rand(2)
joint_log_prob(x, y, tau, lamb, beta)
Array(-126.7892, dtype=float32)

Change of Variables

\[ z\triangleq T^{-1}(\theta),\qquad \pi(z) = \pi(\theta) \left| \frac{\partial T}{\partial z} (z) \right|, \]

Taking the logarithm of both sides, we get \[ \log \pi(z) = \log \pi(\theta) + \log \left| \frac{\partial T }{\partial z}(z) \right| \] Use \(T(z)=e^z\), and \(\log|\frac{\partial T}{\partial z}(z)| = z\)

Change of Variables in Code

Code
def unconstrained_joint_log_prob(x, y, theta):
    ndims = x.shape[-1]
    unc_tau, unc_lamb, beta = jnp.split(theta, [1, 1 + ndims])
    unc_tau = unc_tau.reshape([]) 
    # Make unc_tau a scalar 
    tau = jnp.exp(unc_tau)
    ldj = unc_tau
    lamb = jnp.exp(unc_lamb)
    ldj += unc_lamb.sum()
    return joint_log_prob(x, y, tau, lamb, beta) + ldj

Let’s check out function

Code
from functools import partial
target_log_prob = partial(unconstrained_joint_log_prob, x, y)
theta = np.r_[tau,lamb,beta]
target_log_prob(theta)
Array(-530.62317, dtype=float32)

Automatic Differentiation

Code
target_log_prob_and_grad = jax.value_and_grad(target_log_prob)
tlp, tlp_grad = target_log_prob_and_grad(theta)
print(tlp)
tlp_grad
-530.62317
Array([ -511.1795  ,   -95.032646,  -416.80698 ,  -185.95103 ,
       -1782.1393  ], dtype=float32)

Hamiltonian Monte Carlo

Code
def leapfrog_step(target_log_prob_and_grad, step_size, i, leapfrog_state):
    z, m, tlp, tlp_grad = leapfrog_state
    m += 0.5 * step_size * tlp_grad
    z += step_size * m
    tlp, tlp_grad = target_log_prob_and_grad(z)
    m += 0.5 * step_size * tlp_grad
    return z, m, tlp, tlp_grad

def hmc_step(target_log_prob_and_grad, num_leapfrog_steps, step_size, z, seed):
    m_seed, mh_seed = jax.random.split(seed)
    tlp, tlp_grad = target_log_prob_and_grad(z)
    m = jax.random.normal(m_seed, z.shape)
    energy = 0.5 * jnp.square(m).sum() - tlp
    new_z, new_m, new_tlp, _ = jax.lax.fori_loop(
        0,
        num_leapfrog_steps,
        partial(leapfrog_step, target_log_prob_and_grad, step_size),
        (z, m, tlp, tlp_grad)) 
    new_energy = 0.5 * jnp.square(new_m).sum() - new_tlp
    log_accept_ratio = energy - new_energy
    is_accepted = jnp.log(jax.random.uniform(mh_seed, [])) < log_accept_ratio
    # select the proposed state if accepted
    z = jnp.where(is_accepted, new_z, z)
    hmc_output = {"z": z,
                  "is_accepted": is_accepted,
                  "log_accept_ratio": log_accept_ratio}
    # hmc_output["z"] has shape [num_dimensions]
    return z, hmc_output

def hmc(target_log_prob_and_grad, num_leapfrog_steps, step_size, num_steps, z,
        seed):
    # create a seed for each step
    seeds = jax.random.split(seed, num_steps)
    # this will repeatedly run hmc_step and accumulate the outputs
    _, hmc_output = jax.lax.scan(
        partial(hmc_step, target_log_prob_and_grad, num_leapfrog_steps, step_size),
        z, seeds)
    # hmc_output["z"] now has shape [num_steps, num_dimensions]
    return hmc_output

def scan(f, state, xs):
  output = []
  for x in xs:
    state, y = f(state, x)
    output.append(y)
  return state, jnp.stack(output)

HMC

Code
num_leapfrog_steps=30
step_size = 0.008
from jax import random
seed = random.PRNGKey(92)
num_samples=10000
hmc_output = hmc(target_log_prob_and_grad, num_leapfrog_steps, step_size,
    num_samples, theta, seed)

Inspect the results

Code
ndims = x.shape[-1]
thetap = hmc_output['z']
taup, lambp, betap = jnp.split(thetap, [1, 1 + ndims],axis=1)
skip = 2000
slope = betap[skip:,1]
intercept = betap[skip:,0]
plt.hist(slope,bins=50);
plt.hist(intercept,bins=50);

Inspect the results

Code
print(np.quantile(slope,[0.05,0.95,0.5]))
print(np.quantile(intercept,[0.05,0.95,0.5]))
[-14.40216732  -1.37361898  -5.5752933 ]
[17.57117882 38.81280651 28.22325706]

Summary

  • We have seen how to implement a quantile neural network in PyTorch
  • We used two models: one with fixed quantiles and one with quantile as input
  • Fixed quantile model is easier to implement and train
  • Quantile as input model is more flexible and can be used for more complex problems
  • Both models can be used for quantile regression
  • Both model preform well on synthetic data