Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
SPDX-License-Identifier: MIT

Wrap an Image Classification Dataset#

In this how-to, we will show you how to wrap the CIFAR-10 dataset as a maite.protocols.image_classification.Dataset. CIFAR-10 is an image classification dataset made by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. It’s available both through Hugging Face Hub and torchvision.datasets, but this notebook uses Hugging Face Hub.

The general steps for wrapping a dataset are: * Understand the source (native) dataset * Create a wrapper that makes the source dataset conform to the MAITE dataset protocol (interface) * Verify that the wrapped dataset works correctly

1 Load the CIFAR-10 dataset from Hugging Face Hub#

Load the necessary modules:

import PIL.Image
import datasets
import numpy as np
import numpy.typing as npt
from IPython.display import display

from typing import Literal

import maite.protocols.image_classification as ic
from maite.protocols import DatasetMetadata

%load_ext watermark
%watermark -iv -v
Python implementation: CPython
Python version       : 3.9.21
IPython version      : 8.18.1

maite   : 0.7.3
IPython : 8.18.1
PIL     : 11.1.0
numpy   : 1.26.4
datasets: 3.3.1

Load the CIFAR-10 dataset from Hugging Face Hub:

cifar10_dataset_dict = datasets.load_dataset("uoft-cs/cifar10")
cifar10_dataset_dict
README.md:   0%|          | 0.00/5.16k [00:00<?, ?B/s]
train-00000-of-00001.parquet:   0%|          | 0.00/120M [00:00<?, ?B/s]
test-00000-of-00001.parquet:   0%|          | 0.00/23.9M [00:00<?, ?B/s]
Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]
Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]
DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
})

The cifar10_dataset_dict variable is a Hugging Face datasets.DatasetDict object containing a 50000-image train dataset and a 10000-image test dataset.

cifar10_train: datasets.Dataset = cifar10_dataset_dict["train"] # type: ignore
cifar10_test: datasets.Dataset = cifar10_dataset_dict["test"] # type: ignore

2 Examine the source dataset#

In this section we examine the dataset and confirm we understand it before wrapping it.

Print all the labels (class names) in the dataset:

for label_number, label_name in enumerate(cifar10_train.features["label"].names):
    print(f"Label {label_number}: {label_name}")
Label 0: airplane
Label 1: automobile
Label 2: bird
Label 3: cat
Label 4: deer
Label 5: dog
Label 6: frog
Label 7: horse
Label 8: ship
Label 9: truck

Spot check the items in the dataset:

for i in range(0, len(cifar10_train), 10000):
    item = cifar10_train[i]
    label, img = item["label"], item["img"]
    label_name = cifar10_train.features["label"].names[label]
    print(f"CIFAR-10 Train {i}")
    print(f"Label: {label} {label_name}")
    display(img)
CIFAR-10 Train 0
Label: 0 airplane
../_images/wrap_image_classification_dataset_13_1.png
CIFAR-10 Train 10000
Label: 8 ship
../_images/wrap_image_classification_dataset_13_3.png
CIFAR-10 Train 20000
Label: 1 automobile
../_images/wrap_image_classification_dataset_13_5.png
CIFAR-10 Train 30000
Label: 1 automobile
../_images/wrap_image_classification_dataset_13_7.png
CIFAR-10 Train 40000
Label: 6 frog
../_images/wrap_image_classification_dataset_13_9.png

Everything appears as expected: CIFAR-10 images are 32x32 RGB and the labels correctly correspond to the images.

3 Create a MAITE wrapper for the source dataset#

We create a class that implements the maite.protocols.image_classification.Dataset protocol.

A MAITE image classification compliant Dataset must include the following two methods:

  • __getitem__(index: int): Returns a tuple containing the image ArrayLike, the target label, and the datum metadata.

  • __len__() : Returns the number of data elements in the dataset.

and an attribute:

  • metadata containing an “id” string and an optional (but recommended) map from indexes to label.

The __getitem__ method is the most complex; it returns a tuple consisting of the image ArrayLike, the target label, and the datum metadata. The image ArrayLike has the shape (channel, height, width). Both the dtype and value range of the image ArrayLike are not specified. You must be careful to ensure that users of the Dataset such as Augmentations or Models expect the provided dtype and values. The target label is a one-hot encoding of the label. The datum metadata is a dictionary that includes at least an id, of type str or int, for the datum.

# Extend DatasetMetadata to record whether dataset is a train or test split
class CustomDatasetMetadata(DatasetMetadata):
    split: str

class Cifar10Dataset:
    def __init__(
        self,
        cifar10_dataset_split: datasets.Dataset,
        split: Literal["train", "test"]
    ):
        # Save the CIFAR-10 dataset given by the user. This is helpful if you want to
        # sample the dataset using the Hugging Face API prior to using it.
        self.dataset = cifar10_dataset_split

        # Create a dictionary mapping label number to label name from the label metadata
        # in the underlying dataset.
        index2label = {
            i: label for i, label in enumerate(self.dataset.features["label"].names)
        }

        # Create required metadata attribute (with custom split key)
        self.metadata: DatasetMetadata = CustomDatasetMetadata(
            id="CIFAR-10",
            index2label=index2label,
            split=split
        )

        # Get the number of classes used in the dataset
        num_classes = self.dataset.features["label"].num_classes

        # Create a collection of target vectors to be used for the one-hot encoding of labels
        self.targets = np.eye(num_classes)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(
        self, index: int
    ) -> tuple[npt.NDArray, npt.NDArray, ic.DatumMetadataType]:
        # Look up item in the dataset, which returns a dictionary with two keys:
        # - "img": PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
        # - "label": int
        item = self.dataset[index]
        img_pil = item["img"]
        label = item["label"]

        # Convert the PIL image to a NumPy array
        img_hwc = np.array(img_pil)  # shape (H, W, C)

        # Use MAITE array index convention for representing images: shape (C, H, W)
        img_chw = img_hwc.transpose(2, 0, 1)

        # Get one-hot encoded tensor indicating the class label for this image
        target = self.targets[label, :].copy()

        # CIFAR-10 does not have any extra metadata, so we record only the index of this datum
        metadata: ic.DatumMetadataType = {"id": index}

        return img_chw, target, metadata

4 Examine the MAITE-wrapped dataset#

Wrap the train dataset and print the metadata:

train_dataset: ic.Dataset = Cifar10Dataset(cifar10_dataset_split=cifar10_train, split="train")
train_dataset.metadata
{'id': 'CIFAR-10',
 'index2label': {0: 'airplane',
  1: 'automobile',
  2: 'bird',
  3: 'cat',
  4: 'deer',
  5: 'dog',
  6: 'frog',
  7: 'horse',
  8: 'ship',
  9: 'truck'},
 'split': 'train'}

Load the test dataset using the test split:

test_dataset: ic.Dataset = Cifar10Dataset(cifar10_dataset_split=cifar10_test, split="test")
test_dataset.metadata
{'id': 'CIFAR-10',
 'index2label': {0: 'airplane',
  1: 'automobile',
  2: 'bird',
  3: 'cat',
  4: 'deer',
  5: 'dog',
  6: 'frog',
  7: 'horse',
  8: 'ship',
  9: 'truck'},
 'split': 'test'}

Check the length of the datasets:

print(f"CIFAR-10 size: train={len(train_dataset)}, test={len(test_dataset)}")
CIFAR-10 size: train=50000, test=10000

Examine some of the data points in the wrapped dataset:

def print_datum(dataset, index):
    img_arr_chw, target, datum_metadata = dataset[index]
    print(f"Datum {datum_metadata['id']}")
    print(f"  Input Image Array: {str(img_arr_chw)[:30]}...")
    print(f"    shape={img_arr_chw.shape}")
    print(f"    dtype={img_arr_chw.dtype}")
    display(PIL.Image.fromarray(img_arr_chw.transpose(1, 2, 0)))
    print(f"  Target: {target}")
    label_index = np.argmax(target)
    print(f"    target index: {np.argmax(target)}")
    print(f"    target label: {dataset.metadata['index2label'][label_index]}")
    print("  Metadata:")
    print(f"    {datum_metadata}")
for i in [0, 3000, 6000]:
    print_datum(test_dataset, i)
    print()
Datum 0
  Input Image Array: [[[158 159 165 ... 137 126 116...
    shape=(3, 32, 32)
    dtype=uint8
../_images/wrap_image_classification_dataset_27_1.png
  Target: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
    target index: 3
    target label: cat
  Metadata:
    {'id': 0}

Datum 3000
  Input Image Array: [[[17 17 18 ... 23 21 20]
  [2...
    shape=(3, 32, 32)
    dtype=uint8
../_images/wrap_image_classification_dataset_27_3.png
  Target: [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
    target index: 5
    target label: dog
  Metadata:
    {'id': 3000}

Datum 6000
  Input Image Array: [[[201 211 211 ... 191 183 188...
    shape=(3, 32, 32)
    dtype=uint8
../_images/wrap_image_classification_dataset_27_5.png
Target: [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
  target index: 8
  target label: ship
Metadata:
  {'id': 6000}

5 Conclusion#

We’ve successfully wrapped the source CIFAR-10 dataset as a MAITE-compliant image classification dataset. Note that:

  • We don’t observe any static type checking errors (assuming we’ve enabled static type checking in our IDE, e.g., by following the steps in the Enable Static Type Checking guide).

  • The wrapped dataset appears to be working correctly, e.g., image, ground truth target label, and metadata are consistent.

At this point, the wrapped dataset could be used as part of a larger test & evaluation workflow, e.g., by using MAITE’s evaluate workflow to compute the accuracy of a model on the CIFAR-10 test split.