Deep learning is a subset of machine learning that uses neural networks with multiple layers (“deep” architectures). Compared with many classical ML models, deep learning can capture complex nonlinear relationships and can be especially powerful for unstructured data such as images and text. In practice, the deep learning workflow still follows the same discipline as standard ML:
In this chapter, we demonstrate deep neural networks for: - binary classification (Pima Indians Diabetes), - regression (Boston housing), - image classification using a convolutional neural network (CNN) (MNIST).
We use the keras package in R, which provides a
user-friendly interface to define and train neural networks.
Before modeling, we load the core packages.
- readr: data I/O (not heavily used here but commonly
needed in practice)
- keras: deep learning framework interface
- DT: interactive tables for quick viewing (useful in
exploratory steps)
We also suppress messages to keep the knitted output clean and
readable.
We start with the same Pima Indians Diabetes dataset used in the machine learning chapter. This dataset is small, tabular, and appropriate for demonstrating a basic dense neural network for classification.
Important note: neural networks are data-hungry. For small datasets, deep learning may not outperform classical ML. However, the workflow is still valuable to learn.
library(reticulate)
## Warning: package 'reticulate' was built under R version 4.4.3
k_utils <- import("keras.utils")
# load the Pima Indians dataset from the mlbench dataset
library(mlbench)
## Warning: package 'mlbench' was built under R version 4.4.3
data(PimaIndiansDiabetes)
# rename dataset to have shorter name because lazy
diabetes <- PimaIndiansDiabetes
data.set <- diabetes
# datatable(data.set[sample(nrow(data.set),
# replace = FALSE,
# size = 0.005 * nrow(data.set)), ])
A quick summary helps confirm variable types and detect obvious issues. For example, you may want to check if any predictors have implausible zeros or missingness patterns (common in clinical measurements).
summary(data.set)
## pregnant glucose pressure triceps
## Min. : 0.000 Min. : 0.0 Min. : 0.00 Min. : 0.00
## 1st Qu.: 1.000 1st Qu.: 99.0 1st Qu.: 62.00 1st Qu.: 0.00
## Median : 3.000 Median :117.0 Median : 72.00 Median :23.00
## Mean : 3.845 Mean :120.9 Mean : 69.11 Mean :20.54
## 3rd Qu.: 6.000 3rd Qu.:140.2 3rd Qu.: 80.00 3rd Qu.:32.00
## Max. :17.000 Max. :199.0 Max. :122.00 Max. :99.00
## insulin mass pedigree age diabetes
## Min. : 0.0 Min. : 0.00 Min. :0.0780 Min. :21.00 neg:500
## 1st Qu.: 0.0 1st Qu.:27.30 1st Qu.:0.2437 1st Qu.:24.00 pos:268
## Median : 30.5 Median :32.00 Median :0.3725 Median :29.00
## Mean : 79.8 Mean :31.99 Mean :0.4719 Mean :33.24
## 3rd Qu.:127.2 3rd Qu.:36.60 3rd Qu.:0.6262 3rd Qu.:41.00
## Max. :846.0 Max. :67.10 Max. :2.4200 Max. :81.00
Keras typically expects numeric matrices for the input features and numeric/factor outcomes that are compatible with the chosen loss function.
Here we convert the outcome diabetes into numeric and
then shift it to 0/1.
This step is common because many neural network outputs (especially with
one-hot encoding) assume classes start at 0.
data.set$diabetes <- as.numeric(data.set$diabetes)
data.set$diabetes=data.set$diabetes-1
head(data.set$diabetes)
## [1] 1 0 1 0 1 0
We check the dataset again to confirm that the outcome and predictors are in the expected format and that the structure matches what the model will consume.
head(data.set)
## pregnant glucose pressure triceps insulin mass pedigree age diabetes
## 1 6 148 72 35 0 33.6 0.627 50 1
## 2 1 85 66 29 0 26.6 0.351 31 0
## 3 8 183 64 0 0 23.3 0.672 32 1
## 4 1 89 66 23 94 28.1 0.167 21 0
## 5 0 137 40 35 168 43.1 2.288 33 1
## 6 5 116 74 0 0 25.6 0.201 30 0
str(data.set)
## 'data.frame': 768 obs. of 9 variables:
## $ pregnant: num 6 1 8 1 0 5 3 10 2 8 ...
## $ glucose : num 148 85 183 89 137 116 78 115 197 125 ...
## $ pressure: num 72 66 64 66 40 74 50 0 70 96 ...
## $ triceps : num 35 29 0 23 35 0 32 0 45 0 ...
## $ insulin : num 0 0 0 94 168 0 88 0 543 0 ...
## $ mass : num 33.6 26.6 23.3 28.1 43.1 25.6 31 35.3 30.5 0 ...
## $ pedigree: num 0.627 0.351 0.672 0.167 2.288 ...
## $ age : num 50 31 32 21 33 30 26 29 53 54 ...
## $ diabetes: num 1 0 1 0 1 0 1 0 1 1 ...
dimnames for a clean
numeric structure.# Cast dataframe as a matrix
data.set <- as.matrix(data.set)
# Remove column names
dimnames(data.set) = NULL
We view the matrix head to confirm that numeric conversion and ordering look correct.
head(data.set)
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9]
## [1,] 6 148 72 35 0 33.6 0.627 50 1
## [2,] 1 85 66 29 0 26.6 0.351 31 0
## [3,] 8 183 64 0 0 23.3 0.672 32 1
## [4,] 1 89 66 23 94 28.1 0.167 21 0
## [5,] 0 137 40 35 168 43.1 2.288 33 1
## [6,] 5 116 74 0 0 25.6 0.201 30 0
xtrain ytrian xtest ytestPractical note: this split is a simple random split. For small datasets, results can vary depending on the split. In more formal analyses, you may repeat splits or use cross-validation.
# Split for train and test data
set.seed(100)
indx <- sample(2,
nrow(data.set),
replace = TRUE,
prob = c(0.8, 0.2)) # Makes index with values 1 and 2
We define the predictor matrices (x_train,
x_test) by selecting the first 8 columns as features.
# Select only the feature variables
# Take rows with index = 1
x_train <- data.set[indx == 1, 1:8]
x_test <- data.set[indx == 2, 1:8]
Feature scaling is usually necessary for dense neural networks on
tabular data. Scaling improves numerical stability and helps
gradient-based optimization converge faster.
Here we use scale() (standardization) on training and test
features.
# Feature Scaling
x_train <- scale(x_train )
train_center <- attr(x_train, "scaled:center") # the mean of each column in the training set
train_scale <- attr(x_train, "scaled:scale") # the standard deviation of each column in the training set
x_test <- scale(x_test, center = train_center, scale = train_scale)
We store the true test labels in their original 0/1 numeric form for later evaluation.
y_test_actual <- data.set[indx == 2, 9]
to_categorical() converts the class label (0/1) into a
two-column indicator matrix.# Using similar indices to correspond to the training and test set
k_utils <- import("keras.utils")
y_train <- k_utils$to_categorical(data.set[indx == 1, 9])
y_test <- k_utils$to_categorical(data.set[indx == 2, 9])
head(y_train)
## [,1] [,2]
## [1,] 0 1
## [2,] 1 0
## [3,] 0 1
## [4,] 1 0
## [5,] 0 1
## [6,] 1 0
head(data.set[indx == 1, 9],20)
## [1] 1 0 1 0 1 0 0 1 1 0 0 1 1 1 1 1 0 1 0 0
x_train must match y_train, and similarly for
the test set.dim(x_train)
## [1] 609 8
dim(y_train)
## [1] 609 2
dim(x_test)
## [1] 159 8
dim(y_test)
## [1] 159 2
A dense (fully-connected) neural network is the baseline architecture
for tabular data.
Conceptually: - the input layer receives the 8 standardized predictors,
- hidden layers apply nonlinear transformations (ReLU), - the output
layer produces class probabilities via softmax.
units = 2. The text here is interpreted as “output layer
contains multiple levels/classes.”)# 1. Initialize the model
model <- keras_model_sequential()
# 2. Explicitly add layers (use $add to avoid positional-argument ambiguity in Keras 3 when using pipes)
model$add(layer_input(shape = c(8))) # 8 corresponds to your input_shape
model$add(layer_dense(
units = 10,
activation = "relu",
name = "DeepLayer1"
))
model$add(layer_dense(
units = 10,
activation = "relu",
name = "DeepLayer2"
))
model$add(layer_dense(
units = 2,
activation = "softmax",
name = "OutputLayer"
))
# 3. View model structure
model$summary()
## Model: "sequential"
## ┌─────────────────────────────────┬────────────────────────┬───────────────┐
## │ Layer (type) │ Output Shape │ Param # │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ DeepLayer1 (Dense) │ (None, 10) │ 90 │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ DeepLayer2 (Dense) │ (None, 10) │ 110 │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ OutputLayer (Dense) │ (None, 2) │ 22 │
## └─────────────────────────────────┴────────────────────────┴───────────────┘
## Total params: 222 (888.00 B)
## Trainable params: 222 (888.00 B)
## Non-trainable params: 0 (0.00 B)
Compiling sets the loss function, optimizer, and evaluation
metrics.
- categorical_crossentropy: appropriate for multi-class
(including binary with one-hot) classification
- adam: a widely used optimizer that works well in many
practical settings
- accuracy: a basic metric; for imbalanced datasets,
consider sensitivity/specificity or AUC in addition.
# Compiling the model
model$compile(
loss = "categorical_crossentropy",
optimizer = "adam",
metrics = list("accuracy") # In Keras 3, using list() is recommended for Python-side compatibility
)
Training is performed using mini-batch gradient descent. Key training
parameters: - epoch: number of passes through the training
data
- batch_size: number of samples per gradient update
- validation_split: portion of training data held out to
monitor validation performance
Validation monitoring is essential: if training accuracy keeps improving but validation accuracy stagnates or declines, the model is likely overfitting.
# Train the model. Note the argument is 'epochs' (plural) and must be integer.
history <- model$fit(
x = as.matrix(x_train),
y = y_train,
epochs = as.integer(60),
batch_size = as.integer(64),
validation_split = 0.15,
verbose = 2
)
## Epoch 1/60
## 9/9 - 1s - 108ms/step - accuracy: 0.5899 - loss: 0.6824 - val_accuracy: 0.6196 - val_loss: 0.6471
## Epoch 2/60
## 9/9 - 0s - 10ms/step - accuracy: 0.6170 - loss: 0.6670 - val_accuracy: 0.6413 - val_loss: 0.6371
## Epoch 3/60
## 9/9 - 0s - 12ms/step - accuracy: 0.6364 - loss: 0.6546 - val_accuracy: 0.6196 - val_loss: 0.6277
## Epoch 4/60
## 9/9 - 0s - 11ms/step - accuracy: 0.6402 - loss: 0.6432 - val_accuracy: 0.6196 - val_loss: 0.6187
## Epoch 5/60
## 9/9 - 0s - 10ms/step - accuracy: 0.6480 - loss: 0.6323 - val_accuracy: 0.6304 - val_loss: 0.6103
## Epoch 6/60
## 9/9 - 0s - 11ms/step - accuracy: 0.6480 - loss: 0.6236 - val_accuracy: 0.6413 - val_loss: 0.6027
## Epoch 7/60
## 9/9 - 0s - 12ms/step - accuracy: 0.6518 - loss: 0.6143 - val_accuracy: 0.6413 - val_loss: 0.5957
## Epoch 8/60
## 9/9 - 0s - 12ms/step - accuracy: 0.6518 - loss: 0.6065 - val_accuracy: 0.6413 - val_loss: 0.5901
## Epoch 9/60
## 9/9 - 0s - 11ms/step - accuracy: 0.6518 - loss: 0.5993 - val_accuracy: 0.6413 - val_loss: 0.5848
## Epoch 10/60
## 9/9 - 0s - 11ms/step - accuracy: 0.6518 - loss: 0.5925 - val_accuracy: 0.6413 - val_loss: 0.5795
## Epoch 11/60
## 9/9 - 0s - 11ms/step - accuracy: 0.6538 - loss: 0.5861 - val_accuracy: 0.6413 - val_loss: 0.5742
## Epoch 12/60
## 9/9 - 0s - 11ms/step - accuracy: 0.6557 - loss: 0.5796 - val_accuracy: 0.6413 - val_loss: 0.5684
## Epoch 13/60
## 9/9 - 0s - 13ms/step - accuracy: 0.6596 - loss: 0.5726 - val_accuracy: 0.6413 - val_loss: 0.5628
## Epoch 14/60
## 9/9 - 0s - 12ms/step - accuracy: 0.6576 - loss: 0.5662 - val_accuracy: 0.6522 - val_loss: 0.5575
## Epoch 15/60
## 9/9 - 0s - 12ms/step - accuracy: 0.6692 - loss: 0.5601 - val_accuracy: 0.6522 - val_loss: 0.5521
## Epoch 16/60
## 9/9 - 0s - 11ms/step - accuracy: 0.6789 - loss: 0.5545 - val_accuracy: 0.6522 - val_loss: 0.5476
## Epoch 17/60
## 9/9 - 0s - 12ms/step - accuracy: 0.6944 - loss: 0.5492 - val_accuracy: 0.6848 - val_loss: 0.5427
## Epoch 18/60
## 9/9 - 0s - 13ms/step - accuracy: 0.6983 - loss: 0.5437 - val_accuracy: 0.6957 - val_loss: 0.5402
## Epoch 19/60
## 9/9 - 0s - 11ms/step - accuracy: 0.6925 - loss: 0.5383 - val_accuracy: 0.7391 - val_loss: 0.5378
## Epoch 20/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7195 - loss: 0.5340 - val_accuracy: 0.7283 - val_loss: 0.5344
## Epoch 21/60
## 9/9 - 0s - 10ms/step - accuracy: 0.7389 - loss: 0.5286 - val_accuracy: 0.7391 - val_loss: 0.5292
## Epoch 22/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7389 - loss: 0.5227 - val_accuracy: 0.7283 - val_loss: 0.5239
## Epoch 23/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7582 - loss: 0.5169 - val_accuracy: 0.7500 - val_loss: 0.5193
## Epoch 24/60
## 9/9 - 0s - 10ms/step - accuracy: 0.7640 - loss: 0.5109 - val_accuracy: 0.7609 - val_loss: 0.5155
## Epoch 25/60
## 9/9 - 0s - 10ms/step - accuracy: 0.7621 - loss: 0.5057 - val_accuracy: 0.7717 - val_loss: 0.5117
## Epoch 26/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7602 - loss: 0.5012 - val_accuracy: 0.7609 - val_loss: 0.5083
## Epoch 27/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7640 - loss: 0.4976 - val_accuracy: 0.7609 - val_loss: 0.5062
## Epoch 28/60
## 9/9 - 0s - 15ms/step - accuracy: 0.7660 - loss: 0.4942 - val_accuracy: 0.7609 - val_loss: 0.5040
## Epoch 29/60
## 9/9 - 0s - 15ms/step - accuracy: 0.7679 - loss: 0.4908 - val_accuracy: 0.7717 - val_loss: 0.5020
## Epoch 30/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7660 - loss: 0.4870 - val_accuracy: 0.7609 - val_loss: 0.5006
## Epoch 31/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7621 - loss: 0.4828 - val_accuracy: 0.7717 - val_loss: 0.4992
## Epoch 32/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7621 - loss: 0.4799 - val_accuracy: 0.7609 - val_loss: 0.4991
## Epoch 33/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7640 - loss: 0.4769 - val_accuracy: 0.7609 - val_loss: 0.4983
## Epoch 34/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7660 - loss: 0.4743 - val_accuracy: 0.7609 - val_loss: 0.4986
## Epoch 35/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7679 - loss: 0.4715 - val_accuracy: 0.7609 - val_loss: 0.4974
## Epoch 36/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7679 - loss: 0.4694 - val_accuracy: 0.7500 - val_loss: 0.4959
## Epoch 37/60
## 9/9 - 0s - 12ms/step - accuracy: 0.7698 - loss: 0.4673 - val_accuracy: 0.7500 - val_loss: 0.4962
## Epoch 38/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7737 - loss: 0.4659 - val_accuracy: 0.7500 - val_loss: 0.4960
## Epoch 39/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7737 - loss: 0.4640 - val_accuracy: 0.7609 - val_loss: 0.4963
## Epoch 40/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7814 - loss: 0.4625 - val_accuracy: 0.7500 - val_loss: 0.4970
## Epoch 41/60
## 9/9 - 0s - 14ms/step - accuracy: 0.7814 - loss: 0.4604 - val_accuracy: 0.7500 - val_loss: 0.4973
## Epoch 42/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7814 - loss: 0.4586 - val_accuracy: 0.7391 - val_loss: 0.4979
## Epoch 43/60
## 9/9 - 0s - 12ms/step - accuracy: 0.7776 - loss: 0.4574 - val_accuracy: 0.7391 - val_loss: 0.4985
## Epoch 44/60
## 9/9 - 0s - 16ms/step - accuracy: 0.7756 - loss: 0.4561 - val_accuracy: 0.7283 - val_loss: 0.4991
## Epoch 45/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7834 - loss: 0.4546 - val_accuracy: 0.7283 - val_loss: 0.5002
## Epoch 46/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7834 - loss: 0.4536 - val_accuracy: 0.7283 - val_loss: 0.5006
## Epoch 47/60
## 9/9 - 0s - 14ms/step - accuracy: 0.7814 - loss: 0.4522 - val_accuracy: 0.7283 - val_loss: 0.5009
## Epoch 48/60
## 9/9 - 0s - 14ms/step - accuracy: 0.7776 - loss: 0.4510 - val_accuracy: 0.7283 - val_loss: 0.5019
## Epoch 49/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7737 - loss: 0.4503 - val_accuracy: 0.7283 - val_loss: 0.5046
## Epoch 50/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7737 - loss: 0.4486 - val_accuracy: 0.7283 - val_loss: 0.5059
## Epoch 51/60
## 9/9 - 0s - 14ms/step - accuracy: 0.7795 - loss: 0.4481 - val_accuracy: 0.7391 - val_loss: 0.5075
## Epoch 52/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7814 - loss: 0.4472 - val_accuracy: 0.7391 - val_loss: 0.5078
## Epoch 53/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7795 - loss: 0.4465 - val_accuracy: 0.7391 - val_loss: 0.5082
## Epoch 54/60
## 9/9 - 0s - 15ms/step - accuracy: 0.7795 - loss: 0.4456 - val_accuracy: 0.7391 - val_loss: 0.5096
## Epoch 55/60
## 9/9 - 0s - 12ms/step - accuracy: 0.7795 - loss: 0.4445 - val_accuracy: 0.7391 - val_loss: 0.5110
## Epoch 56/60
## 9/9 - 0s - 12ms/step - accuracy: 0.7795 - loss: 0.4436 - val_accuracy: 0.7283 - val_loss: 0.5126
## Epoch 57/60
## 9/9 - 0s - 12ms/step - accuracy: 0.7795 - loss: 0.4431 - val_accuracy: 0.7283 - val_loss: 0.5136
## Epoch 58/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7814 - loss: 0.4427 - val_accuracy: 0.7283 - val_loss: 0.5144
## Epoch 59/60
## 9/9 - 0s - 11ms/step - accuracy: 0.7814 - loss: 0.4428 - val_accuracy: 0.7283 - val_loss: 0.5158
## Epoch 60/60
## 9/9 - 0s - 13ms/step - accuracy: 0.7814 - loss: 0.4418 - val_accuracy: 0.7283 - val_loss: 0.5160
Plotting training history helps diagnose convergence and overfitting. Typically you look at: - training vs validation loss curves, - training vs validation accuracy curves.
# Extract metrics from the Python History object for R plotting
metrics_df <- as.data.frame(history$history)
metrics_df$epoch <- 1:nrow(metrics_df)
par(mfrow = c(1, 2))
# Plot Loss curves
plot(metrics_df$epoch, metrics_df$loss, type = "l", col = "blue", main = "Loss", xlab = "Epoch")
lines(metrics_df$epoch, metrics_df$val_loss, col = "red")
# Plot Accuracy curves
plot(metrics_df$epoch, metrics_df$accuracy, type = "l", col = "blue", main = "Accuracy", xlab = "Epoch")
lines(metrics_df$epoch, metrics_df$val_accuracy, col = "red")
using xtest and ytest data sets to evaluate
the built model directly
Evaluation on the test set provides a final, unbiased estimate of model
performance (under the chosen split).
The output includes the loss and accuracy.
model$evaluate(as.matrix(x_test), y_test)
##
## [1m1/5[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m0s[0m 25ms/step - accuracy: 0.7812 - loss: 0.4574
## [1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.8113 - loss: 0.4208
## [1] 0.4207639 0.8113208
# - accuracy: 0.7924528 - loss: 0.4190769
Here we generate predicted class labels. The model outputs class
probabilities; we select the class with the highest probability using
k_argmax().
The confusion table compares predicted classes with actual test labels. In binary classification: - diagonal counts are correct predictions, - off-diagonal counts are misclassifications.
# Predict probabilities for the test set
prob_preds <- model$predict(as.matrix(x_test))
##
## [1m1/5[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m0s[0m 53ms/step
## [1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step
# Convert probabilities to class labels using argmax
pred <- apply(prob_preds, 1, which.max) - 1
# Compute the confusion matrix
library(caret)
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 4.4.3
## Loading required package: lattice
confusionMatrix(reference = as.factor(y_test_actual), data = as.factor(pred ))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 92 18
## 1 12 37
##
## Accuracy : 0.8113
## 95% CI : (0.7417, 0.8689)
## No Information Rate : 0.6541
## P-Value [Acc > NIR] : 9.431e-06
##
## Kappa : 0.572
##
## Mcnemar's Test P-Value : 0.3613
##
## Sensitivity : 0.8846
## Specificity : 0.6727
## Pos Pred Value : 0.8364
## Neg Pred Value : 0.7551
## Prevalence : 0.6541
## Detection Rate : 0.5786
## Detection Prevalence : 0.6918
## Balanced Accuracy : 0.7787
##
## 'Positive' Class : 0
##
For many applied use cases, predicted probabilities are more
informative than predicted labels (especially if you plan to choose a
custom probability threshold).
This block prints the first few rows of predicted probabilities for the
two classes.
head(prob_preds)
## [,1] [,2]
## [1,] 0.9432676 0.05673239
## [2,] 0.1724173 0.82758278
## [3,] 0.9800782 0.01992181
## [4,] 0.3977331 0.60226691
## [5,] 0.9327911 0.06720894
## [6,] 0.8021081 0.19789185
prob, pred, and ytestThis combined view is helpful for model debugging: -
prob: predicted probabilities for each class
- pred: predicted class label (argmax)
- y_test_actual: true class label
In practice, you may also compute calibration plots or ROC curves when probability quality matters.
comparison <- cbind(prob_preds ,pred, y_test_actual )
head(comparison)
## pred y_test_actual
## [1,] 0.9432676 0.05673239 0 1
## [2,] 0.1724173 0.82758278 1 1
## [3,] 0.9800782 0.01992181 0 0
## [4,] 0.3977331 0.60226691 1 1
## [5,] 0.9327911 0.06720894 0 0
## [6,] 0.8021081 0.19789185 0 0
Neural networks can also model continuous outcomes. In regression
settings: - the output layer typically has units = 1 and no
activation (linear output), - loss functions commonly include MSE, -
evaluation metrics often include MAE and RMSE.
We demonstrate regression using the Boston housing dataset
(MASS::Boston).
We load required libraries and then load the dataset.
plotly is included for interactive plotting (although the
core training plot uses base plotting via
plot(history)).
library(readr)
library(keras)
library(plotly)
data("Boston", package = "MASS")
data.set <- Boston
We inspect dataset dimensions. This helps confirm the number of predictors and the target column index.
dim(data.set)
## [1] 506 14
As above, we convert to matrix and remove dimnames for Keras input compatibility.
library(DT)
# Cast dataframe as a matrix
data.set <- as.matrix(data.set)
# Remove column names
dimnames(data.set) = NULL
head(data.set)
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12]
## [1,] 0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90
## [2,] 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90
## [3,] 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83
## [4,] 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63
## [5,] 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90
## [6,] 0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12
## [,13] [,14]
## [1,] 4.98 24.0
## [2,] 9.14 21.6
## [3,] 4.03 34.7
## [4,] 2.94 33.4
## [5,] 5.33 36.2
## [6,] 5.21 28.7
The target variable is in column 14 (medv in the Boston
dataset). We summarize it to understand the outcome range and
distribution.
summary(data.set[, 14])
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 5.00 17.02 21.20 22.53 25.00 50.00
A histogram provides a quick view of distribution shape (skewness, outliers, multimodality).
hist( data.set[, 14])
Fig 1 Histogram of the target variable
We split the dataset into training and test sets. The split ratio here is 75/25.
# Split for train and test data
set.seed(123)
indx <- sample(2,
nrow(data.set),
replace = TRUE,
prob = c(0.75, 0.25)) # Makes index with values 1 and 2
We separate predictors (first 13 columns) and outcome (14th column).
x_train <- data.set[indx == 1, 1:13]
x_test <- data.set[indx == 2, 1:13]
y_train <- data.set[indx == 1, 14]
y_test <- data.set[indx == 2, 14]
xtrain and xtest dataNeural networks benefit from standardized inputs. This typically improves training stability and reduces the chance that a few large-scale predictors dominate gradient updates.
x_train <- scale(x_train)
train_center <- attr(x_train, "scaled:center") # the mean of each column in the training set
train_scale <- attr(x_train, "scaled:scale") # the standard deviation of each column in the training set
x_test <- scale(x_test, center = train_center, scale = train_scale)
This network uses multiple dense layers with dropout.
- Dense layers learn nonlinear transformations.
- Dropout randomly zeros some activations during training, which reduces
overfitting and acts as regularization.
For tabular regression, this architecture is a common practical baseline.
# Regression model for Boston Housing
model_reg <- keras_model_sequential()
model_reg$add(layer_input(shape = c(13)))
model_reg$add(layer_dense(units = 25, activation = "relu"))
model_reg$add(layer_dropout(rate =0.2))
model_reg$add(layer_dense(units = 1))
We print the model summary to verify layer shapes and parameter counts.
model_reg $ summary()
## Model: "sequential_1"
## ┌─────────────────────────────────┬────────────────────────┬───────────────┐
## │ Layer (type) │ Output Shape │ Param # │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dense (Dense) │ (None, 25) │ 350 │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dropout (Dropout) │ (None, 25) │ 0 │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dense_1 (Dense) │ (None, 1) │ 26 │
## └─────────────────────────────────┴────────────────────────┴───────────────┘
## Total params: 376 (1.47 KB)
## Trainable params: 376 (1.47 KB)
## Non-trainable params: 0 (0.00 B)
Printing configuration is sometimes useful for documenting architecture in reports or debugging.
model_reg $ get_config()
## $name
## [1] "sequential_1"
##
## $trainable
## [1] TRUE
##
## $dtype
## $dtype$module
## [1] "keras"
##
## $dtype$class_name
## [1] "DTypePolicy"
##
## $dtype$config
## $dtype$config$name
## [1] "float32"
##
##
## $dtype$registered_name
## NULL
##
##
## $layers
## $layers[[1]]
## $layers[[1]]$module
## [1] "keras.layers"
##
## $layers[[1]]$class_name
## [1] "InputLayer"
##
## $layers[[1]]$config
## $layers[[1]]$config$batch_shape
## $layers[[1]]$config$batch_shape[[1]]
## NULL
##
## $layers[[1]]$config$batch_shape[[2]]
## [1] 13
##
##
## $layers[[1]]$config$dtype
## [1] "float32"
##
## $layers[[1]]$config$sparse
## [1] FALSE
##
## $layers[[1]]$config$ragged
## [1] FALSE
##
## $layers[[1]]$config$name
## [1] "input_layer_1"
##
## $layers[[1]]$config$optional
## [1] FALSE
##
##
## $layers[[1]]$registered_name
## NULL
##
##
## $layers[[2]]
## $layers[[2]]$module
## [1] "keras.layers"
##
## $layers[[2]]$class_name
## [1] "Dense"
##
## $layers[[2]]$config
## $layers[[2]]$config$name
## [1] "dense"
##
## $layers[[2]]$config$trainable
## [1] TRUE
##
## $layers[[2]]$config$dtype
## $layers[[2]]$config$dtype$module
## [1] "keras"
##
## $layers[[2]]$config$dtype$class_name
## [1] "DTypePolicy"
##
## $layers[[2]]$config$dtype$config
## $layers[[2]]$config$dtype$config$name
## [1] "float32"
##
##
## $layers[[2]]$config$dtype$registered_name
## NULL
##
##
## $layers[[2]]$config$units
## [1] 25
##
## $layers[[2]]$config$activation
## [1] "relu"
##
## $layers[[2]]$config$use_bias
## [1] TRUE
##
## $layers[[2]]$config$kernel_initializer
## $layers[[2]]$config$kernel_initializer$module
## [1] "keras.initializers"
##
## $layers[[2]]$config$kernel_initializer$class_name
## [1] "GlorotUniform"
##
## $layers[[2]]$config$kernel_initializer$config
## $layers[[2]]$config$kernel_initializer$config$seed
## NULL
##
##
## $layers[[2]]$config$kernel_initializer$registered_name
## NULL
##
##
## $layers[[2]]$config$bias_initializer
## $layers[[2]]$config$bias_initializer$module
## [1] "keras.initializers"
##
## $layers[[2]]$config$bias_initializer$class_name
## [1] "Zeros"
##
## $layers[[2]]$config$bias_initializer$config
## named list()
##
## $layers[[2]]$config$bias_initializer$registered_name
## NULL
##
##
## $layers[[2]]$config$kernel_regularizer
## NULL
##
## $layers[[2]]$config$bias_regularizer
## NULL
##
## $layers[[2]]$config$kernel_constraint
## NULL
##
## $layers[[2]]$config$bias_constraint
## NULL
##
## $layers[[2]]$config$quantization_config
## NULL
##
##
## $layers[[2]]$registered_name
## NULL
##
## $layers[[2]]$build_config
## $layers[[2]]$build_config$input_shape
## $layers[[2]]$build_config$input_shape[[1]]
## NULL
##
## $layers[[2]]$build_config$input_shape[[2]]
## [1] 13
##
##
##
##
## $layers[[3]]
## $layers[[3]]$module
## [1] "keras.layers"
##
## $layers[[3]]$class_name
## [1] "Dropout"
##
## $layers[[3]]$config
## $layers[[3]]$config$name
## [1] "dropout"
##
## $layers[[3]]$config$trainable
## [1] TRUE
##
## $layers[[3]]$config$dtype
## $layers[[3]]$config$dtype$module
## [1] "keras"
##
## $layers[[3]]$config$dtype$class_name
## [1] "DTypePolicy"
##
## $layers[[3]]$config$dtype$config
## $layers[[3]]$config$dtype$config$name
## [1] "float32"
##
##
## $layers[[3]]$config$dtype$registered_name
## NULL
##
##
## $layers[[3]]$config$rate
## [1] 0.2
##
## $layers[[3]]$config$seed
## NULL
##
## $layers[[3]]$config$noise_shape
## NULL
##
##
## $layers[[3]]$registered_name
## NULL
##
##
## $layers[[4]]
## $layers[[4]]$module
## [1] "keras.layers"
##
## $layers[[4]]$class_name
## [1] "Dense"
##
## $layers[[4]]$config
## $layers[[4]]$config$name
## [1] "dense_1"
##
## $layers[[4]]$config$trainable
## [1] TRUE
##
## $layers[[4]]$config$dtype
## $layers[[4]]$config$dtype$module
## [1] "keras"
##
## $layers[[4]]$config$dtype$class_name
## [1] "DTypePolicy"
##
## $layers[[4]]$config$dtype$config
## $layers[[4]]$config$dtype$config$name
## [1] "float32"
##
##
## $layers[[4]]$config$dtype$registered_name
## NULL
##
##
## $layers[[4]]$config$units
## [1] 1
##
## $layers[[4]]$config$activation
## [1] "linear"
##
## $layers[[4]]$config$use_bias
## [1] TRUE
##
## $layers[[4]]$config$kernel_initializer
## $layers[[4]]$config$kernel_initializer$module
## [1] "keras.initializers"
##
## $layers[[4]]$config$kernel_initializer$class_name
## [1] "GlorotUniform"
##
## $layers[[4]]$config$kernel_initializer$config
## $layers[[4]]$config$kernel_initializer$config$seed
## NULL
##
##
## $layers[[4]]$config$kernel_initializer$registered_name
## NULL
##
##
## $layers[[4]]$config$bias_initializer
## $layers[[4]]$config$bias_initializer$module
## [1] "keras.initializers"
##
## $layers[[4]]$config$bias_initializer$class_name
## [1] "Zeros"
##
## $layers[[4]]$config$bias_initializer$config
## named list()
##
## $layers[[4]]$config$bias_initializer$registered_name
## NULL
##
##
## $layers[[4]]$config$kernel_regularizer
## NULL
##
## $layers[[4]]$config$bias_regularizer
## NULL
##
## $layers[[4]]$config$kernel_constraint
## NULL
##
## $layers[[4]]$config$bias_constraint
## NULL
##
## $layers[[4]]$config$quantization_config
## NULL
##
##
## $layers[[4]]$registered_name
## NULL
##
## $layers[[4]]$build_config
## $layers[[4]]$build_config$input_shape
## $layers[[4]]$build_config$input_shape[[1]]
## NULL
##
## $layers[[4]]$build_config$input_shape[[2]]
## [1] 25
##
##
##
##
##
## $build_input_shape
## $build_input_shape[[1]]
## NULL
##
## $build_input_shape[[2]]
## [1] 13
For regression: - loss: MSE ("mse")
- optimizer: RMSprop is commonly used for regression tasks and works
well in many settings
- metric: MAE is often easier to interpret in the same unit as the
outcome.
model_reg$compile(
loss = "mse",
optimizer = "rmsprop",
metrics = list("mean_absolute_error")
)
We train the model with early stopping. Early stopping monitors validation MAE and stops training if no improvement is observed for several epochs. This is an effective and simple overfitting control strategy.
history_reg <- model_reg$fit(
x = as.matrix(x_train),
y = as.matrix(y_train),
epochs = as.integer(100),
batch_size = as.integer(64),
validation_split = 0.1,
verbose = 2
)
## Epoch 1/100
## 6/6 - 1s - 126ms/step - loss: 593.5840 - mean_absolute_error: 21.9404 - val_loss: 319.0117 - val_mean_absolute_error: 17.3724
## Epoch 2/100
## 6/6 - 0s - 16ms/step - loss: 580.2263 - mean_absolute_error: 21.6798 - val_loss: 315.5471 - val_mean_absolute_error: 17.2689
## Epoch 3/100
## 6/6 - 0s - 17ms/step - loss: 573.4169 - mean_absolute_error: 21.5376 - val_loss: 312.4537 - val_mean_absolute_error: 17.1760
## Epoch 4/100
## 6/6 - 0s - 17ms/step - loss: 566.1923 - mean_absolute_error: 21.3858 - val_loss: 309.2842 - val_mean_absolute_error: 17.0804
## Epoch 5/100
## 6/6 - 0s - 17ms/step - loss: 556.9926 - mean_absolute_error: 21.1751 - val_loss: 306.3380 - val_mean_absolute_error: 16.9904
## Epoch 6/100
## 6/6 - 0s - 17ms/step - loss: 551.1225 - mean_absolute_error: 21.0478 - val_loss: 303.3892 - val_mean_absolute_error: 16.8996
## Epoch 7/100
## 6/6 - 0s - 17ms/step - loss: 541.5051 - mean_absolute_error: 20.8749 - val_loss: 300.8377 - val_mean_absolute_error: 16.8206
## Epoch 8/100
## 6/6 - 0s - 17ms/step - loss: 536.0032 - mean_absolute_error: 20.7577 - val_loss: 298.1372 - val_mean_absolute_error: 16.7367
## Epoch 9/100
## 6/6 - 0s - 16ms/step - loss: 526.4520 - mean_absolute_error: 20.5228 - val_loss: 295.1030 - val_mean_absolute_error: 16.6417
## Epoch 10/100
## 6/6 - 0s - 18ms/step - loss: 523.0405 - mean_absolute_error: 20.4553 - val_loss: 292.0318 - val_mean_absolute_error: 16.5449
## Epoch 11/100
## 6/6 - 0s - 19ms/step - loss: 510.6674 - mean_absolute_error: 20.1880 - val_loss: 289.0659 - val_mean_absolute_error: 16.4511
## Epoch 12/100
## 6/6 - 0s - 16ms/step - loss: 503.2650 - mean_absolute_error: 20.0903 - val_loss: 286.3202 - val_mean_absolute_error: 16.3639
## Epoch 13/100
## 6/6 - 0s - 17ms/step - loss: 497.8701 - mean_absolute_error: 19.8772 - val_loss: 283.5000 - val_mean_absolute_error: 16.2739
## Epoch 14/100
## 6/6 - 0s - 17ms/step - loss: 487.1030 - mean_absolute_error: 19.7094 - val_loss: 280.1739 - val_mean_absolute_error: 16.1656
## Epoch 15/100
## 6/6 - 0s - 17ms/step - loss: 475.0551 - mean_absolute_error: 19.4116 - val_loss: 276.9488 - val_mean_absolute_error: 16.0598
## Epoch 16/100
## 6/6 - 0s - 17ms/step - loss: 466.6918 - mean_absolute_error: 19.2212 - val_loss: 274.0783 - val_mean_absolute_error: 15.9648
## Epoch 17/100
## 6/6 - 0s - 17ms/step - loss: 459.9448 - mean_absolute_error: 19.0493 - val_loss: 270.9745 - val_mean_absolute_error: 15.8605
## Epoch 18/100
## 6/6 - 0s - 16ms/step - loss: 454.8932 - mean_absolute_error: 18.8769 - val_loss: 267.6364 - val_mean_absolute_error: 15.7474
## Epoch 19/100
## 6/6 - 0s - 16ms/step - loss: 442.8012 - mean_absolute_error: 18.6006 - val_loss: 264.5380 - val_mean_absolute_error: 15.6423
## Epoch 20/100
## 6/6 - 0s - 16ms/step - loss: 434.8492 - mean_absolute_error: 18.3910 - val_loss: 260.9543 - val_mean_absolute_error: 15.5191
## Epoch 21/100
## 6/6 - 0s - 18ms/step - loss: 423.1867 - mean_absolute_error: 18.1615 - val_loss: 257.5658 - val_mean_absolute_error: 15.4021
## Epoch 22/100
## 6/6 - 0s - 15ms/step - loss: 414.3769 - mean_absolute_error: 17.9670 - val_loss: 254.3262 - val_mean_absolute_error: 15.2897
## Epoch 23/100
## 6/6 - 0s - 15ms/step - loss: 408.3499 - mean_absolute_error: 17.7833 - val_loss: 250.9887 - val_mean_absolute_error: 15.1718
## Epoch 24/100
## 6/6 - 0s - 15ms/step - loss: 398.6693 - mean_absolute_error: 17.5503 - val_loss: 247.7304 - val_mean_absolute_error: 15.0572
## Epoch 25/100
## 6/6 - 0s - 15ms/step - loss: 384.7629 - mean_absolute_error: 17.1936 - val_loss: 244.3035 - val_mean_absolute_error: 14.9345
## Epoch 26/100
## 6/6 - 0s - 15ms/step - loss: 376.5441 - mean_absolute_error: 17.0126 - val_loss: 240.7584 - val_mean_absolute_error: 14.8068
## Epoch 27/100
## 6/6 - 0s - 14ms/step - loss: 367.9996 - mean_absolute_error: 16.7997 - val_loss: 237.2258 - val_mean_absolute_error: 14.6773
## Epoch 28/100
## 6/6 - 0s - 13ms/step - loss: 357.5495 - mean_absolute_error: 16.4173 - val_loss: 233.6853 - val_mean_absolute_error: 14.5470
## Epoch 29/100
## 6/6 - 0s - 13ms/step - loss: 346.7405 - mean_absolute_error: 16.1320 - val_loss: 229.6084 - val_mean_absolute_error: 14.3941
## Epoch 30/100
## 6/6 - 0s - 15ms/step - loss: 340.1533 - mean_absolute_error: 16.0121 - val_loss: 225.8949 - val_mean_absolute_error: 14.2540
## Epoch 31/100
## 6/6 - 0s - 15ms/step - loss: 329.0853 - mean_absolute_error: 15.7891 - val_loss: 222.1243 - val_mean_absolute_error: 14.1103
## Epoch 32/100
## 6/6 - 0s - 16ms/step - loss: 317.6953 - mean_absolute_error: 15.3888 - val_loss: 218.7649 - val_mean_absolute_error: 13.9822
## Epoch 33/100
## 6/6 - 0s - 14ms/step - loss: 317.1198 - mean_absolute_error: 15.2866 - val_loss: 215.2253 - val_mean_absolute_error: 13.8462
## Epoch 34/100
## 6/6 - 0s - 15ms/step - loss: 308.3627 - mean_absolute_error: 14.9944 - val_loss: 211.7213 - val_mean_absolute_error: 13.7106
## Epoch 35/100
## 6/6 - 0s - 16ms/step - loss: 299.4528 - mean_absolute_error: 14.8277 - val_loss: 208.1525 - val_mean_absolute_error: 13.5717
## Epoch 36/100
## 6/6 - 0s - 18ms/step - loss: 283.6210 - mean_absolute_error: 14.2640 - val_loss: 204.9196 - val_mean_absolute_error: 13.4446
## Epoch 37/100
## 6/6 - 0s - 18ms/step - loss: 271.9949 - mean_absolute_error: 13.9959 - val_loss: 200.8699 - val_mean_absolute_error: 13.2801
## Epoch 38/100
## 6/6 - 0s - 16ms/step - loss: 270.2120 - mean_absolute_error: 13.8545 - val_loss: 197.4629 - val_mean_absolute_error: 13.1438
## Epoch 39/100
## 6/6 - 0s - 15ms/step - loss: 256.9817 - mean_absolute_error: 13.6217 - val_loss: 193.8249 - val_mean_absolute_error: 12.9936
## Epoch 40/100
## 6/6 - 0s - 15ms/step - loss: 252.4408 - mean_absolute_error: 13.2687 - val_loss: 190.7803 - val_mean_absolute_error: 12.8710
## Epoch 41/100
## 6/6 - 0s - 14ms/step - loss: 246.6973 - mean_absolute_error: 13.2068 - val_loss: 187.0960 - val_mean_absolute_error: 12.7163
## Epoch 42/100
## 6/6 - 0s - 14ms/step - loss: 233.6218 - mean_absolute_error: 12.6591 - val_loss: 183.4566 - val_mean_absolute_error: 12.5622
## Epoch 43/100
## 6/6 - 0s - 14ms/step - loss: 230.4716 - mean_absolute_error: 12.6024 - val_loss: 180.3488 - val_mean_absolute_error: 12.4344
## Epoch 44/100
## 6/6 - 0s - 14ms/step - loss: 220.0802 - mean_absolute_error: 12.3003 - val_loss: 177.1870 - val_mean_absolute_error: 12.3007
## Epoch 45/100
## 6/6 - 0s - 15ms/step - loss: 215.3739 - mean_absolute_error: 12.1226 - val_loss: 173.9112 - val_mean_absolute_error: 12.1608
## Epoch 46/100
## 6/6 - 0s - 15ms/step - loss: 209.4984 - mean_absolute_error: 11.8126 - val_loss: 170.9959 - val_mean_absolute_error: 12.0379
## Epoch 47/100
## 6/6 - 0s - 14ms/step - loss: 199.9228 - mean_absolute_error: 11.6930 - val_loss: 167.7660 - val_mean_absolute_error: 11.8985
## Epoch 48/100
## 6/6 - 0s - 15ms/step - loss: 195.7981 - mean_absolute_error: 11.5360 - val_loss: 165.0495 - val_mean_absolute_error: 11.7866
## Epoch 49/100
## 6/6 - 0s - 18ms/step - loss: 184.8871 - mean_absolute_error: 11.0040 - val_loss: 162.5565 - val_mean_absolute_error: 11.6845
## Epoch 50/100
## 6/6 - 0s - 16ms/step - loss: 182.8613 - mean_absolute_error: 10.9473 - val_loss: 159.0256 - val_mean_absolute_error: 11.5255
## Epoch 51/100
## 6/6 - 0s - 16ms/step - loss: 167.6076 - mean_absolute_error: 10.3881 - val_loss: 155.9083 - val_mean_absolute_error: 11.3894
## Epoch 52/100
## 6/6 - 0s - 16ms/step - loss: 161.4895 - mean_absolute_error: 10.2168 - val_loss: 152.8675 - val_mean_absolute_error: 11.2550
## Epoch 53/100
## 6/6 - 0s - 15ms/step - loss: 159.6780 - mean_absolute_error: 10.1255 - val_loss: 149.7382 - val_mean_absolute_error: 11.1148
## Epoch 54/100
## 6/6 - 0s - 15ms/step - loss: 154.7589 - mean_absolute_error: 9.9321 - val_loss: 147.0170 - val_mean_absolute_error: 10.9951
## Epoch 55/100
## 6/6 - 0s - 15ms/step - loss: 150.3096 - mean_absolute_error: 9.8900 - val_loss: 143.8374 - val_mean_absolute_error: 10.8498
## Epoch 56/100
## 6/6 - 0s - 15ms/step - loss: 137.5247 - mean_absolute_error: 9.4035 - val_loss: 140.8621 - val_mean_absolute_error: 10.7118
## Epoch 57/100
## 6/6 - 0s - 17ms/step - loss: 131.9872 - mean_absolute_error: 9.2019 - val_loss: 137.7275 - val_mean_absolute_error: 10.5626
## Epoch 58/100
## 6/6 - 0s - 15ms/step - loss: 138.0352 - mean_absolute_error: 9.3365 - val_loss: 135.1038 - val_mean_absolute_error: 10.4437
## Epoch 59/100
## 6/6 - 0s - 14ms/step - loss: 131.2623 - mean_absolute_error: 8.9877 - val_loss: 132.0300 - val_mean_absolute_error: 10.2955
## Epoch 60/100
## 6/6 - 0s - 14ms/step - loss: 122.7888 - mean_absolute_error: 8.8636 - val_loss: 129.3799 - val_mean_absolute_error: 10.1691
## Epoch 61/100
## 6/6 - 0s - 14ms/step - loss: 128.0574 - mean_absolute_error: 8.7179 - val_loss: 126.5291 - val_mean_absolute_error: 10.0299
## Epoch 62/100
## 6/6 - 0s - 14ms/step - loss: 114.5122 - mean_absolute_error: 8.3719 - val_loss: 123.3873 - val_mean_absolute_error: 9.8702
## Epoch 63/100
## 6/6 - 0s - 14ms/step - loss: 114.6687 - mean_absolute_error: 8.3478 - val_loss: 121.3452 - val_mean_absolute_error: 9.7753
## Epoch 64/100
## 6/6 - 0s - 14ms/step - loss: 112.6014 - mean_absolute_error: 8.3505 - val_loss: 118.8064 - val_mean_absolute_error: 9.6466
## Epoch 65/100
## 6/6 - 0s - 14ms/step - loss: 103.2491 - mean_absolute_error: 7.9530 - val_loss: 116.2254 - val_mean_absolute_error: 9.5098
## Epoch 66/100
## 6/6 - 0s - 19ms/step - loss: 101.6294 - mean_absolute_error: 7.8726 - val_loss: 114.1571 - val_mean_absolute_error: 9.4022
## Epoch 67/100
## 6/6 - 0s - 16ms/step - loss: 93.2813 - mean_absolute_error: 7.5609 - val_loss: 111.8671 - val_mean_absolute_error: 9.2826
## Epoch 68/100
## 6/6 - 0s - 18ms/step - loss: 92.0893 - mean_absolute_error: 7.6338 - val_loss: 109.5237 - val_mean_absolute_error: 9.1593
## Epoch 69/100
## 6/6 - 0s - 18ms/step - loss: 91.1658 - mean_absolute_error: 7.4903 - val_loss: 107.3103 - val_mean_absolute_error: 9.0401
## Epoch 70/100
## 6/6 - 0s - 23ms/step - loss: 85.6506 - mean_absolute_error: 7.1149 - val_loss: 104.9072 - val_mean_absolute_error: 8.9061
## Epoch 71/100
## 6/6 - 0s - 18ms/step - loss: 87.7790 - mean_absolute_error: 7.3511 - val_loss: 102.5584 - val_mean_absolute_error: 8.7724
## Epoch 72/100
## 6/6 - 0s - 18ms/step - loss: 81.0432 - mean_absolute_error: 6.9713 - val_loss: 101.0766 - val_mean_absolute_error: 8.7011
## Epoch 73/100
## 6/6 - 0s - 18ms/step - loss: 82.2645 - mean_absolute_error: 6.9941 - val_loss: 98.9647 - val_mean_absolute_error: 8.5949
## Epoch 74/100
## 6/6 - 0s - 21ms/step - loss: 74.1519 - mean_absolute_error: 6.5301 - val_loss: 97.0205 - val_mean_absolute_error: 8.4945
## Epoch 75/100
## 6/6 - 0s - 21ms/step - loss: 75.7695 - mean_absolute_error: 6.9746 - val_loss: 95.5871 - val_mean_absolute_error: 8.4209
## Epoch 76/100
## 6/6 - 0s - 20ms/step - loss: 66.9709 - mean_absolute_error: 6.3546 - val_loss: 93.6590 - val_mean_absolute_error: 8.3129
## Epoch 77/100
## 6/6 - 0s - 19ms/step - loss: 73.5376 - mean_absolute_error: 6.4901 - val_loss: 91.4757 - val_mean_absolute_error: 8.1880
## Epoch 78/100
## 6/6 - 0s - 24ms/step - loss: 71.1550 - mean_absolute_error: 6.3472 - val_loss: 89.4585 - val_mean_absolute_error: 8.0626
## Epoch 79/100
## 6/6 - 0s - 20ms/step - loss: 62.3945 - mean_absolute_error: 5.9646 - val_loss: 87.6243 - val_mean_absolute_error: 7.9555
## Epoch 80/100
## 6/6 - 0s - 22ms/step - loss: 64.1137 - mean_absolute_error: 6.2182 - val_loss: 86.1264 - val_mean_absolute_error: 7.8641
## Epoch 81/100
## 6/6 - 0s - 18ms/step - loss: 64.1134 - mean_absolute_error: 6.0799 - val_loss: 85.2056 - val_mean_absolute_error: 7.8027
## Epoch 82/100
## 6/6 - 0s - 19ms/step - loss: 60.7491 - mean_absolute_error: 5.8326 - val_loss: 83.6794 - val_mean_absolute_error: 7.7050
## Epoch 83/100
## 6/6 - 0s - 18ms/step - loss: 57.1465 - mean_absolute_error: 5.8048 - val_loss: 82.5543 - val_mean_absolute_error: 7.6272
## Epoch 84/100
## 6/6 - 0s - 17ms/step - loss: 57.6039 - mean_absolute_error: 5.7027 - val_loss: 81.8496 - val_mean_absolute_error: 7.5688
## Epoch 85/100
## 6/6 - 0s - 18ms/step - loss: 57.8376 - mean_absolute_error: 5.6051 - val_loss: 80.2347 - val_mean_absolute_error: 7.4543
## Epoch 86/100
## 6/6 - 0s - 16ms/step - loss: 55.4941 - mean_absolute_error: 5.5302 - val_loss: 79.1947 - val_mean_absolute_error: 7.3733
## Epoch 87/100
## 6/6 - 0s - 18ms/step - loss: 50.5229 - mean_absolute_error: 5.3698 - val_loss: 78.0666 - val_mean_absolute_error: 7.3167
## Epoch 88/100
## 6/6 - 0s - 18ms/step - loss: 50.7679 - mean_absolute_error: 5.2179 - val_loss: 78.0236 - val_mean_absolute_error: 7.3333
## Epoch 89/100
## 6/6 - 0s - 16ms/step - loss: 53.5579 - mean_absolute_error: 5.5393 - val_loss: 77.1940 - val_mean_absolute_error: 7.3001
## Epoch 90/100
## 6/6 - 0s - 16ms/step - loss: 49.9652 - mean_absolute_error: 5.2475 - val_loss: 76.0469 - val_mean_absolute_error: 7.2431
## Epoch 91/100
## 6/6 - 0s - 16ms/step - loss: 48.4014 - mean_absolute_error: 5.2141 - val_loss: 75.1038 - val_mean_absolute_error: 7.1985
## Epoch 92/100
## 6/6 - 0s - 16ms/step - loss: 46.5787 - mean_absolute_error: 5.0246 - val_loss: 74.1083 - val_mean_absolute_error: 7.1415
## Epoch 93/100
## 6/6 - 0s - 19ms/step - loss: 48.5440 - mean_absolute_error: 5.2215 - val_loss: 73.3276 - val_mean_absolute_error: 7.0982
## Epoch 94/100
## 6/6 - 0s - 21ms/step - loss: 48.6879 - mean_absolute_error: 5.2063 - val_loss: 72.6031 - val_mean_absolute_error: 7.0567
## Epoch 95/100
## 6/6 - 0s - 22ms/step - loss: 46.4703 - mean_absolute_error: 5.0719 - val_loss: 72.1957 - val_mean_absolute_error: 7.0289
## Epoch 96/100
## 6/6 - 0s - 17ms/step - loss: 45.2099 - mean_absolute_error: 5.0528 - val_loss: 71.8247 - val_mean_absolute_error: 7.0128
## Epoch 97/100
## 6/6 - 0s - 18ms/step - loss: 46.8128 - mean_absolute_error: 5.0841 - val_loss: 71.3840 - val_mean_absolute_error: 6.9940
## Epoch 98/100
## 6/6 - 0s - 18ms/step - loss: 41.9701 - mean_absolute_error: 4.7342 - val_loss: 70.5959 - val_mean_absolute_error: 6.9530
## Epoch 99/100
## 6/6 - 0s - 19ms/step - loss: 41.4004 - mean_absolute_error: 4.7965 - val_loss: 70.8889 - val_mean_absolute_error: 6.9964
## Epoch 100/100
## 6/6 - 0s - 17ms/step - loss: 44.7818 - mean_absolute_error: 4.9413 - val_loss: 70.1961 - val_mean_absolute_error: 6.9670
After training, we evaluate on the test set and print MAE. MAE is directly interpretable: it is the average absolute prediction error.
# Evaluate the model on the test set
# [Correction]: Ensure x_test_reg and y_test_reg are passed as matrices
evaluation <- model_reg$evaluate(
x = as.matrix(x_test),
y = as.matrix(y_test),
verbose = 0
)
# Extract Loss (MSE) and MAE
# Keras 3 returns a named list or vector depending on the backend
loss <- evaluation[[1]]
mae <- evaluation[[2]]
cat("Test Set Mean Squared Error (Loss):", loss, "\n")
## Test Set Mean Squared Error (Loss): 31.45444
cat("Test Set Mean Absolute Error (MAE):", mae, "\n")
## Test Set Mean Absolute Error (MAE): 3.761943
Training curves help diagnose: - convergence (loss decreases smoothly), - overfitting (validation loss stops improving while training loss continues to improve), - underfitting (both losses remain high).
# [Edit]: Convert the Python training history to an R data frame
h_df <- as.data.frame(history_reg$history)
h_df$epoch <- 1:nrow(h_df)
# Set a two-panel layout (Loss and MAE)
par(mfrow = c(1, 2))
# Plot Loss (MSE)
plot(h_df$epoch, h_df$loss, type = "l", col = "blue",
main = "Model Loss (MSE)", xlab = "Epoch", ylab = "Loss")
lines(h_df$epoch, h_df$val_loss, col = "red")
legend("topright", legend = c("Train", "Val"), col = c("blue", "red"), lty = 1)
# Plot MAE
plot(h_df$epoch, h_df$mean_absolute_error, type = "l", col = "blue",
main = "Model MAE", xlab = "Epoch", ylab = "Error")
lines(h_df$epoch, h_df$val_mean_absolute_error, col = "red")
legend("topright", legend = c("Train", "Val"), col = c("blue", "red"), lty = 1)
We predict on test features and compare predictions with observed
values.
A quick head view helps confirm shape and reasonableness.
# Generate predictions
predictions <- model_reg$predict(as.matrix(x_test))
##
## [1m1/4[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m0s[0m 51ms/step
## [1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
head(cbind(predictions,y_test))
## y_test
## [1,] 26.56299 21.6
## [2,] 34.28972 33.4
## [3,] 33.66969 36.2
## [4,] 16.81835 27.1
## [5,] 16.75894 15.0
## [6,] 15.40538 19.9
mean absolute error and root mean square error and
plotingsqrt(mean(error^2)). Here
the code follows the same structure used in earlier chapters.error <- y_test-predictions
head(error)
## [,1]
## [1,] -4.9629883
## [2,] -0.8897224
## [3,] 2.5303146
## [4,] 10.2816483
## [5,] -1.7589359
## [6,] 4.4946226
rmse <- sqrt(mean(error^2))
rmse
## [1] 5.608426
plot(error)
# Create a comparison plot
plot(y_test, predictions,
main = "Actual vs. Predicted House Prices",
xlab = "Actual Value ($1000s)",
ylab = "Predicted Value ($1000s)",
pch = 19, col = adjustcolor("blue", alpha.f = 0.5))
# Add a 45-degree line (Perfect prediction line)
abline(0, 1, col = "red", lwd = 2)
Convolutional neural networks (CNNs) are designed for grid-like data
such as images.
Unlike dense networks, CNNs use: - convolution layers to detect local
patterns (edges, shapes), - pooling layers to reduce dimensionality and
add translation invariance, - dropout to reduce overfitting, - and dense
layers at the end for classification.
We demonstrate CNN using the MNIST handwritten digit dataset.
We load keras.
library(keras)
MNIST is included as a built-in dataset in Keras. It contains: - training images and labels, - test images and labels.
mnist <- dataset_mnist()
trainx, trainy, testx, testyclass(mnist)
## [1] "list"
View the first image in the training set.
# head(mnist)
# Extract the first image from the training set
first_image <- mnist$train$x[1, , ]
# 1. Check dimensions (should be 28 x 28)
print(dim(first_image))
## [1] 28 28
# 2. Inspect pixel value range (0-255)
# Print a small block of values for a quick look
print(first_image[10:15, 10:15])
## [,1] [,2] [,3] [,4] [,5] [,6]
## [1,] 156 107 253 253 205 11
## [2,] 14 1 154 253 90 0
## [3,] 0 0 139 253 190 2
## [4,] 0 0 11 190 253 70
## [5,] 0 0 0 35 241 225
## [6,] 0 0 0 0 81 240
# 3. Get the corresponding label (outcome)
first_label <- mnist$train$y[1]
cat("The true label of this image is:", first_label, "\n")
## The true label of this image is: 5
Visualize the first image in the training set.
# Define a plotting function
plot_mnist_image <- function(image_matrix, title = "") {
# Fix rotation: transpose the matrix and reverse rows
rotated_image <- t(apply(image_matrix, 2, rev))
image(rotated_image,
col = gray.colors(256),
axes = FALSE,
main = title)
}
# Plot the first image
plot_mnist_image(first_image, paste("Label:", first_label))
set.seed(123)
# ---- Training set preparation (1000 samples) ----
idx_train <- sample(nrow(mnist$train$x), 1000)
# x_train_cnn <- array_reshape(mnist$train$x[idx_train,,], c(1000, 28, 28, 1)) / 255
# y_train_cnn <- k_utils$to_categorical(mnist$train$y[idx_train] )
x_train_sample <- mnist$train$x[idx_train, , ]
y_train_sample <- (mnist$train$y[idx_train] )
# ---- Test set preparation (200 samples) ----
# Sample from mnist$test to ensure the model has not seen these images during training
idx_test <- sample(nrow(mnist$test$x), 200)
# x_test_cnn <- array_reshape(mnist$test$x[idx_test,,], c(200, 28, 28, 1)) / 255
# y_test_cnn <- k_utils$to_categorical(mnist$test$y[idx_test] )
x_test_sample <- mnist$test$x[idx_test, , ]
y_test_sample <- (mnist$test$y[idx_test] )
dim(x_train_sample)
## [1] 1000 28 28
dim(y_train_sample)
## [1] 1000
dim(x_test_sample)
## [1] 200 28 28
dim(x_test_sample)
## [1] 200 28 28
img_rows <- 28
img_cols <- 28
array_reshape() function to transform
list data into tensorsx_train_reshaped <- array_reshape(x_train_sample,
c(nrow(x_train_sample),
img_rows,
img_cols, 1))
x_test_reshaped <- array_reshape(x_test_sample,
c(nrow(x_test_sample),
img_rows,
img_cols, 1))
input_shape <- c(img_rows,
img_cols, 1)
dim(x_train_reshaped)
## [1] 1000 28 28 1
x_train_cnn <- x_train_reshaped / 255
x_test_cnn <- x_test_reshaped / 255
to_categorical() functiony_train_cnn <- k_utils$to_categorical(y_train_sample )
y_test_cnn <- k_utils$to_categorical(y_test_sample )
We print the first encoded label as a sanity check. Exactly one element should be 1 and the rest 0.
y_train_cnn[1,]
## [1] 0 0 0 0 0 0 1 0 0 0
This CNN includes: - two convolution layers (feature extraction), - a max pooling layer (downsampling), - dropout (regularization), - flatten to convert 2D feature maps to a vector, - a dense layer, - final softmax output for 10-class classification.
model_cnn <- keras_model_sequential()
model_cnn$add(layer_input(shape = c(28, 28, 1)))
model_cnn$add(layer_conv_2d(filters = 32, kernel_size = c(3,3), activation = 'relu'))
model_cnn$add(layer_max_pooling_2d(pool_size = c(2, 2)))
model_cnn$add(layer_flatten())
model_cnn$add(layer_dense(units = 10, activation = 'softmax'))
model_cnn$summary()
## Model: "sequential_2"
## ┌─────────────────────────────────┬────────────────────────┬───────────────┐
## │ Layer (type) │ Output Shape │ Param # │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ conv2d (Conv2D) │ (None, 26, 26, 32) │ 320 │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ max_pooling2d (MaxPooling2D) │ (None, 13, 13, 32) │ 0 │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ flatten (Flatten) │ (None, 5408) │ 0 │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dense_2 (Dense) │ (None, 10) │ 54,090 │
## └─────────────────────────────────┴────────────────────────┴───────────────┘
## Total params: 54,410 (212.54 KB)
## Trainable params: 54,410 (212.54 KB)
## Non-trainable params: 0 (0.00 B)
categorical crossentropy; the gradient
descent will be optimized by adadelta;model_cnn$compile(
loss = 'categorical_crossentropy',
optimizer = 'adam',
metrics = list('accuracy')
)
We train for a small number of epochs (10) due to the small sampled
dataset and demonstration focus.
In practice, you would tune: - number of epochs, - batch size, -
learning rate, - architecture depth, and you would use a larger training
set.
# Train model
history_cnn <- model_cnn$fit(
x_train_cnn, y_train_cnn,
batch_size = as.integer(128),
epochs = as.integer(10),
validation_split = 0.2
)
## Epoch 1/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m4s[0m 743ms/step - accuracy: 0.0938 - loss: 2.3061
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 50ms/step - accuracy: 0.2837 - loss: 2.1738 - val_accuracy: 0.4450 - val_loss: 1.9331
## Epoch 2/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 35ms/step - accuracy: 0.4531 - loss: 1.9100
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.5750 - loss: 1.7572 - val_accuracy: 0.6900 - val_loss: 1.5350
## Epoch 3/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 36ms/step - accuracy: 0.7500 - loss: 1.4615
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.7750 - loss: 1.3360 - val_accuracy: 0.7550 - val_loss: 1.1616
## Epoch 4/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 40ms/step - accuracy: 0.7969 - loss: 1.1526
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.8150 - loss: 0.9861 - val_accuracy: 0.7600 - val_loss: 0.8896
## Epoch 5/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 39ms/step - accuracy: 0.8047 - loss: 0.8553
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.8138 - loss: 0.7405 - val_accuracy: 0.7900 - val_loss: 0.7259
## Epoch 6/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 34ms/step - accuracy: 0.8516 - loss: 0.6659
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8562 - loss: 0.5915 - val_accuracy: 0.8400 - val_loss: 0.5972
## Epoch 7/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 33ms/step - accuracy: 0.9141 - loss: 0.4653
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8775 - loss: 0.4856 - val_accuracy: 0.8450 - val_loss: 0.5220
## Epoch 8/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 28ms/step - accuracy: 0.8516 - loss: 0.4975
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8850 - loss: 0.4194 - val_accuracy: 0.8650 - val_loss: 0.4574
## Epoch 9/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 30ms/step - accuracy: 0.9375 - loss: 0.3070
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.9050 - loss: 0.3610 - val_accuracy: 0.8700 - val_loss: 0.4206
## Epoch 10/10
##
## [1m1/7[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 30ms/step - accuracy: 0.8828 - loss: 0.3699
## [1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.9100 - loss: 0.3275 - val_accuracy: 0.8750 - val_loss: 0.4019
We plot training history to see whether the model is learning and whether validation performance improves.
# Convert the history object to a data frame for visualization
history_df <- as.data.frame(history_cnn$history)
history_df$epoch <- 1:nrow(history_df)
# Set up the plotting area (1 row, 2 columns)
par(mfrow = c(1, 2))
# Plot Accuracy
plot(history_df$epoch, history_df$accuracy, type = "l", col = "blue",
main = "Model Accuracy", xlab = "Epoch", ylab = "Accuracy", ylim = c(0, 1))
lines(history_df$epoch, history_df$val_accuracy, col = "red")
legend("bottomright", legend = c("Train", "Val"), col = c("blue", "red"), lty = 1)
# Plot Loss
plot(history_df$epoch, history_df$loss, type = "l", col = "blue",
main = "Model Loss", xlab = "Epoch", ylab = "Loss")
lines(history_df$epoch, history_df$val_loss, col = "red")
legend("topright", legend = c("Train", "Val"), col = c("blue", "red"), lty = 1)
We evaluate on the test set and obtain loss and accuracy.
Because we used a small sample (100 test images), the accuracy estimate
will have variability, but it is sufficient to demonstrate the
process.
# Evaluate the model on the test data
# Note: x_test_cnn and y_test_cnn must be pre-processed tensors
evaluation <- model_cnn$evaluate(
x = x_test_cnn,
y = y_test_cnn,
verbose = 0
)
# Extract Loss and Accuracy
# In Keras 3, evaluate returns a vector: [Loss, Accuracy]
cat("Test Loss: ", evaluation[1], "\n")
## Test Loss: 0.4227977
cat("Test Accuracy: ", evaluation[2] * 100, "%\n")
## Test Accuracy: 92.5 %