In this lesson, we will learn how to adapt the logistic regression formula for situations in which our response variable has more than 2 potential classes. In otherwords, we will see how to use logistic regression for multi-class classification problems.
Assume that \(Y\) is a categorical variable with levels \(A\), \(B\), and \(C\), and that we have \(m\) predictors \(X^{(1)}\), \(X^{(2)}\), …, \(X^{(m)}\). Define the following quantities:
A multinomial logistic regression model has the following form:
\[\ln\left[\frac{p_A}{p_C} \right] = \beta_{0,1} + \beta_{1,1}\cdot X^{(1)} + \beta_{2,1}\cdot X^{(2)} + ... + \beta_{m,1}\cdot X^{(m)}\]
\[\ln\left[\frac{p_B}{p_C} \right] = \beta_{0,2} + \beta_{1,2}\cdot X^{(1)} + \beta_{2,2}\cdot X^{(2)} + ... + \beta_{m,2}\cdot X^{(m)}\]
Please note the following comments regarding this model:
Since \(p_A + p_B + p_C = 1\), if we know \(\ln\left[{p_A}/{p_C} \right]\) and \(\ln\left[{p_B}/{p_C} \right]\), then we can calculate all three probabilities.
In general, if our reponse variable has \(k\) classes, then our multinomial logistic regression model will need to consist of \(k-1\) equations.
The expression \(p_A/p_C\) is referred to as the relative odds ratio of A with respect to C. We will denote this expression by \(\mathrm{Odds}\left [A:C \right]\).
We will now provide an example of performing multinomial logistic regression in R. In our example, we will build a model that attempts to detect the presence of two types of diabetes based on measurements taken from a blood test. Before we load the dataset, lets load some packages. Note that the nnet package is required for multinomial logistic regression.
library(nnet)
library(ggplot2)
library(gridExtra)
library(caret)
Our dateset contains the following variables. Our response variable will be CC.
CC - Categorical variable. 1 = Overt diabetes, 2 = Chemical diabetes, 3 - Non-diabeticRW - Relative weightIR - Insulin responseSSPG - Steady state plasma glucosedf <- read.table("data/diabetes.txt", sep="\t", header=TRUE)
summary(df)
RW IR SSPG CC
Min. :0.7100 Min. : 10.0 Min. : 29.0 Min. :1.000
1st Qu.:0.8800 1st Qu.:118.0 1st Qu.:100.0 1st Qu.:2.000
Median :0.9800 Median :156.0 Median :159.0 Median :3.000
Mean :0.9773 Mean :186.1 Mean :184.2 Mean :2.297
3rd Qu.:1.0800 3rd Qu.:221.0 3rd Qu.:257.0 3rd Qu.:3.000
Max. :1.2000 Max. :748.0 Max. :480.0 Max. :3.000
We need to convert CC to a factor variable.
df$CC <- factor(df$CC, levels= c("3", "2", "1"))
summary(df)
RW IR SSPG CC
Min. :0.7100 Min. : 10.0 Min. : 29.0 3:76
1st Qu.:0.8800 1st Qu.:118.0 1st Qu.:100.0 2:36
Median :0.9800 Median :156.0 Median :159.0 1:33
Mean :0.9773 Mean :186.1 Mean :184.2
3rd Qu.:1.0800 3rd Qu.:221.0 3rd Qu.:257.0
Max. :1.2000 Max. :748.0 Max. :480.0
Let’s calculate the proportion of individuals in the training set that are non-diabetic.
mean(df$CC == '3')
[1] 0.5241379
We will generate boxplots to explore the relationships between CC and the other three variables.
p1 <- ggplot(df, aes(x=CC, y=RW, fill=CC)) + geom_boxplot()
p2 <- ggplot(df, aes(x=CC, y=IR, fill=CC)) + geom_boxplot()
p3 <- ggplot(df, aes(x=CC, y=SSPG, fill=CC)) + geom_boxplot()
grid.arrange(p1, p2, p3, ncol=3)
We will now create our multinomal logistic regression model using the multinom function from the nnet package.
mod <- multinom(CC ~ RW + IR + SSPG, df)
# weights: 15 (8 variable)
initial value 159.298782
iter 10 value 69.027793
iter 20 value 68.418245
iter 30 value 68.414665
final value 68.414644
converged
summary(mod)
Call:
multinom(formula = CC ~ RW + IR + SSPG, data = df)
Coefficients:
(Intercept) RW IR SSPG
2 -7.615261 3.472572 0.003586749 0.01641449
1 -1.845230 -5.867196 -0.013353688 0.04550552
Std. Errors:
(Intercept) RW IR SSPG
2 2.335615 2.446151 0.002349168 0.004981886
1 3.463507 3.866580 0.005019289 0.009241721
Residual Deviance: 136.8293
AIC: 152.8293
We can use the predict function to generate predictions. If we would like to for predict to return estimated probabilities, then we need to set type="probs".
nd <- data.frame(
RW = c(0.87, 0.91, 0.85, 1.15),
IR = c(120, 620, 150, 150),
SSPG = c(150, 160, 230, 230) )
predict(mod, nd, type="probs")
3 2 1
1 0.7351087 0.1340857 0.1308056224
2 0.4024971 0.5973904 0.0001124671
3 0.1467817 0.1034117 0.7498066056
4 0.2580303 0.5152335 0.2267361597
If we would like to for predict to return estimated probabilities, then we need to set type="class".
predict(mod, nd, type="class")
[1] 3 2 1 2
Levels: 3 2 1
As practice, let’s use the formulas provided by our model to confim the calculations that predict has given us. We will do this for a single observation.
x <- c(0.87, 120, 150)
x
[1] 0.87 120.00 150.00
We will extract our coefficients from the model.
cf <- summary(mod)$coefficients
cf
(Intercept) RW IR SSPG
2 -7.615261 3.472572 0.003586749 0.01641449
1 -1.845230 -5.867196 -0.013353688 0.04550552
We now calculate the log-odds.
logodds23 <- cf[1,1] + sum(x * cf[1,2:4])
logodds13 <- cf[2,1] + sum(x * cf[2,2:4])
c(logodds13, logodds23)
[1] -1.726306 -1.701539
Next we find the relative odds ratios.
odds13 <- exp(logodds13)
odds23 <- exp(logodds23)
c(odds13, odds23)
[1] 0.1779405 0.1824026
Finally, we calculate our probabilities.
p1 <- odds13 / (odds13 + odds23 + 1)
p2 <- odds23 / (odds13 + odds23 + 1)
p3 <- 1 - p1 - p2
c(p3, p2, p1)
[1] 0.7351087 0.1340857 0.1308056
We will evaluate our model by calculating its accuracy on the training set. To this this, we must first generate predictions based on the training set.
training_pred <- predict(mod, df, type="class")
set.seed(1)
s <- sample(1:145, 10)
training_pred[s]
[1] 1 1 3 3 1 3 3 2 3 3
Levels: 3 2 1
df$CC[s]
[1] 3 1 3 3 3 2 3 2 3 3
Levels: 3 2 1
We will now calculate the model’s training accuracy.
accuracy <- mean(training_pred == df$CC)
accuracy
[1] 0.8275862
Assume that we have trained a classification model with a response variable Y. Let C refer to a particular class for Y. We say that:
An observation is a True Positive for Class C if our model predicts that the model is of class C, and the observed class is actually C.
An observation is a False Positive for Class C if our model predicts that the model is of class C, and the observed class is NOT C.
An observation is a True Negative for Class C if our model predicts that the model is NOT of class C, and the observed class is NOT C.
An observation is a False Negative for Class C if our model predicts that the model is NOT of class C, and the observed class is actually C.
We can define the following metrics to measure the model’s performance on individual classes.
Sensitivity = \(\frac{TP}{TP + FN} = \frac{TP}{\textrm{Number of Actual Cs} }\) (Also called: True Positive Rate, Recall, and Probability of Detection)
Specificity = \(\frac{TN}{TN + FP} = \frac{TN}{\textrm{Number of Actual non-Cs} }\) (Also called True Negative Rate)
Positive Predictive Value = \(\frac{TP}{TP + FP} = \frac{TP}{\textrm{Number of Predicted Cs} }\) (Also called Precision)
Negative Predictive Value = \(\frac{TN}{TN + FN} = \frac{TP}{\textrm{Number of predicted non-Cs} }\)
Prevalence = \(\frac{TP + FN}{TP + FP + TN + FN} = \frac{\textrm{Number of Actual Cs} }{\textrm{Size of Total Population} }\)
Detection Rate = \(\frac{TP}{TP + FP + TN + FN} = \frac{TP }{\textrm{Size of Total Population} }\)
Detection Prevalence = \(\frac{TP + FP}{TP + FP + TN + FN} = \frac{\textrm{Number of Predicted Cs} }{\textrm{Size of Total Population} }\)
Balanced Accuracy = \((\textrm{sensitivity + specificity})/2 = \frac{TP + TN}{2(TP + FP + TN + FN)} = \frac{\textrm{Number of Correct Predictions}}{2 \cdot ( \textrm{Size of Total Population})}\)
A confusion matrix is a useful tool for evaluating the performance of a classification model. We can create confusion matrices in R using the confusionMatrix function from the carat package.
confusionMatrix(training_pred, df$CC)
Confusion Matrix and Statistics
Reference
Prediction 3 2 1
3 69 12 3
2 5 24 3
1 2 0 27
Overall Statistics
Accuracy : 0.8276
95% CI : (0.7561, 0.8852)
No Information Rate : 0.5241
P-Value [Acc > NIR] : 1.864e-14
Kappa : 0.7107
Mcnemar's Test P-Value : 0.1077
Statistics by Class:
Class: 3 Class: 2 Class: 1
Sensitivity 0.9079 0.6667 0.8182
Specificity 0.7826 0.9266 0.9821
Pos Pred Value 0.8214 0.7500 0.9310
Neg Pred Value 0.8852 0.8938 0.9483
Prevalence 0.5241 0.2483 0.2276
Detection Rate 0.4759 0.1655 0.1862
Detection Prevalence 0.5793 0.2207 0.2000
Balanced Accuracy 0.8453 0.7966 0.9002