The softmax function

The softmax function is the normalised exponential function. It is used to transform a vector \({\bf x} = (x_1,\dots,x_n)^{\top}\in{\mathbb R}^n\) into a vector in \((0,1)^n\) with unit sum: \[ \operatorname{softmax}({\bf x}) = \dfrac{\exp({\bf x})}{\sum_{j=1}^n \exp(x_j) }. \] This function appears in many applications in Statistics and Machine Learning. The softmax function is closely related to the LogSumExp function since \[ \operatorname{softmax}({\bf x}) = \exp\left[ {\bf x} - \operatorname{LSE}({\bf x}) \right], \] where \(\operatorname{LSE}\) is the LogSumExp function: \[ \operatorname{LSE}({\bf x}) = \log\left[\sum_{j=1}^n \exp(x_j) \right]. \]

Examples

The following R code presents some illustrative examples for calculating the softmax function using the recursive formula described in LogSumExp function.

# Required packages
library(knitr)
#--------------------------------------------------------------------
# softmax function: Recursive formula
#--------------------------------------------------------------------
softmax <- function(par){
  n.par <- length(par)
  par1 <- sort(par, decreasing = TRUE)
  Lk <- par1[1]
  for (k in 1:(n.par-1)) {
    Lk <- max(par1[k+1], Lk) + log1p(exp(-abs(par1[k+1] - Lk))) 
  }
  val <- exp(par - Lk)
  return(val)
}

# Example 1
vec <- c(-1,2,1,-3)
sm <- softmax(vec)
print(sm)
## [1] 0.034952901 0.702047789 0.258268948 0.004730361
print(sum(sm))
## [1] 1
# Example 2
vec <- c(-1,2,1,-3,-1,2,1,-3,-1,2,1,-3,-1,2,1,-3,-1,2,1,-3)
sm <- softmax(vec)
print(sm)
##  [1] 0.0069905803 0.1404095579 0.0516537897 0.0009460722 0.0069905803
##  [6] 0.1404095579 0.0516537897 0.0009460722 0.0069905803 0.1404095579
## [11] 0.0516537897 0.0009460722 0.0069905803 0.1404095579 0.0516537897
## [16] 0.0009460722 0.0069905803 0.1404095579 0.0516537897 0.0009460722
print(sum(sm))
## [1] 1
# Example 3
set.seed(123)
vec <- rnorm(30)
sm <- softmax(vec)
print(sm)
##  [1] 0.012597569 0.017528043 0.104866472 0.023676617 0.025110027 0.122614406
##  [7] 0.034984172 0.006227147 0.011102016 0.014130245 0.075043003 0.031620111
## [13] 0.032942084 0.024647174 0.012656088 0.131748219 0.036300445 0.003087502
## [19] 0.044493158 0.013752053 0.007584873 0.017743240 0.007908792 0.010644979
## [25] 0.011809925 0.004084855 0.050996988 0.025722154 0.007069874 0.077307769
print(sum(sm))
## [1] 1