Build a network to classify Reuters newswires from 1986 into 46 mutually exclusive topics.
The argument num_words = 10000 restricts the data to the 10,000 most frequently occurring words found in the data.
library(keras)
reuters <- dataset_reuters(num_words = 10000)
c(c(train_data, train_labels), c(test_data, test_labels)) %<-% reuters
There are 8982 training samples and 2246 testing samples.
length(train_data)
[1] 8982
length(test_data)
[1] 2246
Again, we have a list of integers that serve as word indices.
train_data[[1]]
[1] 1 2 2 8 43 10 447 5 25 207 270 5 3095 111 16 369 186 90 67 7 89 5 19
[24] 102 6 19 124 15 90 67 84 22 482 26 7 48 4 49 8 864 39 209 154 6 151 6
[47] 83 11 15 22 155 11 15 7 48 9 4579 1005 504 6 258 6 272 11 15 22 134 44 11
[70] 15 16 8 197 1245 90 67 52 29 209 30 32 132 6 109 15 17 12
Decoding back to words.
# word_index is a named list mapping words to an integer index
word_index <- dataset_reuters_word_index()
# Reverses it, mapping integer indices to words
reverse_word_index <- names(word_index)
names(reverse_word_index) <- word_index
# Decodes the 1st wire. Note that the indices are offset by 3 because 0, 1, and 2 are reserved indices for "padding," "start of sequence," and "unknown."
decoded_newswire <- sapply(train_data[[1]], function(index) {
word <- if (index >= 3) reverse_word_index[[as.character(index - 3)]]
if (!is.null(word)) word else "?"
})
paste(decoded_newswire, collapse = " ")
[1] "? ? ? said as a result of its december acquisition of space co it expects earnings per share in 1987 of 1 15 to 1 30 dlrs per share up from 70 cts in 1986 the company said pretax net should rise to nine to 10 mln dlrs from six mln dlrs in 1986 and rental operation revenues to 19 to 22 mln dlrs from 12 5 mln dlrs it said cash flow per share this year should be 2 50 to three dlrs reuter 3"
The label associated with a sample is an integer between 0 and 45 – a topic index.
train_labels[[1]]
[1] 3
Vectorize and apply one-hot encoding as done before.
vectorize_sequences <- function(sequences, dimension = 10000) {
# Initialize a matrix with all zeroes
results <- matrix(0, nrow = length(sequences), ncol = dimension)
# Replace 0 with a 1 for each column of the matrix given in the list
for (i in 1:length(sequences))
results[i, sequences[[i]]] <- 1
results
}
x_train <- vectorize_sequences(train_data)
x_test <- vectorize_sequences(test_data)
str(x_train[1,])
num [1:10000] 1 1 0 1 1 1 1 1 1 1 ...
One-hot encode the labels also. For the first sample, we see that the label was properly encoded from the integer 3 (from above) to index 3 equal to 1 in the new structure.
one_hot_train_labels <- to_categorical(train_labels)
one_hot_test_labels <- to_categorical(test_labels)
str(one_hot_train_labels[1,])
num [1:46] 0 0 0 1 0 0 0 0 0 0 ...
The dimensionality of the output space is now 46 adding complexity compared to the binary classification problem.
In the binary problem, 16-dimension intermediate layers were used. For this problem, more will be needed to allow for more learning due to the 46 separate classes. Use of smaller layers may act as information bottlenecks, permanently dropping relevant information. In this case we will go with 64 units.
The last layer uses a softmax activation. You saw this pattern in the MNIST example. It means the network will output a probability distribution over the 46 different output classes. For every input sample, the network will produce a 46-dimensional output vector, where output[[i]] is the probability that the sample belongs to class i. The 46 scores will sum to 1.
model <- keras_model_sequential() %>%
layer_dense(units = 64, activation = "relu", input_shape = c(10000)) %>%
layer_dense(units = 64, activation = "relu") %>%
layer_dense(units = 46, activation = "softmax")
The best loss function to use in this case is categorical_crossentropy. It measures the distance between two probability distributions: here, between the probability distribution output by the network and the true distribution of the labels. By minimizing the distance between these two distributions, you train the network to output something as close as possible to the true labels.
model %>% compile(
optimizer = "rmsprop",
loss = "categorical_crossentropy",
metrics = c("accuracy")
)
Let’s set apart 1,000 samples in the training data to use as a validation set.
val_indices <- 1:1000
x_val <- x_train[val_indices,]
partial_x_train <- x_train[-val_indices,]
y_val <- one_hot_train_labels[val_indices,]
partial_y_train = one_hot_train_labels[-val_indices,]
Now, let’s train the network for 20 epochs.
history <- model %>% fit(
partial_x_train,
partial_y_train,
epochs = 20,
batch_size = 512,
validation_data = list(x_val, y_val)
)
Epoch 1/20
1/16 [>.............................] - ETA: 4s - loss: 3.8699 - accuracy: 0.0059
3/16 [====>.........................] - ETA: 0s - loss: 3.7811 - accuracy: 0.1094
6/16 [==========>...................] - ETA: 0s - loss: 3.5998 - accuracy: 0.2946
10/16 [=================>............] - ETA: 0s - loss: 3.3187 - accuracy: 0.3867
15/16 [===========================>..] - ETA: 0s - loss: 2.9925 - accuracy: 0.4639
16/16 [==============================] - 1s 15ms/step - loss: 2.9587 - accuracy: 0.4711
16/16 [==============================] - 1s 20ms/step - loss: 2.9587 - accuracy: 0.4711 - val_loss: 2.0052 - val_accuracy: 0.6250
Epoch 2/20
1/16 [>.............................] - ETA: 0s - loss: 1.9019 - accuracy: 0.6680
5/16 [========>.....................] - ETA: 0s - loss: 1.7927 - accuracy: 0.6754
10/16 [=================>............] - ETA: 0s - loss: 1.6621 - accuracy: 0.6932
13/16 [=======================>......] - ETA: 0s - loss: 1.6019 - accuracy: 0.7012
16/16 [==============================] - 0s 16ms/step - loss: 1.5721 - accuracy: 0.7041
16/16 [==============================] - 0s 18ms/step - loss: 1.5721 - accuracy: 0.7041 - val_loss: 1.3750 - val_accuracy: 0.7160
Epoch 3/20
1/16 [>.............................] - ETA: 0s - loss: 1.1183 - accuracy: 0.7832
6/16 [==========>...................] - ETA: 0s - loss: 1.1671 - accuracy: 0.7686
11/16 [===================>..........] - ETA: 0s - loss: 1.1251 - accuracy: 0.7752
15/16 [===========================>..] - ETA: 0s - loss: 1.1054 - accuracy: 0.7785
16/16 [==============================] - 0s 13ms/step - loss: 1.1015 - accuracy: 0.7790
16/16 [==============================] - 0s 15ms/step - loss: 1.1015 - accuracy: 0.7790 - val_loss: 1.1533 - val_accuracy: 0.7550
Epoch 4/20
1/16 [>.............................] - ETA: 0s - loss: 0.9423 - accuracy: 0.7949
6/16 [==========>...................] - ETA: 0s - loss: 0.9085 - accuracy: 0.8115
11/16 [===================>..........] - ETA: 0s - loss: 0.8690 - accuracy: 0.8232
16/16 [==============================] - 0s 11ms/step - loss: 0.8570 - accuracy: 0.8236
16/16 [==============================] - 0s 13ms/step - loss: 0.8570 - accuracy: 0.8236 - val_loss: 1.0368 - val_accuracy: 0.7950
Epoch 5/20
1/16 [>.............................] - ETA: 0s - loss: 0.7957 - accuracy: 0.8398
6/16 [==========>...................] - ETA: 0s - loss: 0.7248 - accuracy: 0.8496
11/16 [===================>..........] - ETA: 0s - loss: 0.6994 - accuracy: 0.8517
16/16 [==============================] - 0s 11ms/step - loss: 0.6864 - accuracy: 0.8581
16/16 [==============================] - 0s 12ms/step - loss: 0.6864 - accuracy: 0.8581 - val_loss: 0.9683 - val_accuracy: 0.8090
Epoch 6/20
1/16 [>.............................] - ETA: 0s - loss: 0.5589 - accuracy: 0.8926
6/16 [==========>...................] - ETA: 0s - loss: 0.5625 - accuracy: 0.8906
11/16 [===================>..........] - ETA: 0s - loss: 0.5530 - accuracy: 0.8908
16/16 [==============================] - 0s 11ms/step - loss: 0.5521 - accuracy: 0.8879
16/16 [==============================] - 0s 12ms/step - loss: 0.5521 - accuracy: 0.8879 - val_loss: 0.9294 - val_accuracy: 0.8030
Epoch 7/20
1/16 [>.............................] - ETA: 0s - loss: 0.4947 - accuracy: 0.9004
6/16 [==========>...................] - ETA: 0s - loss: 0.4510 - accuracy: 0.9082
11/16 [===================>..........] - ETA: 0s - loss: 0.4449 - accuracy: 0.9089
16/16 [==============================] - 0s 11ms/step - loss: 0.4465 - accuracy: 0.9069
16/16 [==============================] - 0s 12ms/step - loss: 0.4465 - accuracy: 0.9069 - val_loss: 0.9028 - val_accuracy: 0.8190
Epoch 8/20
1/16 [>.............................] - ETA: 0s - loss: 0.3109 - accuracy: 0.9395
6/16 [==========>...................] - ETA: 0s - loss: 0.3468 - accuracy: 0.9287
12/16 [=====================>........] - ETA: 0s - loss: 0.3592 - accuracy: 0.9251
16/16 [==============================] - 0s 10ms/step - loss: 0.3621 - accuracy: 0.9245
16/16 [==============================] - 0s 11ms/step - loss: 0.3621 - accuracy: 0.9245 - val_loss: 0.8920 - val_accuracy: 0.8160
Epoch 9/20
1/16 [>.............................] - ETA: 0s - loss: 0.3324 - accuracy: 0.9414
6/16 [==========>...................] - ETA: 0s - loss: 0.3086 - accuracy: 0.9365
12/16 [=====================>........] - ETA: 0s - loss: 0.3003 - accuracy: 0.9338
16/16 [==============================] - 0s 10ms/step - loss: 0.2984 - accuracy: 0.9344
16/16 [==============================] - 0s 11ms/step - loss: 0.2984 - accuracy: 0.9344 - val_loss: 0.9076 - val_accuracy: 0.8150
Epoch 10/20
1/16 [>.............................] - ETA: 0s - loss: 0.2236 - accuracy: 0.9629
6/16 [==========>...................] - ETA: 0s - loss: 0.2431 - accuracy: 0.9505
12/16 [=====================>........] - ETA: 0s - loss: 0.2453 - accuracy: 0.9465
16/16 [==============================] - 0s 10ms/step - loss: 0.2539 - accuracy: 0.9424
16/16 [==============================] - 0s 11ms/step - loss: 0.2539 - accuracy: 0.9424 - val_loss: 0.8895 - val_accuracy: 0.8240
Epoch 11/20
1/16 [>.............................] - ETA: 0s - loss: 0.2606 - accuracy: 0.9453
7/16 [============>.................] - ETA: 0s - loss: 0.2237 - accuracy: 0.9528
11/16 [===================>..........] - ETA: 0s - loss: 0.2216 - accuracy: 0.9485
16/16 [==============================] - 0s 10ms/step - loss: 0.2175 - accuracy: 0.9474
16/16 [==============================] - 0s 12ms/step - loss: 0.2175 - accuracy: 0.9474 - val_loss: 0.9538 - val_accuracy: 0.8070
Epoch 12/20
1/16 [>.............................] - ETA: 0s - loss: 0.2022 - accuracy: 0.9473
6/16 [==========>...................] - ETA: 0s - loss: 0.1725 - accuracy: 0.9613
11/16 [===================>..........] - ETA: 0s - loss: 0.1875 - accuracy: 0.9540
16/16 [==============================] - 0s 11ms/step - loss: 0.1905 - accuracy: 0.9528
16/16 [==============================] - 0s 12ms/step - loss: 0.1905 - accuracy: 0.9528 - val_loss: 0.9623 - val_accuracy: 0.8010
Epoch 13/20
1/16 [>.............................] - ETA: 0s - loss: 0.1785 - accuracy: 0.9551
6/16 [==========>...................] - ETA: 0s - loss: 0.1564 - accuracy: 0.9593
11/16 [===================>..........] - ETA: 0s - loss: 0.1659 - accuracy: 0.9553
16/16 [==============================] - 0s 10ms/step - loss: 0.1719 - accuracy: 0.9533
16/16 [==============================] - 0s 11ms/step - loss: 0.1719 - accuracy: 0.9533 - val_loss: 0.9416 - val_accuracy: 0.8160
Epoch 14/20
1/16 [>.............................] - ETA: 0s - loss: 0.1235 - accuracy: 0.9629
6/16 [==========>...................] - ETA: 0s - loss: 0.1418 - accuracy: 0.9577
11/16 [===================>..........] - ETA: 0s - loss: 0.1525 - accuracy: 0.9558
16/16 [==============================] - 0s 11ms/step - loss: 0.1551 - accuracy: 0.9534
16/16 [==============================] - 0s 13ms/step - loss: 0.1551 - accuracy: 0.9534 - val_loss: 0.9476 - val_accuracy: 0.8060
Epoch 15/20
1/16 [>.............................] - ETA: 0s - loss: 0.0983 - accuracy: 0.9668
4/16 [======>.......................] - ETA: 0s - loss: 0.1202 - accuracy: 0.9653
9/16 [===============>..............] - ETA: 0s - loss: 0.1298 - accuracy: 0.9592
14/16 [=========================>....] - ETA: 0s - loss: 0.1429 - accuracy: 0.9563
16/16 [==============================] - 0s 12ms/step - loss: 0.1438 - accuracy: 0.9551
16/16 [==============================] - 0s 13ms/step - loss: 0.1438 - accuracy: 0.9551 - val_loss: 0.9849 - val_accuracy: 0.8150
Epoch 16/20
1/16 [>.............................] - ETA: 0s - loss: 0.1102 - accuracy: 0.9648
6/16 [==========>...................] - ETA: 0s - loss: 0.1100 - accuracy: 0.9655
12/16 [=====================>........] - ETA: 0s - loss: 0.1297 - accuracy: 0.9591
16/16 [==============================] - 0s 10ms/step - loss: 0.1318 - accuracy: 0.9575
16/16 [==============================] - 0s 11ms/step - loss: 0.1318 - accuracy: 0.9575 - val_loss: 1.0448 - val_accuracy: 0.7990
Epoch 17/20
1/16 [>.............................] - ETA: 0s - loss: 0.1089 - accuracy: 0.9707
6/16 [==========>...................] - ETA: 0s - loss: 0.1051 - accuracy: 0.9684
11/16 [===================>..........] - ETA: 0s - loss: 0.1210 - accuracy: 0.9631
16/16 [==============================] - 0s 10ms/step - loss: 0.1278 - accuracy: 0.9578
16/16 [==============================] - 0s 11ms/step - loss: 0.1278 - accuracy: 0.9578 - val_loss: 0.9959 - val_accuracy: 0.8150
Epoch 18/20
1/16 [>.............................] - ETA: 0s - loss: 0.1021 - accuracy: 0.9766
5/16 [========>.....................] - ETA: 0s - loss: 0.1087 - accuracy: 0.9680
10/16 [=================>............] - ETA: 0s - loss: 0.1141 - accuracy: 0.9635
15/16 [===========================>..] - ETA: 0s - loss: 0.1206 - accuracy: 0.9592
16/16 [==============================] - 0s 11ms/step - loss: 0.1218 - accuracy: 0.9592
16/16 [==============================] - 0s 13ms/step - loss: 0.1218 - accuracy: 0.9592 - val_loss: 1.0678 - val_accuracy: 0.8010
Epoch 19/20
1/16 [>.............................] - ETA: 0s - loss: 0.0777 - accuracy: 0.9707
6/16 [==========>...................] - ETA: 0s - loss: 0.1090 - accuracy: 0.9587
12/16 [=====================>........] - ETA: 0s - loss: 0.1163 - accuracy: 0.9570
16/16 [==============================] - 0s 10ms/step - loss: 0.1143 - accuracy: 0.9568
16/16 [==============================] - 0s 11ms/step - loss: 0.1143 - accuracy: 0.9568 - val_loss: 1.0761 - val_accuracy: 0.7990
Epoch 20/20
1/16 [>.............................] - ETA: 0s - loss: 0.1250 - accuracy: 0.9551
7/16 [============>.................] - ETA: 0s - loss: 0.1056 - accuracy: 0.9626
13/16 [=======================>......] - ETA: 0s - loss: 0.1137 - accuracy: 0.9588
16/16 [==============================] - 0s 9ms/step - loss: 0.1164 - accuracy: 0.9573
16/16 [==============================] - 0s 10ms/step - loss: 0.1164 - accuracy: 0.9573 - val_loss: 1.0617 - val_accuracy: 0.8090
plot(history)
`geom_smooth()` using formula 'y ~ x'
The network begins to overfit after nine epochs. Let’s train a new network from scratch for nine epochs and then evaluate it on the test set.
model <- keras_model_sequential() %>%
layer_dense(units = 64, activation = "relu", input_shape = c(10000)) %>%
layer_dense(units = 64, activation = "relu") %>%
layer_dense(units = 46, activation = "softmax")
model %>% compile(
optimizer = "rmsprop",
loss = "categorical_crossentropy",
metrics = c("accuracy")
)
history <- model %>% fit(
partial_x_train,
partial_y_train,
epochs = 9,
batch_size = 512,
validation_data = list(x_val, y_val)
)
Epoch 1/9
1/16 [>.............................] - ETA: 3s - loss: 3.8369 - accuracy: 0.0312
5/16 [========>.....................] - ETA: 0s - loss: 3.5004 - accuracy: 0.3105
9/16 [===============>..............] - ETA: 0s - loss: 3.1227 - accuracy: 0.4015
13/16 [=======================>......] - ETA: 0s - loss: 2.8102 - accuracy: 0.4656
16/16 [==============================] - 0s 13ms/step - loss: 2.6486 - accuracy: 0.4986
16/16 [==============================] - 0s 17ms/step - loss: 2.6486 - accuracy: 0.4986 - val_loss: 1.7172 - val_accuracy: 0.6580
Epoch 2/9
1/16 [>.............................] - ETA: 0s - loss: 1.4873 - accuracy: 0.7363
6/16 [==========>...................] - ETA: 0s - loss: 1.5088 - accuracy: 0.7061
11/16 [===================>..........] - ETA: 0s - loss: 1.4755 - accuracy: 0.7029
16/16 [==============================] - 0s 10ms/step - loss: 1.4227 - accuracy: 0.7082
16/16 [==============================] - 0s 11ms/step - loss: 1.4227 - accuracy: 0.7082 - val_loss: 1.3113 - val_accuracy: 0.7100
Epoch 3/9
1/16 [>.............................] - ETA: 0s - loss: 1.1503 - accuracy: 0.7520
6/16 [==========>...................] - ETA: 0s - loss: 1.0910 - accuracy: 0.7718
11/16 [===================>..........] - ETA: 0s - loss: 1.0699 - accuracy: 0.7747
16/16 [==============================] - 0s 11ms/step - loss: 1.0570 - accuracy: 0.7771
16/16 [==============================] - 0s 12ms/step - loss: 1.0570 - accuracy: 0.7771 - val_loss: 1.1626 - val_accuracy: 0.7460
Epoch 4/9
1/16 [>.............................] - ETA: 0s - loss: 0.8849 - accuracy: 0.8203
6/16 [==========>...................] - ETA: 0s - loss: 0.8606 - accuracy: 0.8177
12/16 [=====================>........] - ETA: 0s - loss: 0.8492 - accuracy: 0.8164
16/16 [==============================] - 0s 10ms/step - loss: 0.8331 - accuracy: 0.8211
16/16 [==============================] - 0s 11ms/step - loss: 0.8331 - accuracy: 0.8211 - val_loss: 1.0473 - val_accuracy: 0.7750
Epoch 5/9
1/16 [>.............................] - ETA: 0s - loss: 0.7235 - accuracy: 0.8457
7/16 [============>.................] - ETA: 0s - loss: 0.6817 - accuracy: 0.8535
13/16 [=======================>......] - ETA: 0s - loss: 0.6797 - accuracy: 0.8540
16/16 [==============================] - 0s 9ms/step - loss: 0.6616 - accuracy: 0.8578
16/16 [==============================] - 0s 10ms/step - loss: 0.6616 - accuracy: 0.8578 - val_loss: 0.9789 - val_accuracy: 0.7970
Epoch 6/9
1/16 [>.............................] - ETA: 0s - loss: 0.6526 - accuracy: 0.8613
6/16 [==========>...................] - ETA: 0s - loss: 0.5290 - accuracy: 0.8916
12/16 [=====================>........] - ETA: 0s - loss: 0.5300 - accuracy: 0.8906
16/16 [==============================] - 0s 9ms/step - loss: 0.5251 - accuracy: 0.8911
16/16 [==============================] - 0s 11ms/step - loss: 0.5251 - accuracy: 0.8911 - val_loss: 0.9226 - val_accuracy: 0.8090
Epoch 7/9
1/16 [>.............................] - ETA: 0s - loss: 0.4846 - accuracy: 0.9023
6/16 [==========>...................] - ETA: 0s - loss: 0.4220 - accuracy: 0.9137
11/16 [===================>..........] - ETA: 0s - loss: 0.4176 - accuracy: 0.9112
15/16 [===========================>..] - ETA: 0s - loss: 0.4206 - accuracy: 0.9130
16/16 [==============================] - 0s 11ms/step - loss: 0.4225 - accuracy: 0.9126
16/16 [==============================] - 0s 12ms/step - loss: 0.4225 - accuracy: 0.9126 - val_loss: 0.9009 - val_accuracy: 0.8210
Epoch 8/9
1/16 [>.............................] - ETA: 0s - loss: 0.3499 - accuracy: 0.9277
7/16 [============>.................] - ETA: 0s - loss: 0.3597 - accuracy: 0.9241
12/16 [=====================>........] - ETA: 0s - loss: 0.3412 - accuracy: 0.9277
16/16 [==============================] - 0s 10ms/step - loss: 0.3405 - accuracy: 0.9276
16/16 [==============================] - 0s 12ms/step - loss: 0.3405 - accuracy: 0.9276 - val_loss: 0.8964 - val_accuracy: 0.8220
Epoch 9/9
1/16 [>.............................] - ETA: 0s - loss: 0.2725 - accuracy: 0.9395
6/16 [==========>...................] - ETA: 0s - loss: 0.2695 - accuracy: 0.9427
12/16 [=====================>........] - ETA: 0s - loss: 0.2798 - accuracy: 0.9404
16/16 [==============================] - 0s 10ms/step - loss: 0.2817 - accuracy: 0.9397
16/16 [==============================] - 0s 11ms/step - loss: 0.2817 - accuracy: 0.9397 - val_loss: 0.9044 - val_accuracy: 0.8200
results <- model %>% evaluate(x_test, one_hot_test_labels)
1/71 [..............................] - ETA: 0s - loss: 0.5331 - accuracy: 0.9062
71/71 [==============================] - 0s 635us/step - loss: 0.9809 - accuracy: 0.7943
71/71 [==============================] - 0s 637us/step - loss: 0.9809 - accuracy: 0.7943
results
loss accuracy
0.9808653 0.7943010
This approach reaches an accuracy of ~79%. With a balanced binary classification problem, the accuracy reached by a purely random classifier would be 50%. But in this case it’s closer to 18%, so the results seem pretty good, at least when compared to a random baseline.
test_labels_copy <- test_labels
test_labels_copy <- sample(test_labels_copy)
length(which(test_labels == test_labels_copy)) / length(test_labels)
[1] 0.182992
The predict method of the model instance returns a probability distribution over all 46 topics. Let’s generate topic predictions for all of the test data.
predictions <- model %>% predict(x_test)
dim(predictions)
[1] 2246 46
The coefficients in each vector sum to 1
sum(predictions[1,])
[1] 1
The largest entry is the predicted class – the class with the highest probability.
which.max(predictions[1,])
[1] 4
softmax activation so that it will output a probability distribution over the N output classes.