Apply what you learned, present a simple R Markdown document to demonstrate neural network modelling, to identify fashion items in the photos. You can use either mxnet discussed in the later section of this course book or other neural network tools of your choice.
Find the data stored in your course material in fashionmnist directory under data_input. This folder contains train and test set of 10 different categories for 28 x 28 pixel sized fashion images, use the following glossary for your target labels:
categories <- c(“T-shirt”, “Trouser”, “Pullover”, “Dress”, “Coat”, “Sandal”, “Shirt”, “Sneaker”, “Bag”, “Boot”)
Students should be awarded full points if:
Document demonstrates student’s ability in data preparation.
Document shows student’s ability to design neural network layer design for input, hidden, output layer, and activation functions.
Document shows cross validation method and model evaluation.
We’ll set-up caching for this notebook given how computationally expensive some of the code we will write can get.
knitr::opts_chunk$set(cache=TRUE)
options(scipen = 9999)
You will need to use install.packages()
to install any packages that are not already downloaded onto your machine. You then load the package into your workspace using the library()
function:
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(tidyr)
library(Hmisc)
## Loading required package: lattice
## Loading required package: survival
## Loading required package: Formula
## Loading required package: ggplot2
##
## Attaching package: 'Hmisc'
## The following objects are masked from 'package:dplyr':
##
## src, summarize
## The following objects are masked from 'package:base':
##
## format.pval, units
library(ggplot2)
library(neuralnet)
##
## Attaching package: 'neuralnet'
## The following object is masked from 'package:dplyr':
##
## compute
library(NeuralNetTools)
library(mxnet)
library(caret)
##
## Attaching package: 'caret'
## The following object is masked from 'package:survival':
##
## cluster
We’ll read the data into our environment:
fashionmnist_train <- read.csv("data_input/fashionmnist/train.csv", sep = ",", header = T)
fashionmnist_test <- read.csv("data_input/fashionmnist/test.csv", sep = ",", header = T)
tr_lab <- table(fashionmnist_train$label)
barplot(tr_lab, main="Distribution of Digits in Training Sample", col=gray.colors(10))
range(fashionmnist_train$label)
## [1] 0 9
range(fashionmnist_test$label)
## [1] 0 9
The range of fashionmnist_train
and fashionmnist_test
represent these categories accordingly :
categories <- c("T-shirt", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Boot")
Our first column is label
, and the rest are values between 0 and 9 representing categories of clothes in the corresponding pixel.
colnames(fashionmnist_train)[c(1:5,781:785)]
## [1] "label" "pixel1" "pixel2" "pixel3" "pixel4" "pixel780"
## [7] "pixel781" "pixel782" "pixel783" "pixel784"
colnames(fashionmnist_test)[c(1:5,781:785)]
## [1] "label" "pixel1" "pixel2" "pixel3" "pixel4" "pixel780"
## [7] "pixel781" "pixel782" "pixel783" "pixel784"
Putting it together, we now convert the 784 predictor (columns) of our first row into a matrix of 28 x 28 shape, ensure that the elements are numeric (doesn’t make sense to perform matrix computation on non-numeric values) and finally reverse the order of the rows in our matrix:
m_train <- matrix(fashionmnist_train[1,2:ncol(fashionmnist_train)], nrow=28, ncol=28)
m_train <- apply(m_train, 2, as.numeric)
m_train <- apply(m_train, 2, rev)
m_train
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
## [1,] 0 0 0 0 0 0 0 0 0 0 0 0 1
## [2,] 0 0 0 0 0 0 0 0 0 0 4 0 0
## [3,] 0 0 0 0 0 0 0 0 2 0 0 0 13
## [4,] 0 0 0 0 0 0 1 0 0 0 31 177 255
## [5,] 0 0 0 0 0 0 0 0 13 175 229 232 216
## [6,] 0 0 0 0 0 0 0 150 255 231 213 213 209
## [7,] 0 0 0 0 0 45 254 237 218 213 215 220 214
## [8,] 0 0 0 0 0 254 240 211 199 206 212 207 208
## [9,] 0 0 0 0 135 249 221 226 206 217 207 223 226
## [10,] 0 0 0 0 255 231 223 213 214 210 224 197 183
## [11,] 0 0 0 61 222 224 217 219 215 211 228 162 158
## [12,] 0 0 0 136 235 214 221 208 210 213 232 128 147
## [13,] 0 0 0 51 255 207 219 221 213 210 228 139 121
## [14,] 0 0 0 23 137 180 233 211 214 209 233 156 100
## [15,] 0 0 0 29 62 108 254 208 213 209 224 147 138
## [16,] 0 0 0 21 115 229 220 210 214 210 215 195 186
## [17,] 0 0 0 61 255 215 217 214 206 215 222 136 103
## [18,] 0 0 0 62 225 215 225 216 213 203 230 139 90
## [19,] 0 0 0 0 228 224 224 207 212 208 217 193 158
## [20,] 0 0 0 0 201 238 213 224 210 211 204 225 228
## [21,] 0 0 0 0 88 234 210 224 207 200 205 211 205
## [22,] 0 0 0 0 0 252 222 207 215 218 206 203 209
## [23,] 0 0 0 0 0 47 214 237 222 210 215 212 211
## [24,] 0 0 0 4 0 0 0 128 237 228 224 212 207
## [25,] 0 0 0 0 0 0 0 0 0 85 217 225 226
## [26,] 0 0 0 0 0 0 1 0 0 0 0 21 123
## [27,] 0 0 0 0 0 0 0 0 2 4 0 0 0
## [28,] 0 0 0 0 0 0 0 1 0 0 0 1 0
## [,14] [,15] [,16] [,17] [,18] [,19] [,20] [,21] [,22] [,23] [,24]
## [1,] 0 0 0 0 25 114 26 0 0 0 0
## [2,] 0 0 107 230 218 236 181 0 0 2 0
## [3,] 148 254 234 222 212 203 230 85 0 3 0
## [4,] 233 213 208 209 214 202 234 188 0 3 0
## [5,] 211 218 213 208 208 255 224 0 0 3 2
## [6,] 213 210 214 214 228 119 0 0 0 0 0
## [7,] 213 217 214 215 201 139 242 238 165 155 95
## [8,] 216 223 214 218 215 245 221 221 226 222 236
## [9,] 219 224 213 211 205 210 206 207 201 200 212
## [10,] 200 221 207 203 212 210 210 215 212 203 220
## [11,] 186 219 210 212 214 207 215 206 209 203 221
## [12,] 188 220 205 206 216 210 211 207 209 204 217
## [13,] 168 226 203 210 216 209 209 206 209 203 219
## [14,] 164 233 197 212 216 208 206 205 211 202 219
## [15,] 170 224 201 213 216 207 204 206 211 201 220
## [16,] 193 213 201 217 209 199 206 216 209 199 222
## [17,] 150 221 208 215 201 202 211 213 208 198 224
## [18,] 156 229 207 204 202 211 210 207 208 196 228
## [19,] 184 230 198 197 207 210 205 206 208 192 232
## [20,] 216 221 205 202 208 206 207 207 207 199 226
## [21,] 205 222 215 213 210 205 209 200 190 199 218
## [22,] 206 218 209 209 210 208 217 250 253 254 255
## [23,] 208 211 211 210 213 213 216 115 129 89 0
## [24,] 202 214 207 211 208 208 220 157 0 0 0
## [25,] 219 204 203 204 200 206 203 255 31 0 5
## [26,] 226 227 226 213 207 201 212 201 0 0 1
## [27,] 0 45 157 235 255 217 238 145 0 0 0
## [28,] 0 0 0 0 52 118 171 39 0 2 0
## [,25] [,26] [,27] [,28]
## [1,] 0 0 0 0
## [2,] 0 0 0 0
## [3,] 1 0 0 0
## [4,] 0 0 0 0
## [5,] 1 0 0 0
## [6,] 0 0 0 0
## [7,] 0 0 0 0
## [8,] 180 0 0 0
## [9,] 161 0 0 0
## [10,] 167 0 0 0
## [11,] 171 0 0 0
## [12,] 172 0 0 0
## [13,] 175 0 0 0
## [14,] 177 0 0 0
## [15,] 179 0 0 0
## [16,] 173 0 0 0
## [17,] 173 0 0 0
## [18,] 171 0 0 0
## [19,] 170 0 0 0
## [20,] 168 0 0 0
## [21,] 194 0 0 0
## [22,] 155 0 0 0
## [23,] 0 0 0 0
## [24,] 0 0 0 0
## [25,] 0 0 0 0
## [26,] 0 0 0 0
## [27,] 0 0 0 0
## [28,] 0 0 0 0
m_test <- matrix(fashionmnist_test[1,2:ncol(fashionmnist_test)], nrow=28, ncol=28)
m_test <- apply(m_test, 2, as.numeric)
m_test <- apply(m_test, 2, rev)
m_test
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
## [1,] 0 0 0 0 0 0 0 0 0 0 0 0 0
## [2,] 0 0 0 0 0 0 0 0 0 0 0 0 0
## [3,] 0 0 0 0 2 1 0 0 0 0 0 21 1
## [4,] 1 0 1 1 0 0 0 0 29 123 179 172 28
## [5,] 1 1 0 0 0 33 133 226 221 224 219 216 56
## [6,] 0 1 0 20 173 229 231 219 208 204 203 220 106
## [7,] 1 1 0 184 249 205 208 212 217 214 211 227 177
## [8,] 3 0 0 203 216 194 207 237 237 224 220 236 236
## [9,] 3 0 37 236 205 219 234 165 173 234 208 207 225
## [10,] 0 0 157 238 200 249 50 54 67 112 245 218 218
## [11,] 0 0 212 236 212 226 0 72 113 122 231 198 212
## [12,] 24 80 14 133 223 232 186 68 119 230 203 232 210
## [13,] 11 123 61 74 232 204 236 247 232 203 207 129 192
## [14,] 0 122 210 113 251 194 188 200 190 206 193 80 186
## [15,] 7 99 94 61 237 185 210 200 191 224 138 127 207
## [16,] 29 88 0 125 255 197 191 201 192 214 162 114 211
## [17,] 34 44 67 214 216 199 216 219 205 193 213 246 197
## [18,] 0 0 249 238 206 206 226 64 138 226 184 213 187
## [19,] 0 0 174 239 205 201 233 172 165 231 200 221 144
## [20,] 8 0 46 242 207 200 195 229 207 212 205 201 224
## [21,] 9 0 3 206 223 197 192 194 211 207 203 219 224
## [22,] 0 0 0 168 245 201 208 216 220 219 211 233 160
## [23,] 0 1 0 23 175 225 223 218 207 200 204 218 64
## [24,] 0 0 0 0 0 53 133 216 221 216 212 191 5
## [25,] 0 0 2 2 0 0 0 0 50 131 195 185 4
## [26,] 0 4 1 2 1 7 1 0 0 0 0 8 21
## [27,] 0 0 0 0 0 0 0 0 0 0 0 0 0
## [28,] 0 0 0 0 0 0 0 0 0 0 0 0 0
## [,14] [,15] [,16] [,17] [,18] [,19] [,20] [,21] [,22] [,23] [,24]
## [1,] 0 0 0 0 0 0 0 0 0 0 0
## [2,] 0 0 0 0 0 0 0 0 0 0 0
## [3,] 2 0 2 0 1 0 0 1 1 0 0
## [4,] 4 0 1 0 0 0 0 1 1 0 0
## [5,] 0 2 2 0 0 0 0 0 1 0 2
## [6,] 3 2 4 4 3 4 2 2 2 0 1
## [7,] 1 0 0 0 0 0 0 0 0 0 0
## [8,] 118 95 102 112 119 128 139 148 157 181 187
## [9,] 252 236 237 231 227 229 230 230 230 226 229
## [10,] 162 174 203 207 206 206 205 204 204 204 207
## [11,] 64 140 237 199 210 208 211 211 212 207 213
## [12,] 76 175 226 203 207 206 206 208 207 206 210
## [13,] 79 125 231 192 206 204 204 205 206 205 207
## [14,] 123 159 224 198 203 201 204 203 203 203 205
## [15,] 70 161 217 195 200 200 200 201 203 203 204
## [16,] 73 165 220 198 201 201 201 201 205 199 199
## [17,] 51 155 221 194 199 200 201 204 203 200 208
## [18,] 120 170 217 199 206 204 205 204 203 206 201
## [19,] 96 145 204 201 199 198 195 194 194 195 197
## [20,] 252 226 225 223 223 223 223 221 220 221 224
## [21,] 116 115 131 135 134 139 145 157 166 171 165
## [22,] 0 0 0 0 0 0 0 0 0 0 0
## [23,] 2 0 2 3 1 1 1 0 0 0 0
## [24,] 0 0 1 0 0 0 0 1 1 0 1
## [25,] 1 0 0 1 1 0 0 0 1 0 0
## [26,] 1 0 0 1 1 0 0 1 1 0 0
## [27,] 0 0 0 0 0 0 0 0 0 0 0
## [28,] 0 0 0 0 0 0 0 0 0 0 0
## [,25] [,26] [,27] [,28]
## [1,] 0 0 0 0
## [2,] 0 0 0 0
## [3,] 0 0 0 0
## [4,] 0 0 0 0
## [5,] 2 0 0 0
## [6,] 0 0 3 0
## [7,] 0 0 0 0
## [8,] 191 161 234 56
## [9,] 225 218 245 87
## [10,] 205 199 224 103
## [11,] 213 199 245 125
## [12,] 210 195 242 135
## [13,] 206 193 236 145
## [14,] 205 191 232 142
## [15,] 203 187 231 138
## [16,] 211 187 239 138
## [17,] 207 194 233 127
## [18,] 201 188 239 123
## [19,] 203 191 213 94
## [20,] 201 201 240 69
## [21,] 128 141 212 37
## [22,] 0 0 0 0
## [23,] 0 1 1 0
## [24,] 0 1 0 0
## [25,] 0 0 0 0
## [26,] 0 0 0 0
## [27,] 0 0 0 0
## [28,] 0 0 0 0
image(1:28, 1:28, m_train)
text(5, 2, col="white", cex=1.2, fashionmnist_train[1, 1])
range(fashionmnist_train[1,2:ncol(fashionmnist_train)])
## [1] 0 255
vizTrain <- function(input){
dimmax <- sqrt(ncol(fashionmnist_train[,-1]))
dimn <- ceiling(sqrt(nrow(input)))
par(mfrow=c(dimn, dimn), mar=c(.1, .1, .1, .1))
for (i in 1:nrow(input)){
m1 <- matrix(input[i,2:ncol(input)], nrow=dimmax, byrow=T)
m1 <- apply(m1, 2, as.numeric)
m1 <- t(apply(m1, 2, rev))
image(1:dimmax, 1:dimmax, m1, col=grey.colors(255), xaxt = 'n', yaxt = 'n')
text(5, 20, col="red", cex=1.2, fashionmnist_train[i, 1])
}
}
vizTrain(fashionmnist_train[1:25,])
vizTrain(fashionmnist_train[255,])
We make the neural network architecture (neural network design, how many nodes we will use)
mxnet
has its own “grammar” in defining a neural network’s architecture. Specifically, it made use of what it termed as the symbolic API, starting with a placeholder data
symbol.
In the following code, we’re assembling a network with the following architecture: data --> 128 units ReLU --> 64 units ReLU --> 10 units output softMax
# notice how our layers are fully-connected and cascading
m1f.data <- mx.symbol.Variable("data") #define data
# num_hidden : to specify our hidden layers
m1f.fc1 <- mx.symbol.FullyConnected(m1f.data,
name = "fc1",
num_hidden = 128)
# act_type : activation function used
m1f.act1 <- mx.symbol.Activation(m1f.fc1,
name = "activation1",
act_type = "relu")
m1f.fc2 <- mx.symbol.FullyConnected(m1f.act1,
name = "fc2",
num_hidden = 64)
m1f.act2 <- mx.symbol.Activation(m1f.fc2,
name = "activation2",
act_type = "relu")
m1f.fc3 <- mx.symbol.FullyConnected(m1f.act2,
name = "fc3",
num_hidden = 10)
m1f.softmax <- mx.symbol.SoftmaxOutput(m1f.fc3,
name="softMax")
# putting train and test data into the architecture
train <- data.matrix(fashionmnist_train)
test <- data.matrix(fashionmnist_test)
# remove 'label' from the data
# train_x: result in 42000 x 784 matrix range 0 to 255
trainf_x <- train[,-1]
trainf_y <- train[,1]
testf_x <- test[,-1]
testf_y <- test[,1]
# train_x: result in 784 x 42000 matrix range 0 to 1 (divide by 255)
# each of the 784 is treated as input (predictors), hence the transpose
trainf_x <- t(trainf_x/255) # normalized it with the maximum value
testf_x <- t(testf_x/255)
# How to make the model
log <- mx.metric.logger$new()
startime <- proc.time()
mx.set.seed(0)
m1f <- mx.model.FeedForward.create(m1f.softmax, # the network configuration made above
X = trainf_x, # input (predictors)
y = trainf_y, # the labels
ctx = mx.cpu(), # use cpu not gpu
num.round = 50,
array.batch.size = 80,
momentum = 0.95,
array.layout="colmajor",
learning.rate = 0.001,
eval.metric = mx.metric.accuracy,
epoch.end.callback = mx.callback.log.train.metric(1,log)
)
## Start training with 1 devices
## [1] Train-accuracy=0.101050000359615
## [2] Train-accuracy=0.162933333282669
## [3] Train-accuracy=0.405249999329448
## [4] Train-accuracy=0.608633332808812
## [5] Train-accuracy=0.695250000476837
## [6] Train-accuracy=0.728516666253408
## [7] Train-accuracy=0.74975
## [8] Train-accuracy=0.769633332808812
## [9] Train-accuracy=0.786516666889191
## [10] Train-accuracy=0.799883333683014
## [11] Train-accuracy=0.810099999666214
## [12] Train-accuracy=0.818133334318797
## [13] Train-accuracy=0.823550000111262
## [14] Train-accuracy=0.830083332935969
## [15] Train-accuracy=0.836250000397364
## [16] Train-accuracy=0.84138333272934
## [17] Train-accuracy=0.846416666348775
## [18] Train-accuracy=0.850599999745687
## [19] Train-accuracy=0.854599999030431
## [20] Train-accuracy=0.858049999237061
## [21] Train-accuracy=0.861733332633972
## [22] Train-accuracy=0.864566665331523
## [23] Train-accuracy=0.867333331902822
## [24] Train-accuracy=0.869883333206177
## [25] Train-accuracy=0.87263333328565
## [26] Train-accuracy=0.874733333587646
## [27] Train-accuracy=0.87686666782697
## [28] Train-accuracy=0.878650000969569
## [29] Train-accuracy=0.880566666523615
## [30] Train-accuracy=0.882366666873296
## [31] Train-accuracy=0.884283332983653
## [32] Train-accuracy=0.885450000127157
## [33] Train-accuracy=0.886966666936874
## [34] Train-accuracy=0.888333333015442
## [35] Train-accuracy=0.889450000127157
## [36] Train-accuracy=0.89061666671435
## [37] Train-accuracy=0.892016667048136
## [38] Train-accuracy=0.893349999984105
## [39] Train-accuracy=0.894633333683014
## [40] Train-accuracy=0.896166667222977
## [41] Train-accuracy=0.897200001319249
## [42] Train-accuracy=0.898316667636236
## [43] Train-accuracy=0.899533334175746
## [44] Train-accuracy=0.900533333619436
## [45] Train-accuracy=0.901416666984558
## [46] Train-accuracy=0.902383333762487
## [47] Train-accuracy=0.903216666062673
## [48] Train-accuracy=0.904500000079473
## [49] Train-accuracy=0.905783333539963
## [50] Train-accuracy=0.906750001033147
print(paste("Training took:", round((proc.time() - startime)[3],2),"seconds"))
## [1] "Training took: 250.98 seconds"
We train our model to get 90% accuracy.
fashionmnist_train$label <- factor(fashionmnist_train$label, labels = categories)
fashionmnist_test$label <- factor(fashionmnist_test$label, labels = categories)
m1f_pred <- predict(m1f,
testf_x,
array.layout = "colmajor")
t(round(m1f_pred[,1:5], 2))
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
## [1,] 0.85 0 0.00 0.00 0.00 0 0.15 0 0 0
## [2,] 0.00 1 0.00 0.00 0.00 0 0.00 0 0 0
## [3,] 0.03 0 0.79 0.00 0.01 0 0.17 0 0 0
## [4,] 0.23 0 0.68 0.00 0.00 0 0.09 0 0 0
## [5,] 0.00 0 0.00 0.18 0.82 0 0.00 0 0 0
m1f_pred_result <- max.col(t(m1f_pred)) -1
m1f_pred_result <- factor(m1f_pred_result, labels = categories)
head(m1f_pred_result)
## [1] T-shirt Trouser Pullover Pullover Coat Shirt
## 10 Levels: T-shirt Trouser Pullover Dress Coat Sandal Shirt ... Boot
table(fashionmnist_test$label, m1f_pred_result)
## m1f_pred_result
## T-shirt Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag
## T-shirt 878 1 17 9 0 2 77 0 16
## Trouser 5 982 1 6 1 0 5 0 0
## Pullover 18 1 808 6 97 1 67 0 2
## Dress 43 19 14 863 47 0 14 0 0
## Coat 1 2 62 14 866 2 50 0 3
## Sandal 1 0 0 0 0 950 0 22 3
## Shirt 167 3 61 12 64 1 684 0 8
## Sneaker 0 0 0 0 0 47 0 838 1
## Bag 6 0 11 2 4 3 8 1 963
## Boot 0 0 1 0 0 6 0 11 2
## m1f_pred_result
## Boot
## T-shirt 0
## Trouser 0
## Pullover 0
## Dress 0
## Coat 0
## Sandal 24
## Shirt 0
## Sneaker 114
## Bag 2
## Boot 980
head(m1f_pred_result)
## [1] T-shirt Trouser Pullover Pullover Coat Shirt
## 10 Levels: T-shirt Trouser Pullover Dress Coat Sandal Shirt ... Boot
caret::confusionMatrix(data = m1f_pred_result,
reference = fashionmnist_test$label)
## Confusion Matrix and Statistics
##
## Reference
## Prediction T-shirt Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag
## T-shirt 878 5 18 43 1 1 167 0 6
## Trouser 1 982 1 19 2 0 3 0 0
## Pullover 17 1 808 14 62 0 61 0 11
## Dress 9 6 6 863 14 0 12 0 2
## Coat 0 1 97 47 866 0 64 0 4
## Sandal 2 0 1 0 2 950 1 47 3
## Shirt 77 5 67 14 50 0 684 0 8
## Sneaker 0 0 0 0 0 22 0 838 1
## Bag 16 0 2 0 3 3 8 1 963
## Boot 0 0 0 0 0 24 0 114 2
## Reference
## Prediction Boot
## T-shirt 0
## Trouser 0
## Pullover 1
## Dress 0
## Coat 0
## Sandal 6
## Shirt 0
## Sneaker 11
## Bag 2
## Boot 980
##
## Overall Statistics
##
## Accuracy : 0.8812
## 95% CI : (0.8747, 0.8875)
## No Information Rate : 0.1
## P-Value [Acc > NIR] : < 0.00000000000000022
##
## Kappa : 0.868
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: T-shirt Class: Trouser Class: Pullover
## Sensitivity 0.8780 0.9820 0.8080
## Specificity 0.9732 0.9971 0.9814
## Pos Pred Value 0.7846 0.9742 0.8287
## Neg Pred Value 0.9863 0.9980 0.9787
## Prevalence 0.1000 0.1000 0.1000
## Detection Rate 0.0878 0.0982 0.0808
## Detection Prevalence 0.1119 0.1008 0.0975
## Balanced Accuracy 0.9256 0.9896 0.8947
## Class: Dress Class: Coat Class: Sandal Class: Shirt
## Sensitivity 0.8630 0.8660 0.9500 0.6840
## Specificity 0.9946 0.9763 0.9931 0.9754
## Pos Pred Value 0.9463 0.8026 0.9387 0.7558
## Neg Pred Value 0.9849 0.9850 0.9944 0.9653
## Prevalence 0.1000 0.1000 0.1000 0.1000
## Detection Rate 0.0863 0.0866 0.0950 0.0684
## Detection Prevalence 0.0912 0.1079 0.1012 0.0905
## Balanced Accuracy 0.9288 0.9212 0.9716 0.8297
## Class: Sneaker Class: Bag Class: Boot
## Sensitivity 0.8380 0.9630 0.9800
## Specificity 0.9962 0.9961 0.9844
## Pos Pred Value 0.9610 0.9649 0.8750
## Neg Pred Value 0.9823 0.9959 0.9977
## Prevalence 0.1000 0.1000 0.1000
## Detection Rate 0.0838 0.0963 0.0980
## Detection Prevalence 0.0872 0.0998 0.1120
## Balanced Accuracy 0.9171 0.9796 0.9822
We still get 88.12% accuracy from the unseen data, which is pretty good, very close to the model we build with the train dataset.
plotResults <- function(images, preds){
x <- ceiling(sqrt(length(images)))
par(mfrow = c(x,x),
mar = c(.1,.1,.1,.1))
for (i in images){
m <- matrix(test[i, 2:785],
nrow = 28,
byrow = TRUE)
m <- apply(m, 2, rev)
image(t(m),
col = grey.colors(255),
axes = FALSE)
text(0.5,0.1,col="red", cex=1.2, preds[i])
}
}
plotResults(1:49,
m1f_pred_result)