Code
# ============================================================
# TWO-LAYER NEURAL NETWORK FOR HANDWRITTEN DIGIT DATA (MNIST)
# ============================================================

# Install if needed:
# install.packages("keras")
# library(keras)
# install_keras()

library(keras)

# -----------------------------
# 1. Load MNIST data
# -----------------------------
mnist <- dataset_mnist()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

    8192/11490434 [..............................] - ETA: 0s
   16384/11490434 [..............................] - ETA: 1:01
   32768/11490434 [..............................] - ETA: 58s 
   49152/11490434 [..............................] - ETA: 1:21
   65536/11490434 [..............................] - ETA: 1:10
   81920/11490434 [..............................] - ETA: 1:24
  106496/11490434 [..............................] - ETA: 1:10
  131072/11490434 [..............................] - ETA: 1:03
  147456/11490434 [..............................] - ETA: 1:00
  163840/11490434 [..............................] - ETA: 58s 
  196608/11490434 [..............................] - ETA: 52s
  229376/11490434 [..............................] - ETA: 48s
  262144/11490434 [..............................] - ETA: 44s
  278528/11490434 [..............................] - ETA: 43s
  327680/11490434 [..............................] - ETA: 42s
  409600/11490434 [>.............................] - ETA: 35s
  507904/11490434 [>.............................] - ETA: 29s
  589824/11490434 [>.............................] - ETA: 26s
  671744/11490434 [>.............................] - ETA: 23s
  778240/11490434 [=>............................] - ETA: 21s
  868352/11490434 [=>............................] - ETA: 19s
  901120/11490434 [=>............................] - ETA: 19s
 1015808/11490434 [=>............................] - ETA: 17s
 1163264/11490434 [==>...........................] - ETA: 15s
 1277952/11490434 [==>...........................] - ETA: 14s
 1310720/11490434 [==>...........................] - ETA: 14s
 1507328/11490434 [==>...........................] - ETA: 12s
 1753088/11490434 [===>..........................] - ETA: 11s
 1884160/11490434 [===>..........................] - ETA: 10s
 1916928/11490434 [====>.........................] - ETA: 10s
 2203648/11490434 [====>.........................] - ETA: 9s 
 2531328/11490434 [=====>........................] - ETA: 7s
 2678784/11490434 [=====>........................] - ETA: 7s
 2777088/11490434 [======>.......................] - ETA: 7s
 3170304/11490434 [=======>......................] - ETA: 6s
 3645440/11490434 [========>.....................] - ETA: 5s
 3907584/11490434 [=========>....................] - ETA: 4s
 4071424/11490434 [=========>....................] - ETA: 4s
 4489216/11490434 [==========>...................] - ETA: 4s
 4882432/11490434 [===========>..................] - ETA: 3s
 5193728/11490434 [============>.................] - ETA: 3s
 5537792/11490434 [=============>................] - ETA: 2s
 5914624/11490434 [==============>...............] - ETA: 2s
 6299648/11490434 [===============>..............] - ETA: 2s
 6668288/11490434 [================>.............] - ETA: 2s
 7012352/11490434 [=================>............] - ETA: 1s
 7405568/11490434 [==================>...........] - ETA: 1s
 7749632/11490434 [===================>..........] - ETA: 1s
 7913472/11490434 [===================>..........] - ETA: 1s
 8142848/11490434 [====================>.........] - ETA: 1s
 8503296/11490434 [=====================>........] - ETA: 1s
 8830976/11490434 [======================>.......] - ETA: 0s
 9240576/11490434 [=======================>......] - ETA: 0s
 9601024/11490434 [========================>.....] - ETA: 0s
 9977856/11490434 [=========================>....] - ETA: 0s
10338304/11490434 [=========================>....] - ETA: 0s
10715136/11490434 [==========================>...] - ETA: 0s
11075584/11490434 [===========================>..] - ETA: 0s
11386880/11490434 [============================>.] - ETA: 0s
11490434/11490434 [==============================] - 4s 0us/step
Code
x_train <- mnist$train$x
y_train <- mnist$train$y

x_test  <- mnist$test$x
y_test  <- mnist$test$y

# -----------------------------
# 2. Reshape and normalize
# -----------------------------
# Original: 28 x 28
# Flatten to 784 and scale to [0,1]

x_train <- array_reshape(x_train, c(nrow(x_train), 784))
x_test  <- array_reshape(x_test,  c(nrow(x_test), 784))

x_train <- x_train / 255
x_test  <- x_test / 255

# -----------------------------
# 3. One-hot encode response
# -----------------------------
# Digits are 0,1,...,9
y_train_onehot <- to_categorical(y_train, num_classes = 10)
y_test_onehot  <- to_categorical(y_test,  num_classes = 10)

# -----------------------------
# 4. Activation functions
# -----------------------------
sigmoid <- function(z) {
  1 / (1 + exp(-z))
}

softmax <- function(Z) {
  Z_shift <- Z - apply(Z, 1, max)
  expZ <- exp(Z_shift)
  expZ / rowSums(expZ)
}

# -----------------------------
# 5. Initialize parameters
# -----------------------------
input_dim  <- 784
hidden_dim <- 128
output_dim <- 10

set.seed(123)

W1 <- matrix(rnorm(hidden_dim * input_dim, mean = 0, sd = 0.01),
             nrow = hidden_dim, ncol = input_dim)
b1 <- matrix(0, nrow = hidden_dim, ncol = 1)

W2 <- matrix(rnorm(output_dim * hidden_dim, mean = 0, sd = 0.01),
             nrow = output_dim, ncol = hidden_dim)
b2 <- matrix(0, nrow = output_dim, ncol = 1)

# -----------------------------
# 6. Forward propagation
# -----------------------------
forward <- function(X, W1, b1, W2, b2) {

  # Z1 = X W1^T + b1
  Z1 <- X %*% t(W1) +
    matrix(rep(as.vector(b1), each = nrow(X)),
           nrow = nrow(X), byrow = FALSE)

  # A1 = sigmoid(Z1)
  A1 <- sigmoid(Z1)

  # Z2 = A1 W2^T + b2
  Z2 <- A1 %*% t(W2) +
    matrix(rep(as.vector(b2), each = nrow(X)),
           nrow = nrow(X), byrow = FALSE)

  # A2 = softmax(Z2)
  A2 <- softmax(Z2)

  list(Z1 = Z1, A1 = A1, Z2 = Z2, A2 = A2)
}

# -----------------------------
# 7. Loss function
# -----------------------------
cross_entropy <- function(y_true, y_pred) {
  eps <- 1e-8
  y_pred <- pmin(pmax(y_pred, eps), 1 - eps)
  -mean(rowSums(y_true * log(y_pred)))
}

# -----------------------------
# 8. Mini-batch training
# -----------------------------
learning_rate <- 0.1
epochs <- 20
batch_size <- 128

n_train <- nrow(x_train)
loss_history <- numeric(epochs)

for (epoch in 1:epochs) {

  # shuffle data
  idx <- sample(1:n_train)
  x_train <- x_train[idx, , drop = FALSE]
  y_train_onehot <- y_train_onehot[idx, , drop = FALSE]

  batch_losses <- c()

  for (start in seq(1, n_train, by = batch_size)) {

    end <- min(start + batch_size - 1, n_train)

    Xb <- x_train[start:end, , drop = FALSE]
    Yb <- y_train_onehot[start:end, , drop = FALSE]

    m <- nrow(Xb)

    # ===== forward =====
    fp <- forward(Xb, W1, b1, W2, b2)
    Z1 <- fp$Z1
    A1 <- fp$A1
    A2 <- fp$A2

    loss <- cross_entropy(Yb, A2)
    batch_losses <- c(batch_losses, loss)

    # ===== backpropagation =====

    # Output error:
    # dZ2 = A2 - Y
    dZ2 <- A2 - Yb                     # m x 10

    # dW2 = (dZ2^T A1)/m
    dW2 <- t(dZ2) %*% A1 / m           # 10 x hidden_dim

    # db2 = column means
    db2 <- matrix(colMeans(dZ2), nrow = output_dim, ncol = 1)

    # Hidden error:
    # dA1 = dZ2 W2
    dA1 <- dZ2 %*% W2                  # m x hidden_dim

    # dZ1 = dA1 ⊙ A1 ⊙ (1-A1)
    dZ1 <- dA1 * A1 * (1 - A1)

    # dW1 = (dZ1^T X)/m
    dW1 <- t(dZ1) %*% Xb / m           # hidden_dim x 784

    # db1 = column means
    db1 <- matrix(colMeans(dZ1), nrow = hidden_dim, ncol = 1)

    # ===== parameter updates =====
    W2 <- W2 - learning_rate * dW2
    b2 <- b2 - learning_rate * db2

    W1 <- W1 - learning_rate * dW1
    b1 <- b1 - learning_rate * db1
  }

  loss_history[epoch] <- mean(batch_losses)
  cat("Epoch:", epoch, " Loss:", round(loss_history[epoch], 5), "\n")
}
Epoch: 1  Loss: 1.73623 
Epoch: 2  Loss: 0.63927 
Epoch: 3  Loss: 0.43884 
Epoch: 4  Loss: 0.37306 
Epoch: 5  Loss: 0.33952 
Epoch: 6  Loss: 0.31833 
Epoch: 7  Loss: 0.30238 
Epoch: 8  Loss: 0.28992 
Epoch: 9  Loss: 0.27891 
Epoch: 10  Loss: 0.26923 
Epoch: 11  Loss: 0.26043 
Epoch: 12  Loss: 0.25203 
Epoch: 13  Loss: 0.24446 
Epoch: 14  Loss: 0.23692 
Epoch: 15  Loss: 0.22968 
Epoch: 16  Loss: 0.2233 
Epoch: 17  Loss: 0.21688 
Epoch: 18  Loss: 0.21069 
Epoch: 19  Loss: 0.20512 
Epoch: 20  Loss: 0.19956 
Code
# -----------------------------
# 9. Prediction function
# -----------------------------
predict_nn <- function(X, W1, b1, W2, b2) {
  fp <- forward(X, W1, b1, W2, b2)
  probs <- fp$A2
  preds <- max.col(probs) - 1   # convert to digits 0,...,9
  list(probabilities = probs, predictions = preds)
}

# -----------------------------
# 10. Evaluate on test data
# -----------------------------
test_pred <- predict_nn(x_test, W1, b1, W2, b2)
test_class <- test_pred$predictions

accuracy <- mean(test_class == y_test)
cat("\nTest Accuracy =", round(accuracy, 4), "\n")

Test Accuracy = 0.9436 
Code
# -----------------------------
# 11. Confusion matrix
# -----------------------------
conf_mat <- table(Predicted = test_class, Actual = y_test)
print(conf_mat)
         Actual
Predicted    0    1    2    3    4    5    6    7    8    9
        0  965    0    9    0    2    9   11    2    4   10
        1    0 1113    2    1    2    2    3    7    5    8
        2    2    2  964   14    5    1    4   22    5    1
        3    1    2   11  950    0   22    1    7   20   11
        4    0    0    8    0  926    4    6    5    7   23
        5    4    1    1   17    0  813   11    2   14    6
        6    6    4   11    1   11   12  917    0    9    1
        7    1    2    8   11    3    3    2  957    8   12
        8    1   11   16   14    5   18    3    2  899    5
        9    0    0    2    2   28    8    0   24    3  932
Code
# -----------------------------
# 12. Plot training loss
# -----------------------------
plot(loss_history, type = "l", lwd = 2,
     main = "Training Loss for MNIST Neural Network",
     xlab = "Epoch", ylab = "Cross-Entropy Loss")

Code
# -----------------------------
# 13. Show first few predictions
# -----------------------------
head(data.frame(
  Actual = y_test[1:10],
  Predicted = test_class[1:10]
))
  Actual Predicted
1      7         7
2      2         2
3      1         1
4      0         0
5      4         4
6      1         1