CNN using PyTorch for MNIST classification

In this example we use a simple convolutional neural network for handwritten digit recognition. This gets an accuracy on the test dataset of \(98\%\). That is better performance with a smaller model than in the previous example. The training took less than a minute on my laptop.

There is a nice explanation of CNNs here.

The depth of the input volume is the number of channels, e.g. the RGB associated with a pixel. The depth of the output volume is the number of filters in the CNN.
The number of pixels we move the filter at each step.
The depth of zero-padding around the image.
A layer included between convolutions that reduces the size of the volume by taking a summary (typically the max) over the spatial extent. A common choice is max pooling, a \(2\times 2\) filter with a stride of \(2\) and taking the max value.

Loading packages

import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import plotnine as p9
from plotnine import *

import niceneuron.io as nn_io
import niceneuron.data as nn_data

Define input and output filenames

Define the files that are used as input and output to keep the program DRY.

images_path = 'data/train-images.idx3-ubyte'
labels_path = 'data/train-labels.idx1-ubyte'
test_images_path = 'data/t10k-images.idx3-ubyte'
test_labels_path = 'data/t10k-labels.idx1-ubyte'

loss_csv = "example-2024-01-04-loss.csv"
loss_png = "example-2024-01-04-loss.png"
trained_model_file = "example-2024-01-04-model.pt"

Loading testing and training data

images = nn_io.read_idx(images_path)
images = images.astype(np.float32)
labels = nn_io.read_idx(labels_path)

train_dataset = nn_data.MNISTDataset(
    images, labels, flatten=False, normalise=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_images = nn_io.read_idx(test_images_path)
test_images = test_images.astype(np.float32)
test_labels = nn_io.read_idx(test_labels_path)
test_dataset = nn_data.MNISTDataset(
    test_images, test_labels, flatten=False, normalise=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

Defining the model

Here we define the neural network we will use with the DemoCNN class.

class DemoCNN(nn.Module):
    def __init__(self):
        super(DemoCNN, self).__init__()
        self._conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self._pool1 = nn.AvgPool2d(2)
        self._conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0)
        self._pool2 = nn.AvgPool2d(2)
        self._fc1 = nn.Linear(5*5*16, 84)
        self._fc2 = nn.Linear(84, 10)
        self._relu = nn.ReLU()

    def forward(self, x):
        x = self._pool1(self._relu(self._conv1(x)))
        x = self._pool2(self._relu(self._conv2(x)))
        x = x.view(-1, 5*5*16)
        x = self._relu(self._fc1(x))
        x = self._fc2(x)
        return x

Train the model

Then we need to define a training loop for the model.

model = DemoCNN()

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

loss_history = []
training_start_time = time.time()
for epoch in range(5):
    epoch_cumloss = 0
    for images, label in train_dataloader:
        images = images.unsqueeze(1)
        # Forward pass
        logits = model(images)
        loss = loss_fn(logits, label)
        epoch_cumloss += loss.item()

        # Backward pass

    print(f"Epoch {epoch} loss: {epoch_cumloss}")

training_finish_time = time.time()
print(f"Training took: {training_finish_time - training_start_time}")

loss_df = pd.DataFrame(loss_history, columns=['epoch', 'loss'])
loss_df.to_csv(loss_csv, index=False)

Saving the model

Instead of saving the nn.Module as a single object we just save the state of the model which is the recommendation from PyTorch.

torch.save(model.state_dict(), trained_model_file)

You could then reload this state later with a new instance of the model:

prev_state_dict = torch.load(trained_model_file)
new_model = DemoCNN()

Test the model

Then we test it with the following loop. Note that we have turned the gradient tracking off as this runs slightly quicker.

correct = 0
total = 0
with torch.no_grad():
    for images, label in test_dataloader:
        images = images.unsqueeze(1)
        logits = model(images)
        predicted = torch.argmax(logits, dim=1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

# print test accuracy
print(f"Test accuracy: {correct / total}")

Visualise the results

Plot the resulting loss values across the epochs.


loss_df = pd.read_csv(loss_csv)
loss_p9 = (
    ggplot(loss_df, aes(x='epoch', y='loss')) +
    geom_point() +
    geom_line() +
loss_p9.save(loss_png, height = 2.9, width = 4.1)

Author: Alexander E. Zarebski

Created: 2025-01-06 Mon 09:32
