:: Copyright 2025, MASSACHUSETTS INSTITUTE OF TECHNOLOGY Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). SPDX-License-Identifier: MIT Wrap an Image Classification Dataset ==================================== In this how-to, we will show you how to wrap the `CIFAR-10 `__ dataset as a `maite.protocols.image_classification.Dataset <../generated/maite.protocols.image_classification.Dataset.html>`__. CIFAR-10 is an image classification dataset made by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. It’s available both through `Hugging Face Hub `__ and `torchvision.datasets `__, but this notebook uses Hugging Face Hub. The general steps for wrapping a dataset are: \* Understand the source (native) dataset \* Create a wrapper that makes the source dataset conform to the MAITE dataset protocol (interface) \* Verify that the wrapped dataset works correctly 1 Load the CIFAR-10 dataset from Hugging Face Hub ------------------------------------------------- Load the necessary modules: .. code:: ipython3 import PIL.Image import datasets import numpy as np import numpy.typing as npt from IPython.display import display from typing import Literal import maite.protocols.image_classification as ic from maite.protocols import DatasetMetadata %load_ext watermark %watermark -iv -v .. parsed-literal:: Python implementation: CPython Python version : 3.9.21 IPython version : 8.18.1 maite : 0.7.3 IPython : 8.18.1 PIL : 11.1.0 numpy : 1.26.4 datasets: 3.3.1 Load the CIFAR-10 dataset from Hugging Face Hub: .. code:: ipython3 cifar10_dataset_dict = datasets.load_dataset("uoft-cs/cifar10") cifar10_dataset_dict .. parsed-literal:: README.md: 0%| | 0.00/5.16k [00:00`__ object containing a 50000-image train dataset and a 10000-image test dataset. .. code:: ipython3 cifar10_train: datasets.Dataset = cifar10_dataset_dict["train"] # type: ignore cifar10_test: datasets.Dataset = cifar10_dataset_dict["test"] # type: ignore 2 Examine the source dataset ---------------------------- In this section we examine the dataset and confirm we understand it before wrapping it. Print all the labels (class names) in the dataset: .. code:: ipython3 for label_number, label_name in enumerate(cifar10_train.features["label"].names): print(f"Label {label_number}: {label_name}") .. parsed-literal:: Label 0: airplane Label 1: automobile Label 2: bird Label 3: cat Label 4: deer Label 5: dog Label 6: frog Label 7: horse Label 8: ship Label 9: truck Spot check the items in the dataset: .. code:: ipython3 for i in range(0, len(cifar10_train), 10000): item = cifar10_train[i] label, img = item["label"], item["img"] label_name = cifar10_train.features["label"].names[label] print(f"CIFAR-10 Train {i}") print(f"Label: {label} {label_name}") display(img) .. parsed-literal:: CIFAR-10 Train 0 Label: 0 airplane .. image:: wrap_image_classification_dataset_files/wrap_image_classification_dataset_13_1.png .. parsed-literal:: CIFAR-10 Train 10000 Label: 8 ship .. image:: wrap_image_classification_dataset_files/wrap_image_classification_dataset_13_3.png .. parsed-literal:: CIFAR-10 Train 20000 Label: 1 automobile .. image:: wrap_image_classification_dataset_files/wrap_image_classification_dataset_13_5.png .. parsed-literal:: CIFAR-10 Train 30000 Label: 1 automobile .. image:: wrap_image_classification_dataset_files/wrap_image_classification_dataset_13_7.png .. parsed-literal:: CIFAR-10 Train 40000 Label: 6 frog .. image:: wrap_image_classification_dataset_files/wrap_image_classification_dataset_13_9.png Everything appears as expected: CIFAR-10 images are 32x32 RGB and the labels correctly correspond to the images. 3 Create a MAITE wrapper for the source dataset ----------------------------------------------- We create a class that implements the `maite.protocols.image_classification.Dataset <../generated/maite.protocols.image_classification.Dataset.html>`__ protocol. A MAITE image classification compliant Dataset must include the following two methods: - ``__getitem__(index: int)``: Returns a tuple containing the image ``ArrayLike``, the target label, and the datum metadata. - ``__len__()`` : Returns the number of data elements in the dataset. and an attribute: - ``metadata`` containing an “id” string and an optional (but recommended) map from indexes to label. The ``__getitem__`` method is the most complex; it returns a tuple consisting of the image ``ArrayLike``, the target label, and the datum metadata. The image `ArrayLike <../explanation/protocol_overview.html#concept-bridging-arraylikes>`__ has the shape (channel, height, width). Both the ``dtype`` and value range of the image ``ArrayLike`` are not specified. You must be careful to ensure that users of the Dataset such as Augmentations or Models expect the provided ``dtype`` and values. The target label is a one-hot encoding of the label. The datum metadata is a dictionary that includes at least an ``id``, of type ``str`` or ``int``, for the datum. .. code:: ipython3 # Extend DatasetMetadata to record whether dataset is a train or test split class CustomDatasetMetadata(DatasetMetadata): split: str class Cifar10Dataset: def __init__( self, cifar10_dataset_split: datasets.Dataset, split: Literal["train", "test"] ): # Save the CIFAR-10 dataset given by the user. This is helpful if you want to # sample the dataset using the Hugging Face API prior to using it. self.dataset = cifar10_dataset_split # Create a dictionary mapping label number to label name from the label metadata # in the underlying dataset. index2label = { i: label for i, label in enumerate(self.dataset.features["label"].names) } # Create required metadata attribute (with custom split key) self.metadata: DatasetMetadata = CustomDatasetMetadata( id="CIFAR-10", index2label=index2label, split=split ) # Get the number of classes used in the dataset num_classes = self.dataset.features["label"].num_classes # Create a collection of target vectors to be used for the one-hot encoding of labels self.targets = np.eye(num_classes) def __len__(self) -> int: return len(self.dataset) def __getitem__( self, index: int ) -> tuple[npt.NDArray, npt.NDArray, ic.DatumMetadataType]: # Look up item in the dataset, which returns a dictionary with two keys: # - "img": PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>, # - "label": int item = self.dataset[index] img_pil = item["img"] label = item["label"] # Convert the PIL image to a NumPy array img_hwc = np.array(img_pil) # shape (H, W, C) # Use MAITE array index convention for representing images: shape (C, H, W) img_chw = img_hwc.transpose(2, 0, 1) # Get one-hot encoded tensor indicating the class label for this image target = self.targets[label, :].copy() # CIFAR-10 does not have any extra metadata, so we record only the index of this datum metadata: ic.DatumMetadataType = {"id": index} return img_chw, target, metadata 4 Examine the MAITE-wrapped dataset ----------------------------------- Wrap the train dataset and print the metadata: .. code:: ipython3 train_dataset: ic.Dataset = Cifar10Dataset(cifar10_dataset_split=cifar10_train, split="train") train_dataset.metadata .. parsed-literal:: {'id': 'CIFAR-10', 'index2label': {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}, 'split': 'train'} Load the test dataset using the ``test`` split: .. code:: ipython3 test_dataset: ic.Dataset = Cifar10Dataset(cifar10_dataset_split=cifar10_test, split="test") test_dataset.metadata .. parsed-literal:: {'id': 'CIFAR-10', 'index2label': {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}, 'split': 'test'} Check the length of the datasets: .. code:: ipython3 print(f"CIFAR-10 size: train={len(train_dataset)}, test={len(test_dataset)}") .. parsed-literal:: CIFAR-10 size: train=50000, test=10000 Examine some of the data points in the wrapped dataset: .. code:: ipython3 def print_datum(dataset, index): img_arr_chw, target, datum_metadata = dataset[index] print(f"Datum {datum_metadata['id']}") print(f" Input Image Array: {str(img_arr_chw)[:30]}...") print(f" shape={img_arr_chw.shape}") print(f" dtype={img_arr_chw.dtype}") display(PIL.Image.fromarray(img_arr_chw.transpose(1, 2, 0))) print(f" Target: {target}") label_index = np.argmax(target) print(f" target index: {np.argmax(target)}") print(f" target label: {dataset.metadata['index2label'][label_index]}") print(" Metadata:") print(f" {datum_metadata}") .. code:: ipython3 for i in [0, 3000, 6000]: print_datum(test_dataset, i) print() .. parsed-literal:: Datum 0 Input Image Array: [[[158 159 165 ... 137 126 116... shape=(3, 32, 32) dtype=uint8 .. image:: wrap_image_classification_dataset_files/wrap_image_classification_dataset_27_1.png .. parsed-literal:: Target: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] target index: 3 target label: cat Metadata: {'id': 0} Datum 3000 Input Image Array: [[[17 17 18 ... 23 21 20] [2... shape=(3, 32, 32) dtype=uint8 .. image:: wrap_image_classification_dataset_files/wrap_image_classification_dataset_27_3.png .. parsed-literal:: Target: [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] target index: 5 target label: dog Metadata: {'id': 3000} Datum 6000 Input Image Array: [[[201 211 211 ... 191 183 188... shape=(3, 32, 32) dtype=uint8 .. image:: wrap_image_classification_dataset_files/wrap_image_classification_dataset_27_5.png .. parsed-literal:: Target: [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] target index: 8 target label: ship Metadata: {'id': 6000} 5 Conclusion ------------ We’ve successfully wrapped the source CIFAR-10 dataset as a MAITE-compliant image classification dataset. Note that: - We don’t observe any static type checking errors (assuming we’ve enabled static type checking in our IDE, e.g., by following the steps in the `Enable Static Type Checking <./static_typing.html>`__ guide). - The wrapped dataset appears to be working correctly, e.g., image, ground truth target label, and metadata are consistent. At this point, the wrapped dataset could be used as part of a larger test & evaluation workflow, e.g., by using MAITE’s `evaluate <../generated/maite.workflows.evaluate.html>`__ workflow to compute the accuracy of a model on the CIFAR-10 test split.