Lesson 4.2 - Multinomial Logistic Regression

Robbie Beane

Multinomial Logistic Regression

In this lesson, we will learn how to adapt the logistic regression formula for situations in which our response variable has more than 2 potential classes. In otherwords, we will see how to use logistic regression for multi-class classification problems.

Assume that \(Y\) is a categorical variable with levels \(A\), \(B\), and \(C\), and that we have \(m\) predictors \(X^{(1)}\), \(X^{(2)}\), …, \(X^{(m)}\). Define the following quantities:

\[p_A = P\left[ Y = A ~|~ X \right]\] \[p_B = P\left[ Y = B ~|~ X \right]\] \[p_C = P\left[ Y = C ~|~ X \right]\]

Multinomial Logistic Regression

A multinomial logistic regression model has the following form:

\[\ln\left[\frac{p_A}{p_C} \right] = \beta_{0,1} + \beta_{1,1}\cdot X^{(1)} + \beta_{2,1}\cdot X^{(2)} + ... + \beta_{m,1}\cdot X^{(m)}\]

\[\ln\left[\frac{p_B}{p_C} \right] = \beta_{0,2} + \beta_{1,2}\cdot X^{(1)} + \beta_{2,2}\cdot X^{(2)} + ... + \beta_{m,2}\cdot X^{(m)}\]

Please note the following comments regarding this model:

  1. Since \(p_A + p_B + p_C = 1\), if we know \(\ln\left[{p_A}/{p_C} \right]\) and \(\ln\left[{p_B}/{p_C} \right]\), then we can calculate all three probabilities.

  2. In general, if our reponse variable has \(k\) classes, then our multinomial logistic regression model will need to consist of \(k-1\) equations.

  3. The expression \(p_A/p_C\) is referred to as the relative odds ratio of A with respect to C. We will denote this expression by \(\mathrm{Odds}\left [A:C \right]\).

Example: Diagnosing Diabetes

We will now provide an example of performing multinomial logistic regression in R. In our example, we will build a model that attempts to detect the presence of two types of diabetes based on measurements taken from a blood test. Before we load the dataset, lets load some packages. Note that the nnet package is required for multinomial logistic regression.

Load Packages and Data

library(nnet)
library(ggplot2)
library(gridExtra)
library(caret)

Loading the Dataset

Our dateset contains the following variables. Our response variable will be CC.

df <- read.table("data/diabetes.txt", sep="\t", header=TRUE)
summary(df)
##        RW               IR             SSPG             CC       
##  Min.   :0.7100   Min.   : 10.0   Min.   : 29.0   Min.   :1.000  
##  1st Qu.:0.8800   1st Qu.:118.0   1st Qu.:100.0   1st Qu.:2.000  
##  Median :0.9800   Median :156.0   Median :159.0   Median :3.000  
##  Mean   :0.9773   Mean   :186.1   Mean   :184.2   Mean   :2.297  
##  3rd Qu.:1.0800   3rd Qu.:221.0   3rd Qu.:257.0   3rd Qu.:3.000  
##  Max.   :1.2000   Max.   :748.0   Max.   :480.0   Max.   :3.000

We need to convert CC to a factor variable.

df$CC <- factor(df$CC, levels= c("3", "2", "1"))
summary(df)
##        RW               IR             SSPG       CC    
##  Min.   :0.7100   Min.   : 10.0   Min.   : 29.0   3:76  
##  1st Qu.:0.8800   1st Qu.:118.0   1st Qu.:100.0   2:36  
##  Median :0.9800   Median :156.0   Median :159.0   1:33  
##  Mean   :0.9773   Mean   :186.1   Mean   :184.2         
##  3rd Qu.:1.0800   3rd Qu.:221.0   3rd Qu.:257.0         
##  Max.   :1.2000   Max.   :748.0   Max.   :480.0

Let’s calculate the proportion of individuals in the training set that are non-diabetic.

mean(df$CC == '3')
## [1] 0.5241379

Exploratory Plots

We will generate boxplots to explore the relationships between CC and the other three variables.

p1 <- ggplot(df, aes(x=CC, y=RW, fill=CC)) + geom_boxplot()  
p2 <- ggplot(df, aes(x=CC, y=IR, fill=CC)) + geom_boxplot()  
p3 <- ggplot(df, aes(x=CC, y=SSPG, fill=CC)) + geom_boxplot()

grid.arrange(p1, p2, p3, ncol=3)

Create Model

We will now create our multinomal logistic regression model using the multinom function from the nnet package.

mod <- multinom(CC ~ RW + IR + SSPG, df)
## # weights:  15 (8 variable)
## initial  value 159.298782 
## iter  10 value 69.027793
## iter  20 value 68.418245
## iter  30 value 68.414665
## final  value 68.414644 
## converged
summary(mod)
## Call:
## multinom(formula = CC ~ RW + IR + SSPG, data = df)
## 
## Coefficients:
##   (Intercept)        RW           IR       SSPG
## 2   -7.615261  3.472572  0.003586749 0.01641449
## 1   -1.845230 -5.867196 -0.013353688 0.04550552
## 
## Std. Errors:
##   (Intercept)       RW          IR        SSPG
## 2    2.335615 2.446151 0.002349168 0.004981886
## 1    3.463507 3.866580 0.005019289 0.009241721
## 
## Residual Deviance: 136.8293 
## AIC: 152.8293

Generate Predictions

We can use the predict function to generate predictions. If we would like to for predict to return estimated probabilities, then we need to set type="probs".

nd <- data.frame(
  RW = c(0.87, 0.91, 0.85, 1.15),
  IR = c(120, 620, 150, 150),
  SSPG = c(150, 160, 230, 230)  )

predict(mod, nd, type="probs")
##           3         2            1
## 1 0.7351087 0.1340857 0.1308056224
## 2 0.4024971 0.5973904 0.0001124671
## 3 0.1467817 0.1034117 0.7498066056
## 4 0.2580303 0.5152335 0.2267361597

If we would like to for predict to return estimated probabilities, then we need to set type="class".

predict(mod, nd, type="class")
## [1] 3 2 1 2
## Levels: 3 2 1

Generating Predictions Using Formulas

As practice, let’s use the formulas provided by our model to confim the calculations that predict has given us. We will do this for a single observation.

x <- c(0.87, 120, 150)
x
## [1]   0.87 120.00 150.00

We will extract our coefficients from the model.

cf <- summary(mod)$coefficients
cf
##   (Intercept)        RW           IR       SSPG
## 2   -7.615261  3.472572  0.003586749 0.01641449
## 1   -1.845230 -5.867196 -0.013353688 0.04550552

We now calculate the log-odds.

logodds23 <- cf[1,1]  + sum(x * cf[1,2:4])
logodds13 <- cf[2,1]  + sum(x * cf[2,2:4])

c(logodds13, logodds23)
## [1] -1.726306 -1.701539

Next we find the relative odds ratios.

odds13 <- exp(logodds13)
odds23 <- exp(logodds23)

c(odds13, odds23)
## [1] 0.1779405 0.1824026

Finally, we calculate our probabilities.

p1 <- odds13 / (odds13 + odds23 + 1)
p2 <- odds23 / (odds13 + odds23 + 1)
p3 <- 1 - p1 - p2
c(p3, p2, p1)
## [1] 0.7351087 0.1340857 0.1308056

Evaluating the Model

We will evaluate our model by calculating its accuracy on the training set. To this this, we must first generate predictions based on the training set.

training_pred <- predict(mod, df, type="class")

set.seed(1)
s <- sample(1:145, 10)

training_pred[s]
##  [1] 1 1 3 3 1 3 3 2 3 3
## Levels: 3 2 1
df$CC[s]
##  [1] 3 1 3 3 3 2 3 2 3 3
## Levels: 3 2 1

We will now calculate the model’s training accuracy.

accuracy <- mean(training_pred == df$CC)
accuracy
## [1] 0.8275862

Additional Classification Metrics

Assume that we have trained a classification model with a response variable Y. Let C refer to a particular class for Y. We say that:

We can define the following metrics to measure the model’s performance on individual classes.

Confusion Matrix

A confusion matrix is a useful tool for evaluating the performance of a classification model. We can create confusion matrices in R using the confusionMatrix function from the carat package.

confusionMatrix(training_pred, df$CC)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  3  2  1
##          3 69 12  3
##          2  5 24  3
##          1  2  0 27
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8276          
##                  95% CI : (0.7561, 0.8852)
##     No Information Rate : 0.5241          
##     P-Value [Acc > NIR] : 1.864e-14       
##                                           
##                   Kappa : 0.7107          
##                                           
##  Mcnemar's Test P-Value : 0.1077          
## 
## Statistics by Class:
## 
##                      Class: 3 Class: 2 Class: 1
## Sensitivity            0.9079   0.6667   0.8182
## Specificity            0.7826   0.9266   0.9821
## Pos Pred Value         0.8214   0.7500   0.9310
## Neg Pred Value         0.8852   0.8938   0.9483
## Prevalence             0.5241   0.2483   0.2276
## Detection Rate         0.4759   0.1655   0.1862
## Detection Prevalence   0.5793   0.2207   0.2000
## Balanced Accuracy      0.8453   0.7966   0.9002