Welcome to this R demo session! Here, I will demonstrate how to use R to conduct discriminant analysis.
In this Rmarkdown file, the following discriminant analysis methods will be described:
We’ll revist the iris
data set, introduced in the
Rmarkdown file for the topic of MANOVA, for predicting iris species
based on the predictor variables Sepal.Length
,
Sepal.Width
, Petal.Length
, and
Petal.Width
. Let’s first examine the data as if this is the
first time we work with it.
# load the data
dta_iris <- iris
# examine the structure of the data
str(dta_iris) # four continuous variables and one categorical variable (Species);
## 'data.frame': 150 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
# the dataset contains 5 variables and 150 total observations (the sample size is pretty good)
# examine the first six rows of the data
head(dta_iris)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1 5.1 3.5 1.4 0.2 setosa
## 2 4.9 3.0 1.4 0.2 setosa
## 3 4.7 3.2 1.3 0.2 setosa
## 4 4.6 3.1 1.5 0.2 setosa
## 5 5.0 3.6 1.4 0.2 setosa
## 6 5.4 3.9 1.7 0.4 setosa
As mentioned in the lecture, one of the key assumptions of linear discriminant analysis is that each of the predictor variables have the same variance. An easy way to assure that this assumption is met is to scale each variable such that it has a mean of 0 and a standard deviation of 1. Let’s standardize our continuous variables before the main analysis.
We can quickly do so in R by using the scale()
function:
#scale each predictor variable (i.e. first 4 columns)
dta_iris[1:4] <- scale(dta_iris[1:4])
We can quickly verify that each predictor variable now has a mean of 0 and a standard deviation of 1:
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.1 ✔ tibble 3.2.1
## ✔ lubridate 1.9.3 ✔ tidyr 1.3.1
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(dplyr)
library(rstatix) # the package we used earlier for MANOVA
##
## Attaching package: 'rstatix'
##
## The following object is masked from 'package:stats':
##
## filter
# Examine the mean and sd of the variables
dta_iris %>%
get_summary_stats(Sepal.Length, Petal.Length, Sepal.Width, Petal.Width, type = "mean_sd")
## # A tibble: 4 × 4
## variable n mean sd
## <fct> <dbl> <dbl> <dbl>
## 1 Sepal.Length 150 0 1
## 2 Petal.Length 150 0 1
## 3 Sepal.Width 150 0 1
## 4 Petal.Width 150 0 1
There is one last step to prepare our data for analysis. That is to separate the data into two subsets: a training set and a testing set. We will use the training set to build our predictive model and then we will use our testing set to evaluate the accuracy of that model. For convenience sake we will use a 70/30 split, using 70% of the data as the training set and the remaining 30% for the testing set.
#make this example reproducible
set.seed(123)
#Use 70% of dataset as training set and remaining 30% as testing set
sample <- sample(c(TRUE, FALSE), nrow(dta_iris), replace=TRUE, prob=c(0.7,0.3))
dta_train <- dta_iris %>%
filter(sample)
dta_test <- dta_iris %>%
filter(!sample)
nrow(dta_train) # contains 70% of the data
## [1] 106
nrow(dta_test) # contains 30% of the data
## [1] 44
The assumptions checking process is skipped here. For more detailed information about checking and evaluating assumptions for discriminant analysis, please refer to the lecture slides and the previous Rmarkdown file for MANOVA.
Here, the question we are trying to answer is can IVs predict DV group membership significantly?
As mentioned in the lecture, the fundamental equations for testing the significance of a set of discriminant functions are the same as for MANOVA
# MANOVA model
model <- lm(cbind(Sepal.Length, Petal.Length, Sepal.Width, Petal.Width) ~ Species, dta_iris)
Manova(model, test.statistic = "Wilks")
##
## Type II MANOVA Tests: Wilks test statistic
## Df test stat approx F num Df den Df Pr(>F)
## Species 2 0.023439 199.15 8 288 < 2.2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
# Effect size
effectsize::eta_squared(model, partial = FALSE)
## # Effect Size for ANOVA (Type I)
##
## Response | Parameter | Eta2 | 95% CI
## ----------------------------------------------
## Sepal.Length | Species | 0.62 | [0.54, 1.00]
## Petal.Length | Species | 0.94 | [0.93, 1.00]
## Sepal.Width | Species | 0.40 | [0.30, 1.00]
## Petal.Width | Species | 0.93 | [0.91, 1.00]
# Post-hoc test
dta_iris %>%
gather(key = "Variable", value = "Value", Sepal.Length, Petal.Length, Sepal.Width, Petal.Width) %>%
group_by(Variable) %>%
tukey_hsd(Value ~ Species) # Pair-wise comparison
## # A tibble: 12 × 10
## Variable term group1 group2 null.value estimate conf.low conf.high p.adj
## * <chr> <chr> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 Petal.Le… Spec… setosa versi… 0 1.59 1.47 1.70 3 e-15
## 2 Petal.Le… Spec… setosa virgi… 0 2.32 2.20 2.43 3 e-15
## 3 Petal.Le… Spec… versi… virgi… 0 0.732 0.616 0.847 3 e-15
## 4 Petal.Wi… Spec… setosa versi… 0 1.42 1.29 1.54 3 e-15
## 5 Petal.Wi… Spec… setosa virgi… 0 2.34 2.21 2.46 3 e-15
## 6 Petal.Wi… Spec… versi… virgi… 0 0.918 0.791 1.05 3 e-15
## 7 Sepal.Le… Spec… setosa versi… 0 1.12 0.829 1.42 3.39e-14
## 8 Sepal.Le… Spec… setosa virgi… 0 1.91 1.62 2.20 3 e-15
## 9 Sepal.Le… Spec… versi… virgi… 0 0.787 0.493 1.08 8.29e- 9
## 10 Sepal.Wi… Spec… setosa versi… 0 -1.51 -1.88 -1.14 3.10e-14
## 11 Sepal.Wi… Spec… setosa virgi… 0 -1.04 -1.41 -0.673 1.36e- 9
## 12 Sepal.Wi… Spec… versi… virgi… 0 0.468 0.0990 0.837 8.78e- 3
## # ℹ 1 more variable: p.adj.signif <chr>
From the results, we can conclude that the four flower features significantly predict the species of the flower. For each of the flower feature, there is a significant mean difference between each pair of the two groups.
Using discriminant functions, we are trying to answer a different set of questions:
Let’s start with linear discriminant analysis as this is the most restrictive model.
The linear discriminant analysis can be easily computed using the
function lda()
from MASS
package. Let’s
compute our first linear discriminant analysis model. Here I would like
to highlight some key arguments included in lda()
:
groups ~ x1 + x2 + ....
. Note that this formula is the
opposite of that for an ANOVA – the grouping variable is the response
here.library(MASS)
##
## Attaching package: 'MASS'
## The following object is masked from 'package:rstatix':
##
## select
## The following object is masked from 'package:dplyr':
##
## select
model <- lda(Species~., data = dta_train)
model
## Call:
## lda(Species ~ ., data = dta_train)
##
## Prior probabilities of groups:
## setosa versicolor virginica
## 0.3301887 0.3396226 0.3301887
##
## Group means:
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## setosa -1.09089516 0.7992840 -1.3098556 -1.2548267
## versicolor 0.07849615 -0.6668712 0.2519562 0.1357115
## virginica 0.95518021 -0.3019712 1.0240277 1.1066415
##
## Coefficients of linear discriminants:
## LD1 LD2
## Sepal.Length 0.4412815 -0.2436325
## Sepal.Width 0.9265112 -0.8423339
## Petal.Length -3.4642736 2.0185682
## Petal.Width -2.7144209 -2.2894355
##
## Proportion of trace:
## LD1 LD2
## 0.993 0.007
LDA determines group means and computes, for each individual, the probability of belonging to the different groups. The individual is then affected to the group with the highest probability score.
The lda()
outputs contain the following elements:
setosa
groupVariable selection: Note that, if the predictor variables are standardized before computing LDA, the discriminator weights (coefficients) can be used as measures of variable importance for feature selection.
Now we have built our LDA model but we do not know how good is the model. Once we’ve fit the model using our training data, we can use it to make predictions on our test data.
#use LDA model to make predictions on test data
predictions <- predict(model, dta_test)
names(predictions)
## [1] "class" "posterior" "x"
This returns a list with three variables:
We can quickly view each of these results for the first six observations in our test dataset:
#view predicted class for first six observations in test set
head(predictions$class)
## [1] setosa setosa setosa setosa setosa setosa
## Levels: setosa versicolor virginica
#view posterior probabilities for first six observations in test set
head(predictions$posterior)
## setosa versicolor virginica
## 1 1 1.368345e-19 3.862237e-43
## 2 1 5.821808e-19 3.138592e-42
## 3 1 7.504010e-26 1.959362e-51
## 4 1 6.042884e-23 1.505230e-47
## 5 1 6.272356e-27 7.218899e-53
## 6 1 8.088751e-32 1.435061e-58
#view linear discriminants for first six observations in test set
head(predictions$x)
## LD1 LD2
## 1 7.589072 0.64800788
## 2 7.445525 0.65736554
## 3 8.917770 -0.54094463
## 4 8.296391 -0.04008727
## 5 9.147257 -0.73754007
## 6 10.082879 -2.77930577
We can use the following code to see what percentage of observations the LDA model correctly predicted the Species for:
#find accuracy of model
mean(predictions$class==dta_test$Species)
## [1] 0.9545455
It turns out that the model correctly predicted the Species for 95.45% of the observations in our test dataset. This is very good.
In the real-world an LDA model will rarely predict every class
outcome correctly, but this iris
dataset is simply built in
a way that machine learning algorithms such as LDA tend to perform very
well on it.
We can also compute the confusion matrix to closely examine the accuracy of the classification.
table(predictions$class, dta_test$Species)
##
## setosa versicolor virginica
## setosa 15 0 0
## versicolor 0 13 1
## virginica 0 1 14
Lastly, we can create an LDA plot (i.e., discriminant function plot) to view the linear discriminants of the model and visualize how well it separated the three different species in our dataset:
#define data to plot
lda_plot <- cbind(dta_train, predict(model)$x)
library(ggplot2)
#create plot
ggplot(lda_plot, aes(LD1, LD2)) +
geom_point(aes(color = Species))
From the discriminant function plot, we can easily see that the three species are separated by the two linear discriminant functions.
QDA is little bit more flexible than LDA, in the sense that it does not assumes the equality of variance/covariance. In other words, for QDA the covariance matrix can be different for each class. LDA tends to be a better than QDA when you have a small training set.
In contrast, QDA is recommended if the training set is very large, so that the variance of the classifier is not a major issue, or if the assumption of a common covariance matrix for the K classes is clearly untenable (James et al. 2014).
QDA can be computed using the R function qda()
from the
MASS
package. The procedures of QDA are essentially the
same as LDA.
library(MASS)
# Fit the model
model <- qda(Species~., data = dta_train)
model
## Call:
## qda(Species ~ ., data = dta_train)
##
## Prior probabilities of groups:
## setosa versicolor virginica
## 0.3301887 0.3396226 0.3301887
##
## Group means:
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## setosa -1.09089516 0.7992840 -1.3098556 -1.2548267
## versicolor 0.07849615 -0.6668712 0.2519562 0.1357115
## virginica 0.95518021 -0.3019712 1.0240277 1.1066415
# Make predictions
predictions <- predict(model, dta_test)
# Model accuracy
mean(predictions$class==dta_test$Species)
## [1] 0.8863636
table(predictions$class, dta_test$Species)
##
## setosa versicolor virginica
## setosa 15 0 0
## versicolor 0 10 1
## virginica 0 4 14
The LDA classifier assumes that each class comes from a single normal (or Gaussian) distribution. This is too restrictive. For MDA, there are classes, and each class is assumed to be a Gaussian mixture of subclasses, where each data point has a probability of belonging to each class. Equality of covariance matrix, among classes, is still assumed. MDA might outperform LDA and QDA in the situations where there are subgroups within each group.
# install.packages("mda")
library(mda)
## Loading required package: class
## Loaded mda 0.5-5
# Fit the model
model <- mda(Species~., data = dta_train)
model
## Call:
## mda(formula = Species ~ ., data = dta_train)
##
## Dimension: 4
##
## Percent Between-Group Variance Explained:
## v1 v2 v3 v4
## 95.25 99.50 99.83 100.00
##
## Degrees of Freedom (per dimension): 5
##
## Training Misclassification Error: 0 ( N = 106 )
##
## Deviance: 1.248
# Make predictions
predictions <- predict(model, dta_test)
# Model accuracy
mean(predictions==dta_test$Species)
## [1] 0.9090909
table(predictions, dta_test$Species)
##
## predictions setosa versicolor virginica
## setosa 15 0 0
## versicolor 0 11 1
## virginica 0 3 14
FDA is a flexible extension of LDA that uses non-linear combinations of predictors such as splines. FDA is useful to model multivariate non-normality or non-linear relationships among variables within each group, allowing for a more accurate classification.
# Fit the model
model <- fda(Species~., data = dta_train)
# Make predictions
predictions <- predict(model, dta_test)
# Model accuracy
mean(predictions==dta_test$Species)
## [1] 0.9545455
RDA builds a classification rule by regularizing the group covariance matrices (Friedman 1989) allowing a more robust model against multicollinearity in the data. This might be very useful for a large multivariate data set containing highly correlated predictors.
Regularized discriminant analysis is a kind of a trade-off between LDA and QDA. Recall that, in LDA we assume equality of covariance matrix for all of the classes. QDA assumes different covariance matrices for all the classes. Regularized discriminant analysis is an intermediate between LDA and QDA.
RDA shrinks the separate covariances of QDA toward a common covariance as in LDA. This improves the estimate of the covariance matrices in situations where the number of predictors is larger than the number of samples in the training data, potentially leading to an improvement of the model accuracy.
# install.packages("klaR")
library(klaR)
# Fit the model
model <- rda(Species~., data = dta_train)
# Make predictions
predictions <- predict(model, dta_test)
# Model accuracy
mean(predictions$class==dta_test$Species)
## [1] 0.9545455
table(predictions$class, dta_test$Species)
##
## setosa versicolor virginica
## setosa 15 0 0
## versicolor 0 13 1
## virginica 0 1 14