Machine learning extraction

Machine Learning extraction is the process of applying statistical algorithms to data to extract meaningful patterns or insights. One common machine learning algorithm is Linear Discriminant Analysis (LDA), which is used for classification tasks. LDA aims to find the linear combination of features that best separates the classes in the data.

The motivation behind using machine learning extraction techniques like LDA is to classify and predict outcomes based on the input data. By understanding the underlying patterns in the data, machine learning models can make accurate predictions on new, unseen data.

The following script, we first load the MNIST dataset, preprocess the data, fit an LDA model, make predictions on the test set, and calculate the accuracy of the model. We also visualize the predictions made by the model on a sample of images from the test set.

# Install necessary packages

library(keras)
library(knitr) 


# Load MNIST dataset
mnist <- dataset_mnist()

Preparation

This code chunk pertains to the preparation and preprocessing of the MNIST dataset for subsequent analysis, specifically using Linear Discriminant Analysis (LDA).

  1. Loading and Assigning Data:
    • x_train and y_train variables are assigned the training features (pixel values of images) and labels (digit classes) from the MNIST dataset, respectively.
    • x_test and y_test variables are assigned the test features and labels.
  2. Reshaping and Normalizing Data:
    • The array_reshape() function reshapes the 2D array of images in x_train to a 2D array of size (number of images, 784) for further processing. Each image in MNIST is 28x28 pixels, resulting in a 1D array of length 784.
    • Similarly, the test set x_test is reshaped to be compatible with the training data.
    • The pixel values in x_train and x_test are normalized by dividing them by 255. Normalizing pixel values to a range of 0 to 1 can help improve the convergence and performance of machine learning algorithms, especially Neural Networks.

These steps ensure that the data is formatted and scaled for training and testing the LDA model on the MNIST dataset. Data preprocessing, including reshaping and normalization enhancing the model’s effectiveness by standardizing the input features.

# Prepare the data
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y

# Reshape and normalize the data
x_train <- array_reshape(x_train, c(nrow(x_train), 784))
x_test <- array_reshape(x_test, c(nrow(x_test), 784))
x_train <- x_train / 255
x_test <- x_test / 255

Linear Discriminant Analysis

This code chunk involves the process of Linear Discriminant Analysis (LDA) using the MASS package for classification tasks on the MNIST dataset.

Linear Discriminant Analysis (LDA) is a classification algorithm that aims to find the linear combination of features that best separates the classes in the data. In the context of the MNIST dataset, LDA can be used to classify handwritten digits based on the pixel values of the images.

The key mathematical concepts used in LDA with the MNIST dataset:

  1. Mean Vectors: For each class \(i\) in the dataset, calculate the mean vector \(\mu_i\) of the features (pixel values) corresponding to that class.

\[ \mu_i = \dfrac{1}{n_i} \sum_{x \in X_i} x \]

Where: - \(\mu_i\) is the mean vector for class \(i\). - \(n_i\) is the number of samples in class \(i\). - \(X_i\) is the set of feature vectors for class \(i\).

  1. Scatter Matrix: Calculate the within-class scatter matrix \(S_W\) and the between-class scatter matrix \(S_B\).

\[ S_W = \sum_{i=1}^{c} \sum_{x \in X_i} (x - \mu_i)(x - \mu_i)^T \]

\[ S_B = \sum_{i=1}^{c} n_i (\mu_i - \mu)(\mu_i - \mu)^T \]

Where: - \(S_W\) is the within-class scatter matrix. - \(S_B\) is the between-class scatter matrix. - \(\mu\) is the overall mean vector of all features. - \(c\) is the total number of classes.

  1. Optimization: Solve the eigenvalue problem \(S_W^{-1}S_B\) to find the eigenvectors and eigenvalues.

  2. Projection: Project the input data onto the directions defined by the eigenvectors with the largest eigenvalues.

\[ y = W^T x \]

Where: - \(y\) is the projected data. - \(W\) is a matrix containing the eigenvectors as columns. - \(x\) is the input data.

  1. Classification: Use a decision rule to assign a class label based on the projected features.

Using mean vectors, scatter matrices, and eigenvectors, LDA finds a projection that maximizes the separation between classes while minimizing the variance within each class.

# Load the MASS package for Linear Discriminant Analysis
library(MASS)

# Convert y_train to a factor
y_train <- as.factor(y_train)

# Remove constant variables
vars_to_keep <- apply(x_train, 2, function(x) length(unique(x)) > 1)
x_train <- x_train[, vars_to_keep]

# Fit a linear discriminant analysis model
model <- lda(x_train, y_train)
## Warning in lda.default(x, grouping, ...): variables are collinear
# Make predictions on the test set
x_test_filtered = x_test[, vars_to_keep]
predictions <- predict(model, x_test_filtered)

# Calculate accuracy of the model

accuracy <- sum(predictions$class == y_test) / length(y_test)

# Print the accuracy
print(paste("Accuracy of linear discriminant analysis model:", accuracy))
## [1] "Accuracy of linear discriminant analysis model: 0.873"

Plotting Predictions

The provided R Markdown file contains code snippets that demonstrate the sample usage of predictions and test data for evaluating a Linear Discriminant Analysis (LDA) model trained on the MNIST dataset. Let’s explain each component:

  1. Usage of Predictions and Test Data:
    • The code sets up a section for utilizing the model predictions and test data to evaluate the LDA performance. It provides a note to ensure that the format of predictions$class is appropriate for further processing.
  2. Converting Predicted Labels to Numeric Format:
    • The code segment predicted_labels <- as.numeric(levels(predictions$class))[predictions$class] converts the predicted class labels from a factor to a numeric format. This conversion facilitates subsequent comparison and plotting operations.
  3. Definition of Plotting Function:
    • The plot_predictions() function is defined to visualize image predictions alongside true labels. It takes in image data, true labels, predicted labels, and indices of images to display.
  4. Adjusting Plotting Area and Image Display:
    • The function adjusts the layout for plotting multiple images. It iterates over a subset of images (n=9) based on the provided indices and plots each image with its corresponding predicted and true labels.
  5. Replotting to Compare True and Predicted Labels:
    • The last code chunk reuses the plot_predictions() function to compare the true labels stored in y_test and the predicted labels. It ensures that both y_test and predicted_labels are in a numeric format for accurate comparison.

This segment facilitates the visualization of predictions made by the LDA model on a subset of images from the MNIST test set, enabling a quick assessment of the model’s performance in classifying handwritten digits.

correct_indices <- which(predictions$class == y_test)
incorrect_indices <- which(predictions$class != y_test)

# Note: Ensure 'predictions$class' is in the correct format, you might need to convert it based on your 'predict()' output
predicted_labels <- as.numeric(levels(predictions$class))[predictions$class]

plot_predictions <- function(images, true_labels, pred_labels, indices, n = 10) {
  par(mfrow=c(sqrt(n), sqrt(n))) # Adjust the plotting area to accommodate n images.
  for (idx in 1:n){
    if (idx > length(indices)) break # Prevents attempting to plot more images than specified.

    current_index <- indices[idx]
    img <- matrix(images[current_index, ], nrow=28, ncol=28, byrow=TRUE) # Reshape the image data back into a 28x28 matrix.
    
    # Rotate the image matrix 90 degrees to the right and apply vertical reflection.
    #img <- t(img) # Transpose to rotate 90 degrees left (counter-clockwise).
    img <- apply(img, 2, rev) # Vertically flip the image after rotation.
    img <- t(img)

    # Display the image
    image(1:28, 1:28, img, col=gray.colors(256), xaxt='n', yaxt='n', main=paste("Predicted:", pred_labels[current_index], "\nTrue:", true_labels[current_index]))
  }
}
plot_predictions(x_test_filtered, 
                 as.numeric(y_test), 
                 predicted_labels, 
                 correct_indices, 
                 n=9)
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]

# Plot incorrect labels
plot_predictions(x_test_filtered, 
                 as.numeric(y_test), 
                 predicted_labels, 
                 incorrect_indices, 
                 n=9)
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]
## Warning in matrix(images[current_index, ], nrow = 28, ncol = 28, byrow = TRUE):
## data length [717] is not a sub-multiple or multiple of the number of rows [28]