## https://tensorflow.rstudio.com/install/
## remotes::install_github("rstudio/tensorflow")
## reticulate::install_python()

library(tensorflow)
### install_tensorflow(envname = "r-tensorflow")

## install.packages("keras")
library(keras)
### install_keras()
library(ggplot2)

# Load the Fashion MNIST dataset
mnist <- dataset_fashion_mnist() 

# Create training and testing datasets
train_images <- mnist$train$x  # 60,000x28x28 tensor for training images
train_labels <- mnist$train$y  # 60,000-element vector for training labels
test_images <- mnist$test$x    # 10,000x28x28 tensor for test images
test_labels <- mnist$test$y    # 10,000-element vector for test labels

# Plot the 100th training image
digit <- train_images[100,,]   
plot(as.raster(digit, max = 255)) 

# Reshape input data for CNN (adding a channel dimension: 28x28x1)
train_images <- array_reshape(train_images, c(60000, 28, 28, 1))
test_images <- array_reshape(test_images, c(10000, 28, 28, 1))

# Normalize the pixel values to [0,1]
train_images <- train_images / 255
test_images <- test_images / 255

# One-hot encode the labels
train_labels <- to_categorical(train_labels, 10) 
test_labels <- to_categorical(test_labels, 10) 

# Build the CNN model
network <- keras_model_sequential() %>%  
  # First convolutional layer with 32 filters, 3x3 kernel, ReLU activation, and same padding
  layer_conv_2d(filters = 32, kernel_size = c(3, 3), activation = 'relu', 
                input_shape = c(28, 28, 1), padding = 'same') %>% 
  
  # Max-pooling layer with 2x2 pool size
  layer_max_pooling_2d(pool_size = c(2, 2)) %>% 
  
  # Second convolutional layer with 64 filters and ReLU activation
  layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = 'relu', padding = 'same') %>% 
  layer_max_pooling_2d(pool_size = c(2, 2)) %>% 
  
  # Flatten the feature maps into a 1D vector for the dense layers
  layer_flatten() %>% 

  # Fully connected dense layer with 128 units and ReLU activation
  layer_dense(units = 128, activation = 'relu') %>% 
  
  # Output layer with 10 units (one for each class) and softmax activation
  layer_dense(units = 10, activation = 'softmax')

# Compile the model
network %>% compile(
  optimizer = "rmsprop", 
  loss = "categorical_crossentropy", 
  metrics = c("accuracy")
)

# Check the dimensions to verify correctness
dim(train_images) 
## [1] 60000    28    28     1
dim(test_images) 
## [1] 10000    28    28     1
dim(train_labels) 
## [1] 60000    10
# Train the CNN model
history <- network %>% fit(
  train_images, train_labels, 
  epochs = 10, 
  batch_size = 1000, 
  validation_split = 0.1
)
## Epoch 1/10
## 54/54 - 44s - loss: 0.9846 - accuracy: 0.6490 - val_loss: 0.5675 - val_accuracy: 0.7882 - 44s/epoch - 817ms/step
## Epoch 2/10
## 54/54 - 37s - loss: 0.5570 - accuracy: 0.7929 - val_loss: 0.5161 - val_accuracy: 0.7982 - 37s/epoch - 677ms/step
## Epoch 3/10
## 54/54 - 39s - loss: 0.4529 - accuracy: 0.8346 - val_loss: 0.4689 - val_accuracy: 0.8152 - 39s/epoch - 731ms/step
## Epoch 4/10
## 54/54 - 35s - loss: 0.4012 - accuracy: 0.8537 - val_loss: 0.3832 - val_accuracy: 0.8610 - 35s/epoch - 653ms/step
## Epoch 5/10
## 54/54 - 35s - loss: 0.3631 - accuracy: 0.8677 - val_loss: 0.3668 - val_accuracy: 0.8668 - 35s/epoch - 653ms/step
## Epoch 6/10
## 54/54 - 35s - loss: 0.3432 - accuracy: 0.8749 - val_loss: 0.3332 - val_accuracy: 0.8775 - 35s/epoch - 657ms/step
## Epoch 7/10
## 54/54 - 38s - loss: 0.3218 - accuracy: 0.8825 - val_loss: 0.3234 - val_accuracy: 0.8812 - 38s/epoch - 709ms/step
## Epoch 8/10
## 54/54 - 39s - loss: 0.3037 - accuracy: 0.8883 - val_loss: 0.3324 - val_accuracy: 0.8773 - 39s/epoch - 728ms/step
## Epoch 9/10
## 54/54 - 35s - loss: 0.2907 - accuracy: 0.8923 - val_loss: 0.3006 - val_accuracy: 0.8922 - 35s/epoch - 653ms/step
## Epoch 10/10
## 54/54 - 43s - loss: 0.2787 - accuracy: 0.8968 - val_loss: 0.3080 - val_accuracy: 0.8838 - 43s/epoch - 799ms/step
# Plot the training history
plot(history) + theme_bw()

# Evaluate the model on test data
network %>% evaluate(test_images, test_labels)
## 313/313 - 2s - loss: 0.3292 - accuracy: 0.8800 - 2s/epoch - 8ms/step
##      loss  accuracy 
## 0.3291659 0.8800000