library(keras3)
library(caret)
library(plot.matrix)Neural Networks with MNIST
1 Introduction
This document adapts exercises from Chapter 10 of An Introduction to Statistical Learning by James et al. It introduces neural networks using the MNIST handwritten digit dataset.
The code below uses the keras interface from R. Under the hood, model fitting relies on Python-based deep learning tools. Depending on your setup, you may need to install both the R package and its backend dependencies separately.
A neural network is not a completely different idea from regression. It starts with the same basic ingredient: a weighted sum of predictors. The difference is that neural networks stack many weighted sums together and apply nonlinear transformations between them. Nonlinear transformations sound scary, but these are just link functions (including our old friend, plogis!)
1.1 Linear regression
Linear regression models an outcome as a weighted sum of predictors:
\[ y = \beta_0 + \beta_1 x_1 + \cdots + \beta_p x_p \]
1.2 Logistic regression
Logistic regression also starts with a weighted sum of predictors, but it transforms that sum so predictions become probabilities:
\[ \text{logit}(p) = \beta_0 + \beta_1 x_1 + \cdots + \beta_p x_p \]
1.3 Neural networks
A neural network keeps this same core idea of weighted sums, but allows us to:
- create many intermediate combinations of predictors,
- apply nonlinear activation functions after each combination, and
- stack multiple layers together.
So, in a very real sense, a neural network is built from the same ingredients as linear and logistic regression, but with many more layers and much more flexibility.
- Linear regression: one weighted sum of predictors
- Logistic regression: one weighted sum, then a nonlinear transformation to probabilities
- Neural network: many weighted sums, many nonlinear transformations, layered together
2 The data
We will work with the MNIST dataset, which contains 60,000 training images and 10,000 test images of handwritten digits from 0 through 9. Each image is 28 × 28 pixels.
mnist <- dataset_mnist()The training images are stored as a three-dimensional array:
dim(mnist$train$x)[1] 60000 28 28
dim(mnist$test$x)[1] 10000 28 28
That means the training data are arranged as:
- image number,
- row of pixels,
- column of pixels.
Let us visualize a couple of example digits.
plot(mnist$train$x[1, , ], main = mnist$train$y[1])plot(mnist$train$x[2, , ], main = mnist$train$y[2])3 Connecting the data to regression thinking
In linear or logistic regression, each observation is usually a row in a data frame, and each predictor is a column.
For image data, each pixel will play the role of a predictor.
- One image = one observation
- 28 × 28 = 784 pixels
- Therefore, each image will become a row with 784 predictors
So this neural network is still using the familiar “rows are observations, columns are predictors” setup from regression. The main difference is that the predictors come from image pixels rather than variables like height, temperature, or income.
If you already understand a design matrix in regression, you already understand the input to a neural network. The main new idea is what happens after the first weighted combination of predictors.
4 Processing the data
We first extract the training and test sets.
x_train <- mnist$train$x
g_train <- mnist$train$y
x_test <- mnist$test$x
g_test <- mnist$test$y
dim(x_train)[1] 60000 28 28
dim(x_test)[1] 10000 28 28
Now we reshape each 28 × 28 image into a vector of length 784.
x_train <- array_reshape(x_train, c(nrow(x_train), 784))
x_test <- array_reshape(x_test, c(nrow(x_test), 784))
dim(x_train)[1] 60000 784
dim(x_test)[1] 10000 784
str(x_train) num [1:60000, 1:784] 0 0 0 0 0 0 0 0 0 0 ...
4.1 Why do we reshape?
This is directly analogous to how regression expects a design matrix:
- rows = observations
- columns = predictors
After reshaping:
x_trainhas one row per image- each of the 784 columns is a pixel value
If we fit a multinomial regression model here, we could also use this same predictor matrix. The neural network starts from the same input matrix, but then builds more complex nonlinear features internally.
5 One-hot encoding
The response variable is the digit label: 0, 1, 2, …, 9.
Neural networks for multiclass classification often represent these categories using one-hot encoding. This converts each class label into a binary vector. Well my dudes, this is exactly what happens in a regression model when you put a categorical variable in (I used the statistical term of a “dummy variable”).
For example:
- 3 becomes
(0, 0, 0, 1, 0, 0, 0, 0, 0, 0) - 7 becomes
(0, 0, 0, 0, 0, 0, 0, 1, 0, 0)
y_train <- to_categorical(g_train, 10)
y_test <- to_categorical(g_test, 10)
dim(y_train)[1] 60000 10
dim(y_test)[1] 10000 10
5.1 Connection to logistic regression
This is the multiclass version of what logistic regression does with a binary response.
- In binary logistic regression, the model predicts the probability of one class.
- In multiclass classification, the model predicts a probability for each possible class.
- The final probabilities are constrained to add up to 1.
That final step is handled by the softmax function, which is the multiclass analogue of the logistic transformation.
6 Rescaling
Pixel values range from 0 to 255. We rescale them to the range 0 to 1.
x_train <- x_train / 255
x_test <- x_test / 2556.1 Why rescale?
Rescaling helps optimization and makes the network easier to train.
This is similar in spirit to standardizing predictors in regression:
- predictors on wildly different scales can make estimation harder,
- gradient-based methods work better when inputs are reasonably scaled.
In neural networks, scaling is especially important because repeated matrix multiplication and nonlinear activations can become unstable when inputs are too large.
7 Developing the model
Now we specify the neural network.
modelnn <- keras_model_sequential()
modelnn %>%
layer_dense(
units = 256,
activation = "relu",
input_shape = c(784)
) %>%
layer_dropout(rate = 0.4) %>%
layer_dense(units = 128, activation = "relu") %>%
layer_dropout(rate = 0.3) %>%
layer_dense(units = 10, activation = "softmax")9 How to interpret the layers
This model has:
- an input layer with 784 predictors,
- a first hidden layer with 256 units,
- a second hidden layer with 128 units,
- an output layer with 10 units, one for each digit.
9.1 Connection to linear regression
A dense layer computes a weighted sum plus an intercept, just like regression:
\[ z = \beta_0 + \beta_1 x_1 + \cdots + \beta_p x_p \]
If we stopped there, we would essentially have a linear model.
9.2 Connection to logistic regression
In logistic regression, we take that linear predictor and pass it through a nonlinear function to get probabilities.
Neural networks do the same basic thing, but repeatedly:
- compute weighted sums,
- apply nonlinear activation functions,
- pass the transformed values to the next layer.
9.3 What does ReLU do?
The hidden layers use the ReLU activation function:
\[ \text{ReLU}(x) = \max(0, x) \]
This introduces nonlinearity. Without nonlinear activation functions, stacking layers would still collapse to a linear model.
That is the key reason neural networks can capture more complex relationships than ordinary linear regression.
9.4 What does softmax do?
The final layer uses softmax. This converts the 10 output scores into probabilities that sum to 1. The class with the highest probability becomes the predicted digit.
This is the multiclass counterpart of logistic regression.
9.5 What does dropout do?
Dropout is a form of regularization. At each iteration, it temporarily removes a fraction of the units from the network.
This is conceptually similar to regularization ideas in regression, such as:
- avoiding overfitting,
- discouraging the model from relying too heavily on any one predictor or pathway,
- improving performance on new data. -regularizing priors
10 Fitting the model
Before training, we compile the model by specifying:
- the loss function,
- the optimizer,
- and the performance metric.
modelnn %>% compile(
loss = "categorical_crossentropy",
optimizer = optimizer_rmsprop(),
metrics = c("accuracy")
)Now fit the model:
system.time(
history <- modelnn %>%
fit(
x_train,
y_train,
epochs = 30,
batch_size = 128,
validation_split = 0.2
)
)Epoch 1/30
375/375 - 2s - 5ms/step - accuracy: 0.8678 - loss: 0.4317 - val_accuracy: 0.9457 - val_loss: 0.1778
Epoch 2/30
375/375 - 1s - 3ms/step - accuracy: 0.9394 - loss: 0.2045 - val_accuracy: 0.9644 - val_loss: 0.1209
Epoch 3/30
375/375 - 1s - 3ms/step - accuracy: 0.9528 - loss: 0.1589 - val_accuracy: 0.9692 - val_loss: 0.1055
Epoch 4/30
375/375 - 1s - 3ms/step - accuracy: 0.9599 - loss: 0.1336 - val_accuracy: 0.9712 - val_loss: 0.0976
Epoch 5/30
375/375 - 1s - 3ms/step - accuracy: 0.9660 - loss: 0.1132 - val_accuracy: 0.9727 - val_loss: 0.0972
Epoch 6/30
375/375 - 1s - 3ms/step - accuracy: 0.9678 - loss: 0.1070 - val_accuracy: 0.9737 - val_loss: 0.0905
Epoch 7/30
375/375 - 1s - 3ms/step - accuracy: 0.9720 - loss: 0.0951 - val_accuracy: 0.9768 - val_loss: 0.0867
Epoch 8/30
375/375 - 1s - 3ms/step - accuracy: 0.9736 - loss: 0.0873 - val_accuracy: 0.9760 - val_loss: 0.0885
Epoch 9/30
375/375 - 1s - 3ms/step - accuracy: 0.9746 - loss: 0.0835 - val_accuracy: 0.9750 - val_loss: 0.0912
Epoch 10/30
375/375 - 1s - 3ms/step - accuracy: 0.9764 - loss: 0.0803 - val_accuracy: 0.9773 - val_loss: 0.0842
Epoch 11/30
375/375 - 1s - 3ms/step - accuracy: 0.9782 - loss: 0.0710 - val_accuracy: 0.9787 - val_loss: 0.0840
Epoch 12/30
375/375 - 1s - 3ms/step - accuracy: 0.9784 - loss: 0.0711 - val_accuracy: 0.9789 - val_loss: 0.0895
Epoch 13/30
375/375 - 1s - 3ms/step - accuracy: 0.9799 - loss: 0.0664 - val_accuracy: 0.9792 - val_loss: 0.0815
Epoch 14/30
375/375 - 1s - 3ms/step - accuracy: 0.9809 - loss: 0.0662 - val_accuracy: 0.9794 - val_loss: 0.0844
Epoch 15/30
375/375 - 1s - 3ms/step - accuracy: 0.9825 - loss: 0.0610 - val_accuracy: 0.9793 - val_loss: 0.0871
Epoch 16/30
375/375 - 1s - 3ms/step - accuracy: 0.9821 - loss: 0.0598 - val_accuracy: 0.9787 - val_loss: 0.0902
Epoch 17/30
375/375 - 1s - 3ms/step - accuracy: 0.9823 - loss: 0.0589 - val_accuracy: 0.9790 - val_loss: 0.0881
Epoch 18/30
375/375 - 1s - 3ms/step - accuracy: 0.9828 - loss: 0.0566 - val_accuracy: 0.9797 - val_loss: 0.0851
Epoch 19/30
375/375 - 1s - 3ms/step - accuracy: 0.9843 - loss: 0.0541 - val_accuracy: 0.9791 - val_loss: 0.0913
Epoch 20/30
375/375 - 1s - 3ms/step - accuracy: 0.9842 - loss: 0.0523 - val_accuracy: 0.9804 - val_loss: 0.0925
Epoch 21/30
375/375 - 1s - 4ms/step - accuracy: 0.9855 - loss: 0.0498 - val_accuracy: 0.9797 - val_loss: 0.0901
Epoch 22/30
375/375 - 1s - 4ms/step - accuracy: 0.9855 - loss: 0.0472 - val_accuracy: 0.9808 - val_loss: 0.0948
Epoch 23/30
375/375 - 1s - 3ms/step - accuracy: 0.9851 - loss: 0.0483 - val_accuracy: 0.9799 - val_loss: 0.0910
Epoch 24/30
375/375 - 1s - 3ms/step - accuracy: 0.9857 - loss: 0.0471 - val_accuracy: 0.9805 - val_loss: 0.0933
Epoch 25/30
375/375 - 1s - 3ms/step - accuracy: 0.9865 - loss: 0.0464 - val_accuracy: 0.9799 - val_loss: 0.0892
Epoch 26/30
375/375 - 1s - 3ms/step - accuracy: 0.9857 - loss: 0.0459 - val_accuracy: 0.9809 - val_loss: 0.0974
Epoch 27/30
375/375 - 1s - 3ms/step - accuracy: 0.9864 - loss: 0.0449 - val_accuracy: 0.9805 - val_loss: 0.0972
Epoch 28/30
375/375 - 1s - 3ms/step - accuracy: 0.9877 - loss: 0.0419 - val_accuracy: 0.9810 - val_loss: 0.0978
Epoch 29/30
375/375 - 1s - 3ms/step - accuracy: 0.9877 - loss: 0.0421 - val_accuracy: 0.9813 - val_loss: 0.0997
Epoch 30/30
375/375 - 1s - 3ms/step - accuracy: 0.9869 - loss: 0.0427 - val_accuracy: 0.9818 - val_loss: 0.1014
user system elapsed
72.09 13.24 38.37
Plot the training history:
plot(history, smooth = FALSE)11 How fitting connects to regression
There is a strong connection here to generalized linear models.
11.1 Loss function
The loss function is categorical cross-entropy, which is the multiclass extension of log-loss used in logistic regression.
So this part of the neural network is very closely related to logistic regression: both methods estimate parameters by minimizing a likelihood-based loss.
11.2 Optimizer
The optimizer (RMSprop) is the algorithm used to update parameters.
In regression classes, parameter estimation is often presented as something solved directly or through likelihood maximization. In neural networks, the same general goal remains, but the parameter space is much larger, so iterative optimization is essential.
11.3 Epochs and batches
- An epoch means one full pass through the training data.
- A batch is a subset of observations used for one gradient update.
This differs from the usual regression workflow, where we often imagine all parameters being estimated from the full dataset at once.
11.4 Validation split
The validation split sets aside part of the training data to monitor predictive performance during fitting.
This is closely related to the train/test or cross-validation logic used in regression and other statistical learning methods.
12 Model accuracy
After fitting the model, we can generate predictions for the test data.
predictions <- modelnn %>%
predict(x_test)313/313 - 0s - 774us/step
predcat <- factor(max.col(predictions) - 1)
realcat <- factor(mnist$test$y)Now create a confusion matrix.
confusionMatrix(realcat, predcat)Confusion Matrix and Statistics
Reference
Prediction 0 1 2 3 4 5 6 7 8 9
0 973 1 1 1 0 0 2 1 1 0
1 0 1128 3 0 0 0 3 0 1 0
2 1 1 1019 0 2 0 0 5 3 1
3 0 0 5 988 0 3 0 7 2 5
4 0 0 2 0 965 0 5 2 1 7
5 2 0 0 9 1 866 8 1 3 2
6 5 2 0 1 4 1 944 0 1 0
7 2 1 9 2 1 0 0 1008 1 4
8 8 2 3 3 6 3 2 5 940 2
9 2 3 0 3 16 0 0 5 3 977
Overall Statistics
Accuracy : 0.9808
95% CI : (0.9779, 0.9834)
No Information Rate : 0.1138
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.9787
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
Sensitivity 0.9799 0.9912 0.9779 0.9811 0.9698 0.9920
Specificity 0.9992 0.9992 0.9985 0.9976 0.9981 0.9972
Pos Pred Value 0.9929 0.9938 0.9874 0.9782 0.9827 0.9709
Neg Pred Value 0.9978 0.9989 0.9974 0.9979 0.9967 0.9992
Prevalence 0.0993 0.1138 0.1042 0.1007 0.0995 0.0873
Detection Rate 0.0973 0.1128 0.1019 0.0988 0.0965 0.0866
Detection Prevalence 0.0980 0.1135 0.1032 0.1010 0.0982 0.0892
Balanced Accuracy 0.9895 0.9952 0.9882 0.9893 0.9840 0.9946
Class: 6 Class: 7 Class: 8 Class: 9
Sensitivity 0.9793 0.9749 0.9833 0.9790
Specificity 0.9985 0.9978 0.9962 0.9964
Pos Pred Value 0.9854 0.9805 0.9651 0.9683
Neg Pred Value 0.9978 0.9971 0.9982 0.9977
Prevalence 0.0964 0.1034 0.0956 0.0998
Detection Rate 0.0944 0.1008 0.0940 0.0977
Detection Prevalence 0.0958 0.1028 0.0974 0.1009
Balanced Accuracy 0.9889 0.9863 0.9898 0.9877
13 Why use a confusion matrix?
The model is trained by minimizing cross-entropy, but that quantity is not always the easiest to interpret.
A confusion matrix is more intuitive because it shows:
- which digits are commonly classified correctly,
- which digits tend to be confused with each other,
- and overall classification performance.
This is similar to how we may do model selection with an -IC metric, but we prefer to present results with an absolute metric that can be easily interpreted (such as RMSE or MAE). ## Summary: neural networks as an extension of regression
A useful way to think about this entire exercise is:
- Linear regression: weighted sums of predictors
- Logistic regression: weighted sums + nonlinear transformation for probabilities
- Neural networks: repeated weighted sums + nonlinear transformations + multiple layers
So neural networks are not a completely different idea. They are an extension of concepts from regression:
- coefficients become weights,
- intercepts are still there,
- predictors are combined linearly within each layer,
- nonlinear link functions are applied,
- and the model is trained to minimize a loss function.
The main innovations are depth, nonlinearity, and flexibility.