Save Parameter Weights While Fitting a Model

September 12, 2024

Q MNIST

The MNIST dataset bundled in machine-learning packages has only 10,000 samples in the test set. The full MNIST test set described in the early report used 60,000 samples, and was not included in the early version of the machine-learning package*1.

The QMNIST reconstructed MNIST dataset based on the original data found in the NIST Special Database 19.

This post uses the QMNIST test set, which has 60,000 sample, for an exercise of showing the method of saving weights while fitting a model, and importing back the weight for continuing the fitting process.

Different format

The format of the label files in MNIST is a one-dimensional tensor of unsigned bytes (idx1-ubyte), but the QMNIST formated its label files as the tensor in a two-dimensional and used integers (idx2-int). These integers were to represent eight pieces of information. We will use only the column 0.

ColumnDescriptionRange
0Character class0 to 9
1NIST HSF series0, 1, or 4
2NIST writer ID0-610 and 2100-2599
3Digit index for this writer0 to 149
4NIST class code30-39
5Global NIST digit index0 to 281769
6Duplicate0
7Unused0

To import the QMNIST dataset, we use the Python scripts provided with the QMNIST dataset. This approach saves us the time it takes to work out the original NIST image formats[*1].

The first 50 digits in the MNIST and QMNIST test sets look seemingly identical. The last 50 digits in the QMNIST are not seen in the MNIST.

MNIST

first 50 mnist

QMNIST

first 50 qmnist

MNIST

last 50 mnist

QMNIST

last 50 qmnist

A Very Deep CNN Model

ResNet is one of the models in the ILSVRC 2015 challenge that scored a very low error rate. Its authors introduced methods to fine-tune the CNN architecture. We are going to use one of the ResNet models that is avaialble in the public domain to fit QMNIST.

Here, we use the ResNet model implemented in the esemble of the three models, the "small VGG", the ResNet and the "very small VGG". The esemble model scored the lowest error rate of 0.20% in the MNIST competition, listed in the "How far can we go with MNIST?". The following is the step-by-step guide.

1. Load the Package and Function

We update this line, from keras.layers import PReLU, because of some changes in the recent version of the Keras. [See Appendix 2 for the script.]

2. Build the ResNet

Then, we use the model building script from the mnist competition to build the model.

3. Load QMNIST

Next, we use the script provided by the QMNIST to load the data.

We split the training set into a train and a validation sets and keep the test set as it is. We use the train_test_split in sklearn.model_selection to load the data.

train, valid, test = load_qmnist()

4. Set up a Callback to Save the Weights

This is a rather deep model, so we set up a callback to save the weight at the end of every epoch. As usual, we save up the fitting result to check the evolution of the loss and accuracy.

import pandas as pd
import pickle as pkl


EPOCH = 3
MODEL_PATH = "~/resnet.weights.h5"

X = Input(shape=[28, 28, 1])
y = resnet(X)
model = Model(X, y, name="resnet")

model.compile(
    loss="sparse_categorical_crossentropy", 
    optimizer="nadam", metrics=["accuracy"])

cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=MODEL_PATH, save_weights_only=True, verbose=1)

history1 = model.fit(
        train[0], train[1],
        epochs=EPOCH,
        validation_data=(valid[0], valid[1]),
        callbacks=[cp_callback])


# Only save the history of loss and accuracy 
SAVE_PATH = "~/history1.pkl"
with open(SAVE_PATH, "wb") as f:
  pkl.dump(history, f)

5. Resume the Model Fitting

Suppose that we stop the fitting at the end of the third epoch and want to continue later. Because we only save the weight, we need to reconstruct the model and then load the weight.

input_shape = train[0][0].shape
X = Input(shape=input_shape)
y = resnet(X)
model = Model(X, y, name="resnet")

model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="nadam", metrics=["accuracy"])

cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=MODEL_PATH, save_weights_only=True, verbose=1)

model.load_weights(MODEL_PATH)

We can evaluate the model with its weights adjusted only for 3 iterations.

# Evaluate the model
loss, acc = model.evaluate(test[0], test[1], verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
# 1875/1875 - 15s - 8ms/step - accuracy: 0.9802 - loss: 0.0672

Then we resume the model fit with more iterations.

EPOCH = 27

history2 = model.fit(
        train[0], train[1],
        epochs=EPOCH,
        validation_data=(valid[0], valid[1]),
        callbacks=[cp_callback])

Result

qmnist_ResNet

The result looks impressive. When we test the model on the 60,000-sample test set, the model reports an accuracy of 99.16%. The error rate is about 0.84%. Recall that the esemble of the three models reached 0.20% for the 10,000-sample MNIST test set and the simple CNN model we tested before resulted in a 3% error rate.

The following figure shows 50 digits that the model casted a different prediction from the given label. The total number of these digits is 506.

qmnist_error

We may plot a confusion matrix to get a ballpark idea of which digits were classified incorrecly the most. As it turns out, the digits, 3 and 4 were classified as 2 (27 times) and 9 (42 times). Three was also classified as 5 for 23 times.

confusion_matrix

In summary, achieving very high accuracy can require significant effort. When the benefits outweigh the costs, such as in the hypothetical case of processing financial transactions, using an ensemble model might be worthwhile. This approach can reduce the number of misidentified digits to 120 out of 60,000. A detailed analysis of the error pattern reveals the issue that each hand-written digit is unique.

49

32

35

Footnote

*1. QMNIST is also included in the latest version of PyTorch.

Appendix

1. The Script of Reading the QMNIST and the MNIST

from utils import mnist_reader
from utils.helper import read_idx3_ubyte
from utils.helper import read_idx2_int


def load_qmnist(path="../data/QMNIST/", validation_ratio=0.2, random_seed=42):

  X_train = read_idx3_ubyte(path + "qmnist-train-images-idx3-ubyte.gz")
  Y_train_all = read_idx2_int(path + "qmnist-train-labels-idx2-int.gz")
  X_test = read_idx3_ubyte(path + "qmnist-test-images-idx3-ubyte.gz")
  Y_test_all = read_idx2_int(path + "qmnist-test-labels-idx2-int.gz")

  y_test = np.array([i[0] for i in Y_test_all])
  y_train = np.array([i[0] for i in Y_train_all])

  train_images = X_train.reshape(-1, 28, 28, 1) / 255.0
  test_images = X_test.reshape(-1, 28, 28, 1) / 255.0

  X_train_split, X_val_split, y_train_split, y_val_split = train_test_split(train_images, y_train, test_size=validation_ratio, random_state=random_seed)
  return (X_train_split, y_train_split), (X_val_split, y_val_split), (test_images, y_test)


PROJECT_ROOT_DIR = "."
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images")
os.makedirs(IMAGES_PATH, exist_ok=True)


def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)


def plot_examples(Xdata, Ydata, fig_id):
    n_rows = 5
    n_cols = 10
    plt.figure(figsize=(n_cols * 2, n_rows * 2))
    for row in range(n_rows):
        for col in range(n_cols):
            index = n_cols * row + col
        
            image = convert2image_format(Xdata[index,:])
            label_data = Ydata[index]
            combined_label = f"{label_data}"

            plt.subplot(n_rows, n_cols, index + 1)
            plt.imshow(image, cmap="binary", interpolation="nearest")
            plt.axis('off')
            plt.title(combined_label, fontsize=20)
        
    plt.subplots_adjust(wspace=0.2, hspace=0.5)
    plt.tight_layout()
    save_fig(fig_id, tight_layout=False)



train, valid, test = load_qmnist(path='../data/QMNIST/')
np.unique(test[0], return_counts=True)
# (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
#  array([5952, 6791, 6026, 6084, 5780, 5454, 5957, 6231, 5890, 5835]))

DATA_DIR = "../data/MNIST/"
X0_test, Y0_test = mnist_reader.load_mnist(path=DATA_DIR, kind="t10k")
X0_train, Y0_train = mnist_reader.load_mnist(path=DATA_DIR, kind="train")


plot_examples(train[0], train[1], 'first_50_qmnist_train_digits')
plot_examples(test[0], test[1], 'first_50_qmnist_test_digits')
plot_examples(X0_train, Y0_train, 'first_50_mnist_train_digits')
plot_examples(X0_test, Y0_test, 'first_50_mnist_test_digits')

2. The Python Script of Fitting the ResNet

import pickle as pkl
import os
import numpy as np
from keras.layers import (
    Conv2D,
    BatchNormalization,
    MaxPooling2D,
    ZeroPadding2D,
    AveragePooling2D,
    add,
    Dense,
    Flatten,
)

from keras.layers import PReLU 
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model

# ResNet
def resnet(input_tensor):
    """Inference function for ResNet

    y = resnet(X)

    Parameters
    ----------
    input_tensor : keras.layers.Input

    Returns
    ----------
    y : softmax output
    """

    def name_builder(type, stage, block, name):
        return "{}{}{}_branch{}".format(type, stage, block, name)

    def identity_block(input_tensor, kernel_size, filters, stage, block):
        F1, F2, F3 = filters

        def name_fn(type, name):
            return name_builder(type, stage, block, name)

        x = Conv2D(F1, (1, 1), name=name_fn("res", "2a"))(input_tensor)
        x = BatchNormalization(name=name_fn("bn", "2a"))(x)
        x = PReLU()(x)

        x = Conv2D(F2, kernel_size, padding="same", name=name_fn("res", "2b"))(x)
        x = BatchNormalization(name=name_fn("bn", "2b"))(x)
        x = PReLU()(x)

        x = Conv2D(F3, (1, 1), name=name_fn("res", "2c"))(x)
        x = BatchNormalization(name=name_fn("bn", "2c"))(x)
        x = PReLU()(x)

        x = add([x, input_tensor])
        x = PReLU()(x)

        return x

    def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
        def name_fn(type, name):
            return name_builder(type, stage, block, name)

        F1, F2, F3 = filters

        x = Conv2D(F1, (1, 1), strides=strides, name=name_fn("res", "2a"))(input_tensor)
        x = BatchNormalization(name=name_fn("bn", "2a"))(x)
        x = PReLU()(x)

        x = Conv2D(F2, kernel_size, padding="same", name=name_fn("res", "2b"))(x)
        x = BatchNormalization(name=name_fn("bn", "2b"))(x)
        x = PReLU()(x)

        x = Conv2D(F3, (1, 1), name=name_fn("res", "2c"))(x)
        x = BatchNormalization(name=name_fn("bn", "2c"))(x)

        sc = Conv2D(F3, (1, 1), strides=strides, name=name_fn("res", "1"))(input_tensor)
        sc = BatchNormalization(name=name_fn("bn", "1"))(sc)

        x = add([x, sc])
        x = PReLU()(x)

        return x

    net = ZeroPadding2D((3, 3))(input_tensor)
    net = Conv2D(64, (7, 7), strides=(2, 2), name="conv1")(net)
    net = BatchNormalization(name="bn_conv1")(net)
    net = PReLU()(net)
    net = MaxPooling2D((3, 3), strides=(2, 2))(net)

    net = conv_block(net, 3, [64, 64, 256], stage=2, block="a", strides=(1, 1))
    net = identity_block(net, 3, [64, 64, 256], stage=2, block="b")
    net = identity_block(net, 3, [64, 64, 256], stage=2, block="c")

    net = conv_block(net, 3, [128, 128, 512], stage=3, block="a")
    net = identity_block(net, 3, [128, 128, 512], stage=3, block="b")
    net = identity_block(net, 3, [128, 128, 512], stage=3, block="c")
    net = identity_block(net, 3, [128, 128, 512], stage=3, block="d")

    net = conv_block(net, 3, [256, 256, 1024], stage=4, block="a")
    net = identity_block(net, 3, [256, 256, 1024], stage=4, block="b")
    net = identity_block(net, 3, [256, 256, 1024], stage=4, block="c")
    net = identity_block(net, 3, [256, 256, 1024], stage=4, block="d")
    net = identity_block(net, 3, [256, 256, 1024], stage=4, block="e")
    net = identity_block(net, 3, [256, 256, 1024], stage=4, block="f")
    net = AveragePooling2D((2, 2))(net)

    net = Flatten()(net)
    net = Dense(10, activation="softmax", name="softmax")(net)

    return net

# Code for Loading QMNIST
import codecs
import gzip
import lzma
import numpy as np
from sklearn.model_selection import train_test_split


def get_int(b):
    return int(codecs.encode(b, "hex"), 16)


def open_maybe_compressed_file(path):
    if path.endswith(".gz"):
        return gzip.open(path, "rb")
    elif path.endswith(".xz"):
        return lzma.open(path, "rb")
    else:
        return open(path, "rb")


def read_idx2_int(path):
    with open_maybe_compressed_file(path) as f:
        data = f.read()
        assert get_int(data[:4]) == 12 * 256 + 2
        length = get_int(data[4:8])
        width = get_int(data[8:12])
        parsed = np.frombuffer(data, dtype=np.dtype(">i4"), offset=12)
        return parsed.astype("i4").reshape(length, width)


def read_idx3_ubyte(path):
    with open_maybe_compressed_file(path) as f:
        data = f.read()
        assert get_int(data[:4]) == 8 * 256 + 3
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        return parsed.reshape(length, num_rows, num_cols)


def load_qmnist(path="../data/QMNIST/", validation_ratio=0.2, random_seed=42):

  X_train = read_idx3_ubyte(path + "qmnist-train-images-idx3-ubyte.gz")
  Y_train_all = read_idx2_int(path + "qmnist-train-labels-idx2-int.gz")
  X_test = read_idx3_ubyte(path + "qmnist-test-images-idx3-ubyte.gz")
  Y_test_all = read_idx2_int(path + "qmnist-test-labels-idx2-int.gz")

  y_test = np.array([i[0] for i in Y_test_all])
  y_train = np.array([i[0] for i in Y_train_all])

  train_images = X_train.reshape(-1, 28, 28, 1) / 255.0
  test_images = X_test.reshape(-1, 28, 28, 1) / 255.0

  X_train_split, X_val_split, y_train_split, y_val_split = train_test_split(train_images, y_train, test_size=validation_ratio, random_state=random_seed)
  return (X_train_split, y_train_split), (X_val_split, y_val_split), (test_images, y_test)