Comparing ordinary least squares and gradient descent for linear model fitting
Canonical multivariate ordinary least squares (OLS) has dimensions as columns and samples as rows. When the response is also multivariate (i.e., when the model produces a matrix rather than a vector of numbers), we assume that the dimensions of the response are independent of each other (if we want to, we can estimate the variance-covariance matrix that relates the columns of the response matrix, but that’s not a byproduct of OLS). Let’s simulate some data to enforce a ground truth and verify that gradient descent arrives there.
Simulate gene expression
Since we are dealing with counts, we will model them with Poisson emissions. We will start by simulating a matrix of mRNA counts for 1000 cells for 5 genes. Since we are hoping to fit models for proteins related to each gene later on, we will specify the “expression groups” of each cell, then generate counts based on those groupings.
# let's set this up as a Poisson generating model, with one parameter# that one parameter ("lambda") controls both the mean and the varianceintercept <-c(gene1=0.1, gene2=0, gene3=0.1, gene4=0, gene5=0.1)specific <-c(gene1=4.9, gene2=1, gene3=1.9, gene4=3, gene5=9.9)# the "high expressor" cells will have counts with a mean of intercept + specific# the "low expressor" cells will have counts with a mean of just the interceptcoefs <-cbind(intercept, specific)exprgroups <-rbind(low=c(1, 0), high=c(1,1))(lambda <- exprgroups %*%t(coefs))
# so, if a cell is a "high expressor," its counts for gene5 will have a mean specified by ["high", "gene5"], i.e. 10# and if a cell is a "low expressor", its counts for gene1 will have a mean specified by ["low", "gene1"], i.e. 0.1# notice that the lambda parameter is completely specified by an intercept term and a state-specific weight term.# We will simulate this many cells:ncells <-1000# And we will order them sequentially for reasons that will soon be obvious:cells <-seq_len(ncells)# The above allows us to simulate a multifactorial generating model by design:highexp <-rbind(gene1 =as.integer(cells <=500), # first 500 cellsgene2 =as.integer(cells <= ncells), # i.e., always truegene3 =as.integer(cells >400), # last 600 cellsgene4 =as.integer(cells <= ncells), # i.e., always truegene5 =as.integer(cells >300)) # last 700 cells# Bolt on the intercept term for every cell because that's what we have to do design_ary <-array(1, dim=list(nrow(highexp), ncol(highexp), 2))design_ary[, , 2] <- highexpdimnames(design_ary) <-list(rownames(highexp), NULL, c("intercept","highexp"))# Get the coefficients for each gene for each cell: genes <-rownames(highexp)names(genes) <- geneslambdas <-do.call(cbind, lapply(genes, function(x) design_ary[x, , ] %*% coefs[x, ]))colnames(lambdas) <- genes# Now we multiply through to get simulated counts. counts <-do.call(cbind, lapply(genes, function(x) rpois(n=nrow(lambdas), lambda=lambdas[, x])))
Check for normality
These counts look like so (divided into “high expressor” and “low expressor” categories):
The Poisson distribution is a great way to simulate discrete events (counts of molecules, horse kicks, etc.) but it is most certainly not normally distributed – the mean depends upon the variance. In fact, a key feature of the Poisson distribution is that the mean is the variance. That breaks a lot of assumptions baked into linear model fitting, particularly the one about homoskedasticity (i.e., “the variance does not depend upon the mean”). This is super obvious for the low expressing cells in the figure above. Fortunately, there is a relatively easy solution to this problem.
Log-normalize to decouple mean and variance for Poisson random variables
Normalize the counts from each cell, with a target of 10K fragments:
X <-round(sweep(counts, 1, rowSums(counts) +1, `/`) *10000)
Take logs, as the raw counts are Poisson and have a mean-variance dependency. log1p adds a pseudo-count of 1 so that the result is always non-negative.
X <-log1p(X)
Plot it again. It looks a bit goofy but at least it resembles a bell curve within each group.
It’s hardly perfect, but the groups at least have a more symmetric distribution, regardless of their means.
Add an intercept column.
This is the easiest way to set the mean in model fitting. We did it above for the Poisson model of mRNA counts, too.
X <-cbind(intercept=1, X)
Simulate protein abundance
Let’s assume protein abundance follows linear combinations of gene abundance. Since we are enforcing a particular generating model, we can see how well we recover it for each protein by ordinary least squares or gradient descent. In some cases the protein is constitutively expressed at a low level, in which case the intercept coefficient is nonzero. We also create coefs for each gene:
B <-cbind(protein1=c(intercept=0.001,gene1=0.25,gene2=0,gene3=0.5,gene4=0,gene5=0),protein2=c(intercept=0,gene1=0.5,gene2=0,gene3=0,gene4=0,gene5=0.25),protein3=c(intercept=0,gene1=0,gene2=0,gene3=0.25,gene4=0,gene5=0.5))
Observe that for each protein .
This is exactly equivalent to saying and thus we will later estimate as .
Now we have what we need to generate a matrix of protein abundance. We assume ADT counts are Poisson, so we will reverse our variance-stabilizing transformation to yield Y0, the “raw” counts.
Y0 <-round(exp(X %*% B))
For regression, we will want to use the variance stabilized version, of course.
Y <-log1p(Y0)
How does this look?
heatmap(t(Y))
Fit coefficients by ordinary least squares
The Gauss-Markov theorem establishes that the best linear unbiased estimator for the coefficients of a linear system is the ordinary least squares (or OLS for short) solution. Specifically, if
, with
then for an N by K matrix and K by 1 matrix ,
minimizes ,
and is the best linear unbiased estimator of .
(It turns out that we can use the QR decomposition for a more stable inverse.)
If is an N by M matrix, we must solve for each row iteratively, and populate a K by M matrix . Nevertheless, as long as the mild Gauss-Markov conditions are satisfied, will minimize the mean squared error, so it should also be the solution by gradient descent if is the cost function.
A visual explainer can be found at setosa.io . The OLS estimate is also the maximum likelihood estimator (MLE) for a linear regression, and a useful benchmark all around. Let’s make it a function, and let’s use the more numerically stable QR decomposition to handle inversion.
Linear regression can be solved exactly via QR decomposition, so typically we would not bother with gradient descent. However, the minute you step outside the classical assumptions (homoskedasticity, uncorrelated errors, i.i.d. variables), this starts to break down. One way or another we need to start from a “guess” or initial value and proceed towards a minimum-error guess (which will be our estimate of the true parameters). The cost function can be the negative log of the joint probability of seeing the data given the proposed model parameters, in which case we refer to the result as the maximum likelihood estimate (we usually minimize the negative log-likelihood instead), or it could be the error sum of squares, .
This process requires a step size and an update informed by . More generally, the matrix of partial derivatives for each predictor forms the gradient , where is the MSE. The goal is to proceed as quickly as possible in the direction that minimizes the cost. For OLS, is , updating so that .
A visual explainer for this can be found at setosa.io. Since the OLS is also the maximum likelihood estimator (MLE) for a linear regression, we can benchmark our gradient descent function and also what we’d expect from glm(y, X). We compose gradient descent from smaller functions.
# estimate of y from X and byhat <-function(X, b) X %*% b# error of the estimate from X and beps <-function(y, X, b) yhat(X, b) - y# mean squared error (a function of eps)mse <-function(y1, y0) mean((y1 - y0)**2)# root mean squared error for the estimatermsb <-function(b1, b0) sqrt(mse(b1, b0))# update vector for a given starting guess `b` delta <-function(y, X, b) t(t(X) %*%eps(y, X, b))[1,] /length(y)# update function: step along this vector with stride length alphaupdate <-function(y, X, b, alpha) return(b - (delta(y, X, b) * alpha))# log an interation with its update logupdate <-function(y, X, b, b0) c(b, mse=mse(yhat(X,b), y), rmsb=rmsb(b, b0))# gradient descent functon, applying all of the above bits # note that we are not scaling the features, which would make it fasterols_gd <-function(y, X, alpha=1e-2, iter=1000, tol=1e-5, verbose=FALSE) { b0 <-rep(0, ncol(X))for (i inseq_len(iter)) { b <-update(y, X, b0, alpha) results <-logupdate(y, X, b, b0)names(results)[seq_along(b)] <-names(b)if (i ==1) updates <-t(as.matrix(results, nrow=1))if (i >1) updates <-rbind(updates, logupdate(y, X, b, b0))if (verbose) message("Iteration ", i, ", ","MSE: ", updates[i, "mse"], ", ","RMS(b): ", updates[i, "rmsb"])if (updates[i, "rmsb"] < tol & verbose) message("Converged after ", i, " iterations.")if (updates[i, "rmsb"] < tol) break() b0 <- b } updates <-data.frame(updates) updates$iter <-seq_len(nrow(updates))class(updates) <-"updates"attr(b, "updates") <- updatesclass(b) <-"results"return(b)}# test it invisible(apply(Y, 2, ols_gd, X))
It will also be handy to automate plotting for the resulting update logs.
# add plot function for updatesplot.updates <-function(upd, ...) {with(upd, plot(iter, mse, type="b", col=j, lwd=3, ...))}# wrap it for results plot.results <-function(res, ...) plot(attr(res, "updates"), ...)
We can compare this with the brute-force result of a stabilized OLS solution, using the QR decomposition to avoid singularities when inverting `t(X) %*% X`. Let’s compare and see if we get about the same result by gradient descent.
# set up plotspar(mfrow=c(2, 3))# plot convergence for each proteinfor (j inseq_len(ncol(Y))) plot(ols_gd(Y[, j], X), main=colnames(Y)[j])# run per columnB_gd <-apply(Y, 2, ols_gd, X)rownames(B_gd) <-colnames(X)# plot B, B_qr, and B_gdresponses <-colnames(Y)estimates <-c("B", "B_qr", "B_gd")for (j inseq_along(responses)) { protein <- responses[j] allests <-cbind(B, B_gd, B_qr) minest <-min(allests) maxest <-max(allests)plot(c(), main=protein,xlim=c(0, ncol(X) -1), xlab="gene",ylim=c(minest, maxest), ylab="coef")for (k inseq_along(estimates)) { method <- estimates[k] est <-get(method)[, protein]points(cbind(x=seq_along(colnames(X))-1, y=est), pch=k, col=j) }}
We do indeed. So why bother with gradient descent?
Maximum likelihood estimation
Suppose our objective is not to predict Y, but rather the probability that Y=1. This immediately breaks all assumptions that prop up the Gauss-Markov theorem, and comes up regularly when Y is an indicator for e.g. membership in a group. Ordinary least squares is useless here because the residuals can be infinite! A general solution to this problem is maximum likelihood.
(Trivia: training a one layer neural net with softmax activation is equivalent to logistic regression.)
Logistic regression
The canonical link for regression of Pr(y==1) on X is logit(p), defined as log(p / (1 - p)). Several good reasons exist for this choice:
logit(p), the log-odds of seeing Y = 1 given p, ranges from -Inf to +Inf
expit(z), the inverse of logit, ranges from 0 to 1 and thus maps XB to p
the first and second derivatives of this likelihood are easy to calculate.
The (log) likelihood function
What is likelihood? It’s the joint probability of seeing the data we saw, given the model we propose. For a logistic regression, that model is a Bernoulli distribution.
If , then the likelihood is .
This function is easier to compute as the log-likelihood,
For with this canonical link, the matrix of first partial derivatives is , where and , and the matrix of second partial derivatives is , where . Note that ends up scaling updates inversely proportional to variances.
With these two pieces, we can start iterating towards a numeric solution.
Our very own binary outcome
Let’s make this concrete by generating a binary outcome from the protein abundances in our previous example. Suppose the odds of a cell being malignant given are .
Now let’s solve for the coefficients by gradient descent, specifically by Fisher scoring. For a canonical logit link, Fisher scoring produces identical results to Newton-Raphson iteration, with the bonus feature of providing the covariance matrix among the predictors at convergence, which is the inverse of t(X) %*% W %*% X and thus relatively easily found with solve at convergence.
Logistic regression via gradient descent
lr_gd <-function(y, X, iter=1000, tol=1e-5) { y <-as.matrix(y) pi_i <-function(eta) return(exp(eta)/(1+exp(eta)))# starting params b0 <-matrix(0, ncol=1, nrow=ncol(X)) p0 <-pi_i(X %*% b0) W0 <-diag(as.vector(p0 * (1- p0))) I0 <-t(X) %*% W0 %*% X U0 <-t(X) %*% (y - p0) I <- I0 U <- U0# iterate for (i inseq_len(iter)) { b <- b0 + (solve(I) %*% U)if (all(abs(b - b0) < tol)) break p <-pi_i(X %*% b) W <-diag(as.vector(p * (1- p))) I <-t(X) %*% W %*% X U <-t(X) %*% (y - p) b0 <- b }attr(b, "fitted") <-pi_i(X %*% b)attr(b, "vcov") <-solve(I)attr(b, "iter") <- iclass(b) <-"results"return(b)}
This is as it should be: the coef for protein2 is positive and anticorrelated with protein1 and protein3, which are negative. None of the other terms are correlated with the intercept term, which in our generating model is 0. The rather larger errors at convergence demonstrate something else of note: once we step away from linear models, we need a larger sample size to get the same tight bounds on our estimates that we would expect from OLS.
Fun stuff
For certain values of fun. Given , and , can you model ? Remember, the two have different activation and gradient functions. Any ideas?