AANN 22/02/2024

Home

Index

Early stopping

In this example we consider an approach to early stopping described by Lutz Prechelt (1998). They propose several classes of stopping criteria. We will consider the third class because it is the simplest and there doesn't seem to be a clear winner among the different classes. The idea of the stopping criteria is to stop the training loop once there have been \(s\) consecutive instances when the validation error has increased. The hyperparameter \(s\) is then used to tune how early we stop.

As an example, we will consider the same CNN network used for MNIST classification in this previous post and provide an example of a helper validation manager class that implements the early stopping rule. This ends up taking a bit longer to run, but it gives us a reasonable justification for stopping the training, beyond saying "we ran out of time to keep training". This gives us a model that gets an accuracy of \(99\%\) on the validation set, but it takes a \(<5\) minutes to train.

Loading packages

There are two packages we will use that haven't seen much action in previous posts: collections which provides a deque, and copy which is useful to copying model states. You could easily get away without using these, but they make some of the code a little cleaner and they are both part of the standard library.

import collections
import copy
import time
import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(0)

import numpy as np
import pandas as pd
import plotnine as p9
from plotnine import *

import sys
import os
from pathlib import Path

current_directory = Path(os.getcwd())
sys.path.append(str(current_directory))

import niceneuron.plot as nn_plot
import niceneuron.io as nn_io
import niceneuron.data as nn_data

Setting data and filenames

As usual there is a bit of fuss here to get the data all set up. This can be skipped on a first reading.

images_path = "data/train-images.idx3-ubyte"
labels_path = "data/train-labels.idx1-ubyte"
test_images_path = "data/t10k-images.idx3-ubyte"
test_labels_path = "data/t10k-labels.idx1-ubyte"
loss_csv = "example-2024-02-22/loss.csv"
training_png = "example-2024-02-22/training.png"


images = nn_io.read_idx(images_path)
images = images.astype(np.float32)
labels = nn_io.read_idx(labels_path)

train_dataset = nn_data.MNISTDataset(images, labels, flatten=False, normalise=True)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True
)

test_images = nn_io.read_idx(test_images_path)
test_images = test_images.astype(np.float32)
test_labels = nn_io.read_idx(test_labels_path)
test_dataset = nn_data.MNISTDataset(
    test_images, test_labels, flatten=False, normalise=True
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=64, shuffle=False
)

Defining the network

This is the simple convolutional neural network we have used before.

class DemoCNN(nn.Module):
    def __init__(self):
        super(DemoCNN, self).__init__()
        self._conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self._pool1 = nn.AvgPool2d(2)
        self._conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0)
        self._pool2 = nn.AvgPool2d(2)
        self._fc1 = nn.Linear(5 * 5 * 16, 84)
        self._fc2 = nn.Linear(84, 10)
        self._relu = nn.ReLU()

    def forward(self, x):
        x = self._pool1(self._relu(self._conv1(x)))
        x = self._pool2(self._relu(self._conv2(x)))
        x = x.view(-1, 5 * 5 * 16)
        x = self._relu(self._fc1(x))
        x = self._fc2(x)
        return x

Early stopping

The idea behind this method of early stopping is simple, keep training until there are \(s\) consecutive increasing validation errors. To simplify this, I have written a class ValidationManager which takes care of this.

ValidationManager

This could more general, but it serves as a nice starting point for creating these helper objects. The main methods to note are validate, which runs the validation loop and records how the training is going, and the should_stop_early predicate which checks if the stopping criteria has been reached.

class ValidationManager:
    """
    A class for managing early stopping during training.
    """

    def __init__(self, model, loss_fn, patience=3):
        self._model = model
        self._loss_fn = loss_fn
        self._patience = patience
        self._loss_history = collections.deque([float("inf")] * patience)
        self._best_loss = float("inf")
        self._best_model = None
        self._validation_history = []

    def validate(self, epoch, dataloader):
        cum_loss = 0
        with torch.no_grad():
            for images, label in dataloader:
                images = images.unsqueeze(1)
                logits = self._model(images)
                loss = self._loss_fn(logits, label)
                cum_loss += loss.item()

        if cum_loss < self._best_loss:
            self._best_loss = cum_loss
            self._best_model = copy.deepcopy(self._model.state_dict())

        self._validation_history.append((epoch, cum_loss))
        self._loss_history.popleft()
        self._loss_history.append(cum_loss)

    def should_stop_early(self):
        return all(
            self._loss_history[i] < self._loss_history[i + 1]
            for i in range(self._patience - 1)
        )

    def get_validation_history(self):
        return self._validation_history

    def get_best_model(self):
        return self._best_model

Training loop (with early stopping)

Here we have the same training loop as used in the past, but now slightly cleaner thanks to the validation manager object.

model = DemoCNN()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
validation_manager = ValidationManager(model, loss_fn, patience=4)

model.train()
loss_history = []
training_start_time = time.time()
for epoch in range(30):
    epoch_cumloss = 0
    for images, label in train_dataloader:
        images = images.unsqueeze(1)
        logits = model(images)
        loss = loss_fn(logits, label)
        epoch_cumloss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} loss: {epoch_cumloss}")
    loss_history.append((epoch, epoch_cumloss))
    validation_manager.validate(epoch, test_dataloader)
    if validation_manager.should_stop_early():
        break

training_finish_time = time.time()
print(f"Training took: {training_finish_time - training_start_time}")

Visualise the training

training.png

Figure 1: Testing and validation errors during the training loop.

loss_df = pd.DataFrame(loss_history, columns=["epoch", "loss"])
loss_df["model"] = "training"
validation_df = pd.DataFrame(
    validation_manager.get_validation_history(), columns=["epoch", "loss"]
)
validation_df["model"] = "validation"

plot_df = pd.concat([loss_df, validation_df], axis=0)
plot_df.to_csv(loss_csv, index=False)

training_p9 = nn_plot.plot_loss_curve(plot_df) + p9.scale_y_log10()
training_p9.save(training_png, width=10, height=10, dpi=300)

Testing the model

We end up with \(99\%\) accuracy, as expected 🥳

best_model = validation_manager.get_best_model()
model.load_state_dict(best_model)
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, label in test_dataloader:
        images = images.unsqueeze(1)
        logits = model(images)
        predicted = torch.argmax(logits, dim=1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

print(f"Test accuracy: {correct / total}")

Discussion

Early stopping provides an uncontroversial stopping criteria. Empirical results from Lutz Prechelt (1998) suggests there isn't a universally best criteria and the one used here seems intuitive and simple so presents a nice choice.

The ValidationManager class — see here — may be over-doing it a little, but it does keep a lot of the details out of the training loop which is nice. The class could be made a bit more general, but this is just a proof-of-concept of the value of a validation manager.

Author: Alexander E. Zarebski

Created: 2024-02-21 Wed 11:49

Validate