Overview

In order to gain familiarity with custom loss functions in MXNet, I first worked through the example of a custom regression loss function given here.

Since I largely work with classification problems, I wanted to test my understanding and implement the logistic equivalent of the link with a different data set.

First, I will set up the data, then I will build the network using the built-in function mx.symbol.LogisticRegressionOutput(), then I will use my own custom function to do the same thing. At the end I will compare the results.

Set up the data

Using the mtcars dataset, predict whether the car is an automatic or manual transmission.

library(mxnet)
data(mtcars)

test.ind = seq(1, 32, 5)    # 1 pt in 5 used for testing
train.x = data.matrix(mtcars[-test.ind,-9])
train.y = mtcars[-test.ind, 9]
test.x = data.matrix(mtcars[--test.ind,-9])
test.y = mtcars[--test.ind, 9]

Set up shared architecture

data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden = 14, name = "fc1")
tanh1 <- mx.symbol.Activation(fc1, act_type = "tanh", name = "tanh1")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden = 1, name = "fc2")

Logistic output using mx.symbol.LogisticRegressionOutput()

lro <- mx.symbol.LogisticRegressionOutput(fc2, name = "lro")
  
mx.set.seed(0)
modelBuiltIn <- mx.model.FeedForward.create(lro, X = train.x, y = train.y,
                                     ctx = mx.cpu(),
                                     num.round = 5,
                                     array.batch.size = 10,
                                     optimizer = "rmsprop",
                                     verbose = TRUE,
                                     array.layout = "rowmajor",
                                     batch.end.callback = NULL,
                                     epoch.end.callback = NULL)
predBuiltIn <- t(predict(modelBuiltIn,test.x))

Logistic output using a custom loss function (cross entropy)

sigOut <- mx.symbol.Activation(fc2, act_type = "sigmoid", name = "sigOut")

#Implement the function: -[y*log(yhat) + (1-y)*log(1-yhat)]
loss <- mx.symbol.MakeLoss(mx.symbol.negative(
  label*mx.symbol.log(mx.symbol.Reshape(sigOut, shape = 0)) +
    (1-label)*mx.symbol.log(1-mx.symbol.Reshape(sigOut, shape = 0))),
  name="loss")

mx.set.seed(0)
modelCustom <- mx.model.FeedForward.create(loss, X = train.x, y = train.y,
                                      ctx = mx.cpu(),
                                      num.round = 5,
                                      array.batch.size = 10,
                                      optimizer = "rmsprop",
                                      verbose = TRUE,
                                      array.layout = "rowmajor",
                                      batch.end.callback = NULL,
                                      epoch.end.callback = NULL)

internals = internals(modelCustom$symbol)
fc_symbol = internals[[match("sigOut_output", outputs(internals))]]

modelCustom2 <- list(symbol = fc_symbol,
                  arg.params = modelCustom$arg.params,
                  aux.params = modelCustom$aux.params)

class(modelCustom2) <- "MXFeedForwardModel"

predCustom <- t(predict(modelCustom2,test.x))

Compare results

Results are very close, but not exact. Would be curious to see what exact function(s) they are using.

cbind(predBuiltIn,predCustom,predBuiltIn-predCustom)
##            [,1]       [,2]          [,3]
## [1,] 0.29726729 0.29726833 -1.043081e-06
## [2,] 0.14318821 0.14321354 -2.533197e-05
## [3,] 0.25930476 0.25930521 -4.470348e-07
## [4,] 0.09107397 0.09107383  1.341105e-07
## [5,] 0.73811299 0.73810983  3.159046e-06
## [6,] 0.87688315 0.87688297  1.788139e-07
## [7,] 0.16522829 0.16522804  2.533197e-07