Creating a CNN model for handwritten digit recognition (MNIST)
March 2, 2018
Table of Contents
Following my overview of Convolutional Neural Networks (CNN) in a previous post, now lets build a CNN model to 1) classify images of handwritten digits, and 2) see what is learned by this type of model.
Handwritten digit recognition is the ‘Hello World’ example of the CNN world. I’ll be using the MNIST database of handwritten digits, which you can find here. The MNIST database contains grey scale images of size 28×28 (pixels), each containing a handwritten number from 0-9 (inclusive). The goal: given a single image, how do we build a model that can accurately recognize the number that is shown?
Let’s take a moment to consider how this model might work. The numbers from 0-9 look different and have different features: some contain only straight edges (e.g. 4) , some contain rounded edges (e.g. 5), and others have enclosed spaces (e.g. 8). A CNN is an ideal model for this problem because it can learn these different feature types and detect their presence/location in an image. We don’t need to be explicit about what features are important (often times we don’t know what features are important) – these so-called kernels will be learned by our model.
For this example I’ll be using the Lasagne package in Python. Lasagne is a library that allows us to build and train neural networks using Theano, and allows us to avoid a lot of the plumbing required to pass data around the layers of the network. If you’re interested in this plumbing, Michael Nielsen wrote a great module containing classes for the different layer types that’ll we need, which you can find here. In this post I’ll mainly be breaking down the Lasagne tutorial, but with a few additions. The code below is primarily from their tutorial
Loading the data
To begin, let’s get some imports out of the way:
Now, let’s import the MNIST data. The following function will download the data from Yann LeCun’s website and split the data into a train and test sets (along with their labels):
To load the data, simply call the function:
We can take a quick look at a random MNIST image to get a sense of what we’re working with:
Training the CNN
The first thing we need to do is define the architecture of the network. We can do that using Lasagne layers.
The following function defines the model architecture (note that the output of one layer is passed as input into the next layer):
Input layer that accepts a 4D tensor containing our input image
The first index is the mini-batch index, second index is the number of channels (1 channel for grey scale images), and finally we have the width and height of the input
Convolution layer with 32 5×5 kernels followed by ReLU activation. It is common to use a rectifier as the activation function instead of the traditional sigmoid to avoid saturation of neurons.
Max pooling layer (2×2)
Another pair of convolution and pooling layers
Fully connected layer with 256 neurons and dropout of 50%
Output layer for each of our 10 digits, again with dropout
CNNs are slow to train so it is typical to use a GPU. However, CNNs also consume a lot of memory in the convolution layers, which is generally lacking on consumer GPUs. To get around the memory limit we use mini-batch gradient descent and only train on small batches of the training data at a time. So we need to define a function to iterate over the dataset in batches:
Finally, a function that defines the loss and update expressions, begins the main training loop over the mini-batches, and returns the learned parameter values:
To start training to simply call the function with our desired batch size and number of training epochs. weights will contain our learned parameters:
With only 5 epochs (about 2 minutes of training on a GPU) we get a test set classification accuracy of 98.85%. If we train over 50 epochs we can achieve classification accuracy of 99.6%.
CNNs are extraordinarily good at learning data that has clear spatial structure, even when we only use 2 convolution layers. We can create Deep CNNs by adding more layers, which increases the representational power of the network.
What is being learned?
We can visualize the 32 kernels the network learned for convolution layer 1. Each kernel is capturing an important element of a handwritten digit, which can be combined through later layers to form each of the 0-9 digits:
They may look somewhat random at first glance, but we can see that clear structure being learned in most kernels. For example, kernels 3 and 4 seem to be learning diagonal edges in opposite directions, and other capture round edges or enclosed spaces:
We can also take a single training example and convolve each kernel over the digit to see the feature maps, which shows the areas of the input that are being activated by that kernel. Kernels and feature maps for later convolution layers cannot be visualized easily because they don’t operate directly on the input images – their input is dependent on the output of the previous layer. We can directly visualize the C1 level feature maps for a single random training image:
In summary, CNNs can learn the visual structure of images and learn to identify the features that distinguish one number from another. This idea can be extended to colour images pretty easily (for example with the CIFAR-10 dataset), or even non-image-based inputs as long as there’s some spatial structure inherent in the input.
As always, the full code for this example can be found in my GitHub repo here.