ID5059 Lecture 15 - Neural Networks 3

C. Donovan
11 April 2018

Administrivia

  • Project 2:
    • Using other people's code and ideas

NB: If it's not in the lecture or lab, it's not in the exam

Today

  • Example NN
  • BP calculations
  • Preventing over-fitting
  • NN pros/cons to date

Example NN on images

We'll fit a basic NN to some image data for classification and see how we did

Example NN on image - the data

  • Hand-written numbers from the MNIST dataset
  • 60,000 training images of numbers, 28 \( \times \) 28 resolution
  • Also has a test set of 10,000
  • NNs can be quite good with images, also this is a multi-class response (10 categories), which is also a good match to a NN

Example NN on image - the data

[R: You'll do similar in the lab this week]

Fitting NNs - a gradient search example (BP)

Simple in principle:

  • Given weights - NN gives a y-hat
  • \( \hat{y} \) compared to \( y \) gives an error measure (RSS say)
  • Changing the weights can make this bigger or smaller
  • Want to change weights to make this smaller
  • Error is a function of weights - so numerically optimise to reduce

It's a search over multiple dimensions (dictated by number of parameters/weights).

Error Surface

Nasty ones (like NNs)

  • Maybe lots of local minima - starting locations are influential
  • The surface is less predictable and we have to search intensively/come up with tricks

plot of chunk unnamed-chunk-1

Back propagation - more detailed

Note:

  • The fine-scale details are not examinable
  • I discovered this is torturous, so will put a detailed set of calculations on Moodle, rather than in the lecture
  • Nonetheless, some elements follow

Back propagation - high level view

Simple in principle:

  • Set some initial weights (can't estimate error without a parameterised model) - software deals with this - probably random uniform.
  • Calculate an initial error (based on observed versus current predicted).
  • For each weight determine if increasing or decreasing the weight increases/decreases the error.
  • Move a bit in the correct direction. Recalculate error with new parameters. Repeat.
  • Stop at some point i.e. further weight alterations make no/little improvement.

This is a gradient search, iterating over multiple dimensions (dictated by number of parameters/weights).

Back propagation - mid-level view

Refer H, T & F sections 11.3 & 11.4. Simplified version follows.

  • Create little local problems to solve at each non-input node. Iteration \( r+1 \): \[ \beta^{r+1}=\beta^r-\gamma \frac{\partial R}{\partial \beta^r} \]
  • So, if \( R \) increases with increasing \( \beta^r \), decrease to create \( \beta^{r+1} \) by step \( \gamma \).
  • Keep doing this until \( R \) gets small.

Back propagation in more detail

Consider the following simple NN \[ y = \beta_0 + \beta_1z_1 + \beta_2z_2 \] where

\[ \begin{align*} z_1 &= \frac{1}{1+e^{-(\alpha_0 + \alpha_1x_1 + \alpha_2x_2)}}\\ z_2 &= \frac{1}{1+e^{(\alpha_3 + \alpha_4x_1 + \alpha_5x_2)}}\\ \end{align*} \]

We're seeking to optmise the weights (the \( \alpha \) and \( \beta \)).

Back propagation in more detail

  • As discussed, this is a (non-linear) optimisation problem - we want to change the weights to better predict \( y \).
  • Define a loss function - lets say simple square error (i.e. want to minimise the RSS). Use \( R \) for _R_esubstitution error \[ R_i = (y_i-\hat{y}_i)^2 \]
  • Set initial weights - just random numbers will do.
  • Now want to know whether increasing or decreasing a particular weight is good or bad WRT \( R \).

Back propagation in more detail

  • Use the inital weights and inputs to make predictions \( y \).
  • \( \hat{y} \) (and \( R \)) is a function of many things, but we want to alter weights. So we want to determine \[ \frac{\partial R_i}{\partial \beta_k}\quad {\rm and}\quad \frac{\partial R_i}{\partial \alpha_s} \] for \( k=1,2 \) and \( s=1,2 \)

Back propagation in more detail

  • Start with the weights nearest \( y \), say \( \beta_1 \), noting \( \hat{y}=\beta_0+\beta_1 z_1 + \beta_2 z_2 \).
  • \( R \) is a function of \( y \) (fixed) and \( \hat{y} \): \( (y - \hat{y})^2 \), say \( h(y, \hat{y}) \)
  • \( f \)=identity (a placeholder for a potential activation function), \( g \)=linear combination function, so we have \( \hat{y} = f(g(z, \beta))=\beta_0+\beta_1 z_1 + \beta_2 z_2 \)
  • Meaning \( R=h(f(g(z, \beta))) \) - to get \( \frac{\partial R}{\partial \beta} \) we can apply the chain rule
  • Let's drop \( i \) for the time being. Also we're using an identity activation function for \( f \), which makes things easier:

\[ \frac{\partial R}{\partial \beta_1} = 2(y-(\beta_0+\beta_1 z_{1} + \beta_2 z_{2}))(-1) \times z_{1} \]

Back propagation in more detail

  • Do this further down the net, say for \( \alpha_1 \), noting \( z_1=\frac{1}{1+exp^{-(\alpha_0+\alpha_1 x_1 + \alpha_2 x_2)}} \)
  • We have \( \frac{\partial R}{\partial \beta_1} \), so further need \( \frac{\partial \beta_1}{\partial \alpha_1} \)

\[ \frac{\partial R}{\partial \alpha_1} = 2(y-(\beta_0+\beta_1 z_{1} + \beta_2 z_{2}))(-1) \times \beta_1 \times z_1(1-z_1) \times x_1 \]

  • Noting that the derivative of a logistic \( f \) is \( f(1-f) \).

The details change particularly with the loss and activation functions (combination function is probably the same).

Back propagation in more detail

The following are sometimes referred to the the errors (often denoted \( \delta \)) that have been “propagated backwards”:

  • \( 2(y-(\beta_0+\beta_1 z_{1} + \beta_2 z_{2}))(-1) \)
  • \( 2(y-(\beta_0+\beta_1 z_{1} + \beta_2 z_{2}))(-1) \times \beta_1 \times z_1(1-z_1) \)

We needed pass forwards to get the error, then using this work backwards to evaluate the derivatives.

Back propagation in more detail

  • Armed with these, we update weights based on all the \( i \) inputs:

\[ \begin{align*} \beta_k^{r+1}&=\beta_k^{r}-\gamma\sum_i^{n}\frac{\partial R_i}{\partial \beta_k^r}\\ \alpha_s^{r+1}&=\alpha_s^{r}-\gamma\sum_i^{n}\frac{\partial R_i}{\partial \alpha_s^r} \end{align*} \]

where the size of movements is controlled by \( \gamma \) (learning rate). We alter weights from the bottom up (start with the \( \alpha \))

Over-view and over-fitting

Fitting

  • NNs can be viewed as a potentially complex combination of basis functions in \( \mathbf{X} \)….
  • Fitting can be thought of as a gradient search method e.g. some variant on the Newton method that seeks to minimise the error. So we have to consider:

    • The efficiency of approaches - many exist.
    • Step sizes in gradient search (learning rates).
    • Problems such as oscillation, local minima, slow convergence.
    • Solutions such as multiple starts and momentum parameters (allows over-shooting of min) to mitigate against local minima.

Overfitting

  • A NN can be a very rich class of functions with even just a single hidden layer with a few hidden units
  • So we are likely to have a model with sufficient inherent complexity to model complex systems
  • This presents a problem too - the model can easily overfit i.e. learn the training dataset very well, giving a model with poor generality
  • The standard problem that we have encountered throughout our consideration of automated model selections

Overfitting - controls

  • Early-stopping: stop the training at an appropriate point (the models get progressively closer to the training data with fitting iterations)
  • Regularisation (e.g. Weight-decay parameter later): use penalised fit measures, where we balance fidelity to the data to a measure of model complexity
  • Architecture: make the NN simpler/more complex, number of nodes, layers, connections.
  • Dropout regularisation: (refer Srivastava et al 2014) - randomly dropping nodes/units and connections during training iterations
  • Batch normalisation: (refer Ioffe & Szegedy 2014) a normalisation of hidden layers during fitting. Speeds things up and in effect regularises a little i.e. constrains the weights

Weight decay

This is a type of regularisation - penalised fitting

  • Similar to the approach in tree-methods we can balance our raw model fit against a measure of model complexity
  • Using \( R_{\boldsymbol{\theta}} \) as our measure of resubstitution error with a given set of parameters \( \boldsymbol{\theta} \):

\[ R_{\boldsymbol{\theta}}+\lambda J(\boldsymbol{\theta}) \]

Here \( J \) is a measure of the size of the weights e.g. \( \sum \beta^2 + \sum \alpha^2 \) (usually excluding biases).

  • \( R \) and \( J \) are effectively in competition, and as we are using a gradient search, you can think of \( \lambda J \) as preventing us from reaching our global minimum for \( R \)
  • we must estimate \( \lambda \)

Validation

Measuring generalisation error:

  • Maintain an independent dataset which is not used to develop the model, but is used to measure the models performance/generality
  • Seek a model that predicts data we have not yet seen - the use of validation or cross-validation data simulates this scenario
  • Simplest method is to use a single validation dataset, and stop' fitting when the performance of the model against the validation dataset begins to deteriorate

Validation

Use the validation data to:

  • Determine number of hidden nodes and/or layers.
  • Estimate the weight decay parameter (below).
  • To terminate fitting (NNs get more complex the more iterations of weight estimation we do).
  • Other parameters or architecture decisions.

These are all aspects of NN complexity.

Controlling complexity

Example via Caret package.

NN problems overview

  • Lack of interpretability: these models are effectively black-box - although variable importance measures exist.
  • Over-fitting: NNs are clearly prone to overfitting if some proper controls are not put in place
  • Specification decisions: there are a bewildering array of activation functions, combination functions, output functions, training methods, parameters (e.g. number of hidden units and layers), standardisations etc.
  • Local minima: as for standard non-linear regression, we may require multiple fits to ensure we have not been trapped in a sub-optimal solution by local minima in the error function
  • Long run-times (as hinted at in SAS EM by the default option of “Maximum run-time=4 hours”.): These models can take a very long time to fit

More on NNs

We've yet to cover:

  • Some guidance for setting up an NN
    • Output functions and loss functions
    • Standardisation of inputs
    • Dealing with different types of inputs
  • Convoluted Neural Networks