The Wasserstein prior for the normal linear regression model

the normal linear regression model, \[ y_i = {\bf x}_i^{\top} \beta + \epsilon_i, \,\,\,\,\, i = 1,\dots,n,. \] where \({\bf x}_i^{\top}\in{\mathbb R}^p\) is a vector of covariates, \(\beta\in{\mathbb R}^p\) is a vector of regression coefficients, \(\epsilon_i \stackrel{i.i.d.}{\sim} N(0,\sigma^2)\) denote the errors. Let \({\bf X} = (x_1,\dots,x_n)^{\top}\) denote the design matrix and \({\bf y} = (y_1,\dots,y_n)^{\top}\) the vector of response variables.

Li and Rubio (2022) calculated the The Wasserstein prior for $(,) $ in the normal linear regression model \[ \pi_W(\beta,\sigma) \propto 1, \] They showed that the posterior distribution is proper if \(n>p+1\)

The following R code shows an example of the use of the Wasserstein prior in a linear regression model.

Marketing data

# Requierd packages
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.3     ✓ purrr   0.3.4
## ✓ tibble  3.1.6     ✓ dplyr   1.0.7
## ✓ tidyr   1.1.4     ✓ stringr 1.4.0
## ✓ readr   1.4.0     ✓ forcats 0.5.0
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(Rtwalk)
## Package: Rtwalk, Version: 1.8.0
## The R Implementation of 't-walk' MCMC Sampler
## http://www.cimat.mx/~jac/twalk/
## For citations, please use:
## Christen JA and Fox C (2010). A general purpose sampling algorithm for
## continuous distributions (the t-walk). Bayesian Analysis, 5(2),
## pp. 263-282. <URL:
## http://ba.stat.cmu.edu/journal/2010/vol05/issue02/christen.pdf>.
library(TeachingDemos)

# Marketing data
data("marketing", package = "datarium")
head(marketing, 4)
##   youtube facebook newspaper sales
## 1  276.12    45.36     83.04 26.52
## 2   53.40    47.16     54.12 12.48
## 3   20.64    55.08     83.16 11.16
## 4  181.80    49.56     70.20 22.20
model <- lm(sales ~ youtube + facebook + newspaper, data = marketing)
summary(model)
## 
## Call:
## lm(formula = sales ~ youtube + facebook + newspaper, data = marketing)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -10.5932  -1.0690   0.2902   1.4272   3.3951 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  3.526667   0.374290   9.422   <2e-16 ***
## youtube      0.045765   0.001395  32.809   <2e-16 ***
## facebook     0.188530   0.008611  21.893   <2e-16 ***
## newspaper   -0.001037   0.005871  -0.177     0.86    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 2.023 on 196 degrees of freedom
## Multiple R-squared:  0.8972, Adjusted R-squared:  0.8956 
## F-statistic: 570.3 on 3 and 196 DF,  p-value: < 2.2e-16
###################################################################################
# Log posterior with Wasserstein prior (flat prior)
###################################################################################
y <- marketing$sales
X <- cbind(1,marketing$youtube, marketing$facebook, marketing$newspaper)

# -Log posterior
lpw <- function(par){
  beta <- par[1:ncol(X)]; sigma <- exp(par[ncol(X)+1])
  return(-sum(dnorm(y - X%*%beta,0, sigma,log=T)) )
}

# Number of iterations 
NMH = 110000

# Support function 
Support <- function(x) {TRUE}

# Function to generate the initial points in the sampler
X0 <- function(x) { c(coef(model),summary(model)$sigma) + runif(ncol(X)+1, -0.1, 0.1) }

# Posterior samples
set.seed(1234)
outw <- Runtwalk( dim=5,  Tr=NMH,  Obj=lpw, Supp=Support, x0=X0(), xp0=X0(),PlotLogPost = FALSE) 
## This is the twalk for R.
## Evaluating the objective density at initial values.
## Opening 12 X 110001 matrix to save output.  Sampling (no graphics mode).
# thin-in and burn-in
burn = 10000
thin = 100
ind = seq(burn,NMH,thin)

betap1 = outw$output[ , 1][ind]
betap2 = outw$output[ , 2][ind]
betap3 = outw$output[ , 3][ind]
betap4 = outw$output[ , 4][ind]
sigmap = exp(outw$output[ , 5][ind])

# Some histograms and summaries
# The red line shows the corresponding MLE
hist(betap1, probability = TRUE)
abline(v=coef(model)[1],col="red",lwd=2)
box()

hist(betap2, probability = TRUE)
abline(v=coef(model)[2],col="red",lwd=2)
box()

hist(betap3, probability = TRUE)
abline(v=coef(model)[3],col="red",lwd=2)
box()

hist(betap4, probability = TRUE)
abline(v=coef(model)[4],col="red",lwd=2)
box()

hist(sigmap, probability = TRUE)
abline(v=summary(model)$sigma,col="red",lwd=2)
box()

# Posterior means
apply(cbind(betap1,betap2,betap3,betap4,sigmap),2,mean)
##       betap1       betap2       betap3       betap4       sigmap 
##  3.500562886  0.045839778  0.188988478 -0.001057659  2.037425894
# Posterior medians
apply(cbind(betap1,betap2,betap3,betap4,sigmap),2,median)
##       betap1       betap2       betap3       betap4       sigmap 
##  3.493727397  0.045852163  0.189223885 -0.001189728  2.037161717
# Posterior highest posterior density intervals
apply(cbind(betap1,betap2,betap3,betap4,sigmap),2,emp.hpd)
##        betap1     betap2    betap3      betap4   sigmap
## [1,] 2.757585 0.04333058 0.1717954 -0.01406317 1.844173
## [2,] 4.222351 0.04879899 0.2064420  0.01020764 2.240368
# Traceplots
par(mfrow = c(3,2))
plot(betap1, type = "l", lwd = 2)
plot(betap2, type = "l", lwd = 2)
plot(betap3, type = "l", lwd = 2)
plot(betap4, type = "l", lwd = 2)
plot(sigmap, type = "l", lwd = 2)

Li, W., and F. J. Rubio. 2022. “On a Prior Based on the Wasserstein Information Matrix.” Submitted.