Analysis Objective

Apply what you learned, present a simple R Markdown document to demonstrate neural network modelling, to identify fashion items in the photos. You can use either mxnet discussed in the later section of this course book or other neural network tools of your choice.

Find the data stored in your course material in fashionmnist directory under data_input. This folder contains train and test set of 10 different categories for 28 x 28 pixel sized fashion images, use the following glossary for your target labels:

categories <- c(“T-shirt”, “Trouser”, “Pullover”, “Dress”, “Coat”, “Sandal”, “Shirt”, “Sneaker”, “Bag”, “Boot”)

Students should be awarded full points if:

Document demonstrates student’s ability in data preparation.

Document shows student’s ability to design neural network layer design for input, hidden, output layer, and activation functions.

Document shows cross validation method and model evaluation.

Libraries and Setup

We’ll set-up caching for this notebook given how computationally expensive some of the code we will write can get.

knitr::opts_chunk$set(cache=TRUE)
options(scipen = 9999)

You will need to use install.packages() to install any packages that are not already downloaded onto your machine. You then load the package into your workspace using the library() function:

library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(tidyr)
library(Hmisc)
## Loading required package: lattice
## Loading required package: survival
## Loading required package: Formula
## Loading required package: ggplot2
## 
## Attaching package: 'Hmisc'
## The following objects are masked from 'package:dplyr':
## 
##     src, summarize
## The following objects are masked from 'package:base':
## 
##     format.pval, units
library(ggplot2)
library(neuralnet)
## 
## Attaching package: 'neuralnet'
## The following object is masked from 'package:dplyr':
## 
##     compute
library(NeuralNetTools)
library(mxnet)
library(caret)
## 
## Attaching package: 'caret'
## The following object is masked from 'package:survival':
## 
##     cluster

We’ll read the data into our environment:

fashionmnist_train <- read.csv("data_input/fashionmnist/train.csv", sep = ",", header = T)
fashionmnist_test <- read.csv("data_input/fashionmnist/test.csv", sep = ",", header = T)
tr_lab <- table(fashionmnist_train$label)
barplot(tr_lab, main="Distribution of Digits in Training Sample", col=gray.colors(10))

range(fashionmnist_train$label)
## [1] 0 9
range(fashionmnist_test$label)
## [1] 0 9

The range of fashionmnist_train and fashionmnist_test represent these categories accordingly :

categories <- c("T-shirt", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Boot")

Our first column is label, and the rest are values between 0 and 9 representing categories of clothes in the corresponding pixel.

colnames(fashionmnist_train)[c(1:5,781:785)]
##  [1] "label"    "pixel1"   "pixel2"   "pixel3"   "pixel4"   "pixel780"
##  [7] "pixel781" "pixel782" "pixel783" "pixel784"
colnames(fashionmnist_test)[c(1:5,781:785)]
##  [1] "label"    "pixel1"   "pixel2"   "pixel3"   "pixel4"   "pixel780"
##  [7] "pixel781" "pixel782" "pixel783" "pixel784"

Putting it together, we now convert the 784 predictor (columns) of our first row into a matrix of 28 x 28 shape, ensure that the elements are numeric (doesn’t make sense to perform matrix computation on non-numeric values) and finally reverse the order of the rows in our matrix:

m_train <- matrix(fashionmnist_train[1,2:ncol(fashionmnist_train)], nrow=28, ncol=28)
m_train <- apply(m_train, 2, as.numeric)
m_train <- apply(m_train, 2, rev)
m_train
##       [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
##  [1,]    0    0    0    0    0    0    0    0    0     0     0     0     1
##  [2,]    0    0    0    0    0    0    0    0    0     0     4     0     0
##  [3,]    0    0    0    0    0    0    0    0    2     0     0     0    13
##  [4,]    0    0    0    0    0    0    1    0    0     0    31   177   255
##  [5,]    0    0    0    0    0    0    0    0   13   175   229   232   216
##  [6,]    0    0    0    0    0    0    0  150  255   231   213   213   209
##  [7,]    0    0    0    0    0   45  254  237  218   213   215   220   214
##  [8,]    0    0    0    0    0  254  240  211  199   206   212   207   208
##  [9,]    0    0    0    0  135  249  221  226  206   217   207   223   226
## [10,]    0    0    0    0  255  231  223  213  214   210   224   197   183
## [11,]    0    0    0   61  222  224  217  219  215   211   228   162   158
## [12,]    0    0    0  136  235  214  221  208  210   213   232   128   147
## [13,]    0    0    0   51  255  207  219  221  213   210   228   139   121
## [14,]    0    0    0   23  137  180  233  211  214   209   233   156   100
## [15,]    0    0    0   29   62  108  254  208  213   209   224   147   138
## [16,]    0    0    0   21  115  229  220  210  214   210   215   195   186
## [17,]    0    0    0   61  255  215  217  214  206   215   222   136   103
## [18,]    0    0    0   62  225  215  225  216  213   203   230   139    90
## [19,]    0    0    0    0  228  224  224  207  212   208   217   193   158
## [20,]    0    0    0    0  201  238  213  224  210   211   204   225   228
## [21,]    0    0    0    0   88  234  210  224  207   200   205   211   205
## [22,]    0    0    0    0    0  252  222  207  215   218   206   203   209
## [23,]    0    0    0    0    0   47  214  237  222   210   215   212   211
## [24,]    0    0    0    4    0    0    0  128  237   228   224   212   207
## [25,]    0    0    0    0    0    0    0    0    0    85   217   225   226
## [26,]    0    0    0    0    0    0    1    0    0     0     0    21   123
## [27,]    0    0    0    0    0    0    0    0    2     4     0     0     0
## [28,]    0    0    0    0    0    0    0    1    0     0     0     1     0
##       [,14] [,15] [,16] [,17] [,18] [,19] [,20] [,21] [,22] [,23] [,24]
##  [1,]     0     0     0     0    25   114    26     0     0     0     0
##  [2,]     0     0   107   230   218   236   181     0     0     2     0
##  [3,]   148   254   234   222   212   203   230    85     0     3     0
##  [4,]   233   213   208   209   214   202   234   188     0     3     0
##  [5,]   211   218   213   208   208   255   224     0     0     3     2
##  [6,]   213   210   214   214   228   119     0     0     0     0     0
##  [7,]   213   217   214   215   201   139   242   238   165   155    95
##  [8,]   216   223   214   218   215   245   221   221   226   222   236
##  [9,]   219   224   213   211   205   210   206   207   201   200   212
## [10,]   200   221   207   203   212   210   210   215   212   203   220
## [11,]   186   219   210   212   214   207   215   206   209   203   221
## [12,]   188   220   205   206   216   210   211   207   209   204   217
## [13,]   168   226   203   210   216   209   209   206   209   203   219
## [14,]   164   233   197   212   216   208   206   205   211   202   219
## [15,]   170   224   201   213   216   207   204   206   211   201   220
## [16,]   193   213   201   217   209   199   206   216   209   199   222
## [17,]   150   221   208   215   201   202   211   213   208   198   224
## [18,]   156   229   207   204   202   211   210   207   208   196   228
## [19,]   184   230   198   197   207   210   205   206   208   192   232
## [20,]   216   221   205   202   208   206   207   207   207   199   226
## [21,]   205   222   215   213   210   205   209   200   190   199   218
## [22,]   206   218   209   209   210   208   217   250   253   254   255
## [23,]   208   211   211   210   213   213   216   115   129    89     0
## [24,]   202   214   207   211   208   208   220   157     0     0     0
## [25,]   219   204   203   204   200   206   203   255    31     0     5
## [26,]   226   227   226   213   207   201   212   201     0     0     1
## [27,]     0    45   157   235   255   217   238   145     0     0     0
## [28,]     0     0     0     0    52   118   171    39     0     2     0
##       [,25] [,26] [,27] [,28]
##  [1,]     0     0     0     0
##  [2,]     0     0     0     0
##  [3,]     1     0     0     0
##  [4,]     0     0     0     0
##  [5,]     1     0     0     0
##  [6,]     0     0     0     0
##  [7,]     0     0     0     0
##  [8,]   180     0     0     0
##  [9,]   161     0     0     0
## [10,]   167     0     0     0
## [11,]   171     0     0     0
## [12,]   172     0     0     0
## [13,]   175     0     0     0
## [14,]   177     0     0     0
## [15,]   179     0     0     0
## [16,]   173     0     0     0
## [17,]   173     0     0     0
## [18,]   171     0     0     0
## [19,]   170     0     0     0
## [20,]   168     0     0     0
## [21,]   194     0     0     0
## [22,]   155     0     0     0
## [23,]     0     0     0     0
## [24,]     0     0     0     0
## [25,]     0     0     0     0
## [26,]     0     0     0     0
## [27,]     0     0     0     0
## [28,]     0     0     0     0
m_test <- matrix(fashionmnist_test[1,2:ncol(fashionmnist_test)], nrow=28, ncol=28)
m_test <- apply(m_test, 2, as.numeric)
m_test <- apply(m_test, 2, rev)
m_test
##       [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
##  [1,]    0    0    0    0    0    0    0    0    0     0     0     0     0
##  [2,]    0    0    0    0    0    0    0    0    0     0     0     0     0
##  [3,]    0    0    0    0    2    1    0    0    0     0     0    21     1
##  [4,]    1    0    1    1    0    0    0    0   29   123   179   172    28
##  [5,]    1    1    0    0    0   33  133  226  221   224   219   216    56
##  [6,]    0    1    0   20  173  229  231  219  208   204   203   220   106
##  [7,]    1    1    0  184  249  205  208  212  217   214   211   227   177
##  [8,]    3    0    0  203  216  194  207  237  237   224   220   236   236
##  [9,]    3    0   37  236  205  219  234  165  173   234   208   207   225
## [10,]    0    0  157  238  200  249   50   54   67   112   245   218   218
## [11,]    0    0  212  236  212  226    0   72  113   122   231   198   212
## [12,]   24   80   14  133  223  232  186   68  119   230   203   232   210
## [13,]   11  123   61   74  232  204  236  247  232   203   207   129   192
## [14,]    0  122  210  113  251  194  188  200  190   206   193    80   186
## [15,]    7   99   94   61  237  185  210  200  191   224   138   127   207
## [16,]   29   88    0  125  255  197  191  201  192   214   162   114   211
## [17,]   34   44   67  214  216  199  216  219  205   193   213   246   197
## [18,]    0    0  249  238  206  206  226   64  138   226   184   213   187
## [19,]    0    0  174  239  205  201  233  172  165   231   200   221   144
## [20,]    8    0   46  242  207  200  195  229  207   212   205   201   224
## [21,]    9    0    3  206  223  197  192  194  211   207   203   219   224
## [22,]    0    0    0  168  245  201  208  216  220   219   211   233   160
## [23,]    0    1    0   23  175  225  223  218  207   200   204   218    64
## [24,]    0    0    0    0    0   53  133  216  221   216   212   191     5
## [25,]    0    0    2    2    0    0    0    0   50   131   195   185     4
## [26,]    0    4    1    2    1    7    1    0    0     0     0     8    21
## [27,]    0    0    0    0    0    0    0    0    0     0     0     0     0
## [28,]    0    0    0    0    0    0    0    0    0     0     0     0     0
##       [,14] [,15] [,16] [,17] [,18] [,19] [,20] [,21] [,22] [,23] [,24]
##  [1,]     0     0     0     0     0     0     0     0     0     0     0
##  [2,]     0     0     0     0     0     0     0     0     0     0     0
##  [3,]     2     0     2     0     1     0     0     1     1     0     0
##  [4,]     4     0     1     0     0     0     0     1     1     0     0
##  [5,]     0     2     2     0     0     0     0     0     1     0     2
##  [6,]     3     2     4     4     3     4     2     2     2     0     1
##  [7,]     1     0     0     0     0     0     0     0     0     0     0
##  [8,]   118    95   102   112   119   128   139   148   157   181   187
##  [9,]   252   236   237   231   227   229   230   230   230   226   229
## [10,]   162   174   203   207   206   206   205   204   204   204   207
## [11,]    64   140   237   199   210   208   211   211   212   207   213
## [12,]    76   175   226   203   207   206   206   208   207   206   210
## [13,]    79   125   231   192   206   204   204   205   206   205   207
## [14,]   123   159   224   198   203   201   204   203   203   203   205
## [15,]    70   161   217   195   200   200   200   201   203   203   204
## [16,]    73   165   220   198   201   201   201   201   205   199   199
## [17,]    51   155   221   194   199   200   201   204   203   200   208
## [18,]   120   170   217   199   206   204   205   204   203   206   201
## [19,]    96   145   204   201   199   198   195   194   194   195   197
## [20,]   252   226   225   223   223   223   223   221   220   221   224
## [21,]   116   115   131   135   134   139   145   157   166   171   165
## [22,]     0     0     0     0     0     0     0     0     0     0     0
## [23,]     2     0     2     3     1     1     1     0     0     0     0
## [24,]     0     0     1     0     0     0     0     1     1     0     1
## [25,]     1     0     0     1     1     0     0     0     1     0     0
## [26,]     1     0     0     1     1     0     0     1     1     0     0
## [27,]     0     0     0     0     0     0     0     0     0     0     0
## [28,]     0     0     0     0     0     0     0     0     0     0     0
##       [,25] [,26] [,27] [,28]
##  [1,]     0     0     0     0
##  [2,]     0     0     0     0
##  [3,]     0     0     0     0
##  [4,]     0     0     0     0
##  [5,]     2     0     0     0
##  [6,]     0     0     3     0
##  [7,]     0     0     0     0
##  [8,]   191   161   234    56
##  [9,]   225   218   245    87
## [10,]   205   199   224   103
## [11,]   213   199   245   125
## [12,]   210   195   242   135
## [13,]   206   193   236   145
## [14,]   205   191   232   142
## [15,]   203   187   231   138
## [16,]   211   187   239   138
## [17,]   207   194   233   127
## [18,]   201   188   239   123
## [19,]   203   191   213    94
## [20,]   201   201   240    69
## [21,]   128   141   212    37
## [22,]     0     0     0     0
## [23,]     0     1     1     0
## [24,]     0     1     0     0
## [25,]     0     0     0     0
## [26,]     0     0     0     0
## [27,]     0     0     0     0
## [28,]     0     0     0     0
image(1:28, 1:28, m_train)
text(5, 2, col="white", cex=1.2, fashionmnist_train[1, 1])

range(fashionmnist_train[1,2:ncol(fashionmnist_train)])
## [1]   0 255
vizTrain <- function(input){
  
  dimmax <- sqrt(ncol(fashionmnist_train[,-1]))
  
  dimn <- ceiling(sqrt(nrow(input)))
  par(mfrow=c(dimn, dimn), mar=c(.1, .1, .1, .1))
  
  for (i in 1:nrow(input)){
      m1 <- matrix(input[i,2:ncol(input)], nrow=dimmax, byrow=T)
      m1 <- apply(m1, 2, as.numeric)
      m1 <- t(apply(m1, 2, rev))
      
      image(1:dimmax, 1:dimmax, m1, col=grey.colors(255), xaxt = 'n', yaxt = 'n')
      text(5, 20, col="red", cex=1.2, fashionmnist_train[i, 1])
  }
  
}
vizTrain(fashionmnist_train[1:25,])

vizTrain(fashionmnist_train[255,])

We make the neural network architecture (neural network design, how many nodes we will use)

mxnet has its own “grammar” in defining a neural network’s architecture. Specifically, it made use of what it termed as the symbolic API, starting with a placeholder data symbol.

In the following code, we’re assembling a network with the following architecture: data --> 128 units ReLU --> 64 units ReLU --> 10 units output softMax

# notice how our layers are fully-connected and cascading
m1f.data  <- mx.symbol.Variable("data") #define data 

# num_hidden : to specify our hidden layers
m1f.fc1 <- mx.symbol.FullyConnected(m1f.data, 
                                    name = "fc1", 
                                    num_hidden = 128) 

# act_type : activation function used
m1f.act1 <- mx.symbol.Activation(m1f.fc1, 
                                 name = "activation1", 
                                 act_type = "relu")

m1f.fc2 <- mx.symbol.FullyConnected(m1f.act1, 
                                    name = "fc2", 
                                    num_hidden = 64)

m1f.act2 <- mx.symbol.Activation(m1f.fc2, 
                                 name = "activation2", 
                                 act_type = "relu")

m1f.fc3 <- mx.symbol.FullyConnected(m1f.act2, 
                                    name = "fc3", 
                                    num_hidden = 10)

m1f.softmax <- mx.symbol.SoftmaxOutput(m1f.fc3, 
                                       name="softMax")
# putting train and test data into the architecture
train <- data.matrix(fashionmnist_train)
test <- data.matrix(fashionmnist_test)

# remove 'label' from the data
# train_x: result in 42000 x 784 matrix range 0 to 255
trainf_x <- train[,-1]
trainf_y <- train[,1] 

testf_x <- test[,-1]
testf_y <- test[,1] 

# train_x: result in 784 x 42000 matrix range 0 to 1 (divide by 255)
# each of the 784 is treated as input (predictors), hence the transpose
trainf_x <- t(trainf_x/255) # normalized it with the maximum value
testf_x <- t(testf_x/255)
# How to make the model

log <- mx.metric.logger$new()
startime <- proc.time() 
mx.set.seed(0)

m1f <- mx.model.FeedForward.create(m1f.softmax,  # the network configuration made above
                                     X = trainf_x, # input (predictors)
                                     y = trainf_y, # the labels
                                     ctx = mx.cpu(), # use cpu not gpu 
                                     num.round = 50,
                                     array.batch.size = 80,
                                     momentum = 0.95,
                                     array.layout="colmajor",
                                     learning.rate = 0.001,
                                     eval.metric = mx.metric.accuracy,
                                     epoch.end.callback = mx.callback.log.train.metric(1,log)
)
## Start training with 1 devices
## [1] Train-accuracy=0.101050000359615
## [2] Train-accuracy=0.162933333282669
## [3] Train-accuracy=0.405249999329448
## [4] Train-accuracy=0.608633332808812
## [5] Train-accuracy=0.695250000476837
## [6] Train-accuracy=0.728516666253408
## [7] Train-accuracy=0.74975
## [8] Train-accuracy=0.769633332808812
## [9] Train-accuracy=0.786516666889191
## [10] Train-accuracy=0.799883333683014
## [11] Train-accuracy=0.810099999666214
## [12] Train-accuracy=0.818133334318797
## [13] Train-accuracy=0.823550000111262
## [14] Train-accuracy=0.830083332935969
## [15] Train-accuracy=0.836250000397364
## [16] Train-accuracy=0.84138333272934
## [17] Train-accuracy=0.846416666348775
## [18] Train-accuracy=0.850599999745687
## [19] Train-accuracy=0.854599999030431
## [20] Train-accuracy=0.858049999237061
## [21] Train-accuracy=0.861733332633972
## [22] Train-accuracy=0.864566665331523
## [23] Train-accuracy=0.867333331902822
## [24] Train-accuracy=0.869883333206177
## [25] Train-accuracy=0.87263333328565
## [26] Train-accuracy=0.874733333587646
## [27] Train-accuracy=0.87686666782697
## [28] Train-accuracy=0.878650000969569
## [29] Train-accuracy=0.880566666523615
## [30] Train-accuracy=0.882366666873296
## [31] Train-accuracy=0.884283332983653
## [32] Train-accuracy=0.885450000127157
## [33] Train-accuracy=0.886966666936874
## [34] Train-accuracy=0.888333333015442
## [35] Train-accuracy=0.889450000127157
## [36] Train-accuracy=0.89061666671435
## [37] Train-accuracy=0.892016667048136
## [38] Train-accuracy=0.893349999984105
## [39] Train-accuracy=0.894633333683014
## [40] Train-accuracy=0.896166667222977
## [41] Train-accuracy=0.897200001319249
## [42] Train-accuracy=0.898316667636236
## [43] Train-accuracy=0.899533334175746
## [44] Train-accuracy=0.900533333619436
## [45] Train-accuracy=0.901416666984558
## [46] Train-accuracy=0.902383333762487
## [47] Train-accuracy=0.903216666062673
## [48] Train-accuracy=0.904500000079473
## [49] Train-accuracy=0.905783333539963
## [50] Train-accuracy=0.906750001033147
print(paste("Training took:", round((proc.time() - startime)[3],2),"seconds"))
## [1] "Training took: 250.98 seconds"

We train our model to get 90% accuracy.

fashionmnist_train$label <- factor(fashionmnist_train$label, labels = categories)
fashionmnist_test$label <- factor(fashionmnist_test$label, labels = categories)
m1f_pred <- predict(m1f,
                    testf_x,
                    array.layout = "colmajor")
t(round(m1f_pred[,1:5], 2))
##      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
## [1,] 0.85    0 0.00 0.00 0.00    0 0.15    0    0     0
## [2,] 0.00    1 0.00 0.00 0.00    0 0.00    0    0     0
## [3,] 0.03    0 0.79 0.00 0.01    0 0.17    0    0     0
## [4,] 0.23    0 0.68 0.00 0.00    0 0.09    0    0     0
## [5,] 0.00    0 0.00 0.18 0.82    0 0.00    0    0     0
m1f_pred_result <- max.col(t(m1f_pred)) -1 
m1f_pred_result <- factor(m1f_pred_result, labels = categories)
head(m1f_pred_result)
## [1] T-shirt  Trouser  Pullover Pullover Coat     Shirt   
## 10 Levels: T-shirt Trouser Pullover Dress Coat Sandal Shirt ... Boot
table(fashionmnist_test$label, m1f_pred_result)
##           m1f_pred_result
##            T-shirt Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag
##   T-shirt      878       1       17     9    0      2    77       0  16
##   Trouser        5     982        1     6    1      0     5       0   0
##   Pullover      18       1      808     6   97      1    67       0   2
##   Dress         43      19       14   863   47      0    14       0   0
##   Coat           1       2       62    14  866      2    50       0   3
##   Sandal         1       0        0     0    0    950     0      22   3
##   Shirt        167       3       61    12   64      1   684       0   8
##   Sneaker        0       0        0     0    0     47     0     838   1
##   Bag            6       0       11     2    4      3     8       1 963
##   Boot           0       0        1     0    0      6     0      11   2
##           m1f_pred_result
##            Boot
##   T-shirt     0
##   Trouser     0
##   Pullover    0
##   Dress       0
##   Coat        0
##   Sandal     24
##   Shirt       0
##   Sneaker   114
##   Bag         2
##   Boot      980
head(m1f_pred_result)
## [1] T-shirt  Trouser  Pullover Pullover Coat     Shirt   
## 10 Levels: T-shirt Trouser Pullover Dress Coat Sandal Shirt ... Boot
caret::confusionMatrix(data = m1f_pred_result,
                       reference = fashionmnist_test$label)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction T-shirt Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag
##   T-shirt      878       5       18    43    1      1   167       0   6
##   Trouser        1     982        1    19    2      0     3       0   0
##   Pullover      17       1      808    14   62      0    61       0  11
##   Dress          9       6        6   863   14      0    12       0   2
##   Coat           0       1       97    47  866      0    64       0   4
##   Sandal         2       0        1     0    2    950     1      47   3
##   Shirt         77       5       67    14   50      0   684       0   8
##   Sneaker        0       0        0     0    0     22     0     838   1
##   Bag           16       0        2     0    3      3     8       1 963
##   Boot           0       0        0     0    0     24     0     114   2
##           Reference
## Prediction Boot
##   T-shirt     0
##   Trouser     0
##   Pullover    1
##   Dress       0
##   Coat        0
##   Sandal      6
##   Shirt       0
##   Sneaker    11
##   Bag         2
##   Boot      980
## 
## Overall Statistics
##                                                
##                Accuracy : 0.8812               
##                  95% CI : (0.8747, 0.8875)     
##     No Information Rate : 0.1                  
##     P-Value [Acc > NIR] : < 0.00000000000000022
##                                                
##                   Kappa : 0.868                
##  Mcnemar's Test P-Value : NA                   
## 
## Statistics by Class:
## 
##                      Class: T-shirt Class: Trouser Class: Pullover
## Sensitivity                  0.8780         0.9820          0.8080
## Specificity                  0.9732         0.9971          0.9814
## Pos Pred Value               0.7846         0.9742          0.8287
## Neg Pred Value               0.9863         0.9980          0.9787
## Prevalence                   0.1000         0.1000          0.1000
## Detection Rate               0.0878         0.0982          0.0808
## Detection Prevalence         0.1119         0.1008          0.0975
## Balanced Accuracy            0.9256         0.9896          0.8947
##                      Class: Dress Class: Coat Class: Sandal Class: Shirt
## Sensitivity                0.8630      0.8660        0.9500       0.6840
## Specificity                0.9946      0.9763        0.9931       0.9754
## Pos Pred Value             0.9463      0.8026        0.9387       0.7558
## Neg Pred Value             0.9849      0.9850        0.9944       0.9653
## Prevalence                 0.1000      0.1000        0.1000       0.1000
## Detection Rate             0.0863      0.0866        0.0950       0.0684
## Detection Prevalence       0.0912      0.1079        0.1012       0.0905
## Balanced Accuracy          0.9288      0.9212        0.9716       0.8297
##                      Class: Sneaker Class: Bag Class: Boot
## Sensitivity                  0.8380     0.9630      0.9800
## Specificity                  0.9962     0.9961      0.9844
## Pos Pred Value               0.9610     0.9649      0.8750
## Neg Pred Value               0.9823     0.9959      0.9977
## Prevalence                   0.1000     0.1000      0.1000
## Detection Rate               0.0838     0.0963      0.0980
## Detection Prevalence         0.0872     0.0998      0.1120
## Balanced Accuracy            0.9171     0.9796      0.9822

We still get 88.12% accuracy from the unseen data, which is pretty good, very close to the model we build with the train dataset.

plotResults <- function(images, preds){

  x <- ceiling(sqrt(length(images)))
  par(mfrow = c(x,x), 
      mar = c(.1,.1,.1,.1))
  
  for (i in images){
    m <- matrix(test[i, 2:785], 
                nrow = 28, 
                byrow = TRUE)
    m <- apply(m, 2, rev)
    image(t(m), 
          col = grey.colors(255), 
          axes = FALSE)
    text(0.5,0.1,col="red", cex=1.2, preds[i])
  }

}

plotResults(1:49,
            m1f_pred_result)