Machine Learning extraction is the process of applying statistical algorithms to data to extract meaningful patterns or insights. One common machine learning algorithm is Linear Discriminant Analysis (LDA), which is used for classification tasks. LDA aims to find the linear combination of features that best separates the classes in the data.
The motivation behind using machine learning extraction techniques like LDA is to classify and predict outcomes based on the input data. By understanding the underlying patterns in the data, machine learning models can make accurate predictions on new, unseen data.
The following script, we first load the MNIST dataset, preprocess the data, fit an LDA model, make predictions on the test set, and calculate the accuracy of the model. We also visualize the predictions made by the model on a sample of images from the test set.
# Install necessary packages
library(keras)
library(knitr)
# Load MNIST dataset
mnist <- dataset_mnist()
x_train
and y_train
variables are assigned
the training features (pixel values of images) and labels (digit
classes) from the MNIST dataset, respectively.x_test
and y_test
variables are assigned
the test features and labels.array_reshape()
function reshapes the 2D array of
images in x_train
to a 2D array of size
(number of images, 784)
for further processing. Each image
in MNIST is 28x28 pixels, resulting in a 1D array of length 784.x_test
is reshaped to be
compatible with the training data.x_train
and x_test
are
normalized by dividing them by 255. Normalizing pixel values to a range
of 0 to 1 can help improve the convergence and performance of machine
learning algorithms, especially Neural Networks.These steps ensure that the data is formatted and scaled for training and testing the LDA model on the MNIST dataset. Data preprocessing, including reshaping and normalization enhancing the model’s effectiveness by standardizing the input features.
# Prepare the data
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y
# Reshape and normalize the data
x_train <- array_reshape(x_train, c(nrow(x_train), 784))
x_test <- array_reshape(x_test, c(nrow(x_test), 784))
x_train <- x_train / 255
x_test <- x_test / 255
This code chunk involves the process of Linear Discriminant Analysis (LDA) using the MASS package for classification tasks on the MNIST dataset.
Linear Discriminant Analysis (LDA) is a classification algorithm that aims to find the linear combination of features that best separates the classes in the data. In the context of the MNIST dataset, LDA can be used to classify handwritten digits based on the pixel values of the images.
The key mathematical concepts used in LDA with the MNIST dataset:
\[ \mu_i = \dfrac{1}{n_i} \sum_{x \in X_i} x \]
Where: - \(\mu_i\) is the mean vector for class \(i\). - \(n_i\) is the number of samples in class \(i\). - \(X_i\) is the set of feature vectors for class \(i\).
\[ S_W = \sum_{i=1}^{c} \sum_{x \in X_i} (x - \mu_i)(x - \mu_i)^T \]
\[ S_B = \sum_{i=1}^{c} n_i (\mu_i - \mu)(\mu_i - \mu)^T \]
Where: - \(S_W\) is the within-class scatter matrix. - \(S_B\) is the between-class scatter matrix. - \(\mu\) is the overall mean vector of all features. - \(c\) is the total number of classes.
Optimization: Solve the eigenvalue problem \(S_W^{-1}S_B\) to find the eigenvectors and eigenvalues.
Projection: Project the input data onto the directions defined by the eigenvectors with the largest eigenvalues.
\[ y = W^T x \]
Where: - \(y\) is the projected data. - \(W\) is a matrix containing the eigenvectors as columns. - \(x\) is the input data.
Using mean vectors, scatter matrices, and eigenvectors, LDA finds a projection that maximizes the separation between classes while minimizing the variance within each class.
# Load the MASS package for Linear Discriminant Analysis
library(MASS)
# Convert y_train to a factor
y_train <- as.factor(y_train)
# Remove constant variables
vars_to_keep <- apply(x_train, 2, function(x) length(unique(x)) > 1)
x_train <- x_train[, vars_to_keep]
# Fit a linear discriminant analysis model
model <- lda(x_train, y_train)
## Warning in lda.default(x, grouping, ...): variables are collinear
# Make predictions on the test set
x_test_filtered = x_test[, vars_to_keep]
predictions <- predict(model, x_test_filtered)
# Calculate accuracy of the model
accuracy <- sum(predictions$class == y_test) / length(y_test)
# Print the accuracy
print(paste("Accuracy of linear discriminant analysis model:", accuracy))
## [1] "Accuracy of linear discriminant analysis model: 0.873"
The provided R Markdown file contains code snippets that demonstrate the sample usage of predictions and test data for evaluating a Linear Discriminant Analysis (LDA) model trained on the MNIST dataset. Let’s explain each component:
predictions$class
is appropriate for
further processing.predicted_labels <- as.numeric(levels(predictions$class))[predictions$class]
converts the predicted class labels from a factor to a numeric format.
This conversion facilitates subsequent comparison and plotting
operations.plot_predictions()
function is defined to visualize
image predictions alongside true labels. It takes in image data, true
labels, predicted labels, and indices of images to display.plot_predictions()
function to compare the true labels stored in y_test
and
the predicted labels. It ensures that both y_test
and
predicted_labels
are in a numeric format for accurate
comparison.This segment facilitates the visualization of predictions made by the LDA model on a subset of images from the MNIST test set, enabling a quick assessment of the model’s performance in classifying handwritten digits.
correct_indices <- which(predictions$class == y_test)
incorrect_indices <- which(predictions$class != y_test)
# Note: Ensure 'predictions$class' is in the correct format, you might need to convert it based on your 'predict()' output
predicted_labels <- as.numeric(levels(predictions$class))[predictions$class]
plot_predictions <- function(images, true_labels, pred_labels, indices, n = 10) {
par(mfrow=c(sqrt(n), sqrt(n))) # Adjust the plotting area to accommodate n images.
for (idx in 1:n){
if (idx > length(indices)) break # Prevents attempting to plot more images than specified.
current_index <- indices[idx]
img <- matrix(images[current_index, ], nrow=28, ncol=28, byrow=TRUE) # Reshape the image data back into a 28x28 matrix.
# Rotate the image matrix 90 degrees to the right and apply vertical reflection.
#img <- t(img) # Transpose to rotate 90 degrees left (counter-clockwise).
img <- apply(img, 2, rev) # Vertically flip the image after rotation.
img <- t(img)
# Display the image
image(1:28, 1:28, img, col=gray.colors(256), xaxt='n', yaxt='n', main=paste("Predicted:", pred_labels[current_index], "\nTrue:", true_labels[current_index]))
}
}
plot_predictions(x_test_filtered,
as.numeric(y_test),
predicted_labels,
correct_indices,
n=9)
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
# Plot incorrect labels
plot_predictions(x_test_filtered,
as.numeric(y_test),
predicted_labels,
incorrect_indices,
n=9)
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]